tf2 トランスフォーマー チャットボット
1.0.0
TensorFlow 2 の Transformer を使用してエンドツーエンドのチャットボットを構築します。blog.tensorflow.org で私のチュートリアルをチェックしてください。
setup.sh
スクリプトを更新して、Apple Silicon バージョンの TensorFlow 2.9 をインストールします (冒険したい場合にのみこれを使用してください)。model.save()
またはtf.keras.models.save_model()
を介してモデルを保存できるように、2 つのカスタム レイヤーPositionalEncoding
とMultiHeadAttentionLayer
更新しました。train.py
model.save()
とtf.keras.models.load_model()
を呼び出す方法を示しています。chatbot
を初期化します conda create -n chatbot python=3.8
conda activate chatbot
sh setup.sh
tensorflow-metal
インストールします (Apple Silicon GPU 上の TensorFlow には大量のバグがあることに注意してください。たとえば、Adam オプティマイザーが機能しません)。max_samples
会話ペアをquestions
とanswers
のリストに抽出します。start_token
とend_token
を追加して各文の開始と終了を示します。max_length
トークンを超える文をフィルターで除外します。max_length
までパディングしますpython main.py --help
python train.py --output_dir runs/save_model --batch_size 256 --epochs 50 --max_samples 50000
runs/save_model
に保存されます。 input: where have you been?
output: i m not talking about that .
input: it's a trap!
output: no , it s not .