tf2 變壓器聊天機器人
1.0.0
在 TensorFlow 2 中使用 Transformer 建立端到端聊天機器人。
setup.sh
腳本以安裝 Apple Silicon 版本的 TensorFlow 2.9(僅當您喜歡冒險時才使用此腳本)。PositionalEncoding
和MultiHeadAttentionLayer
,以允許透過model.save()
或tf.keras.models.save_model()
儲存模型。train.py
展示如何呼叫model.save()
和tf.keras.models.load_model()
。chatbot
conda create -n chatbot python=3.8
conda activate chatbot
sh setup.sh
tensorflow-metal
有大量錯誤,例如Adam 優化器不起作用)。max_samples
個對話對提取到questions
和answers
清單中。start_token
和end_token
來表示每個句子的開始和結束。max_length
標記的句子。max_length
python main.py --help
python train.py --output_dir runs/save_model --batch_size 256 --epochs 50 --max_samples 50000
runs/save_model
。 input: where have you been?
output: i m not talking about that .
input: it's a trap!
output: no , it s not .