tf2 变压器聊天机器人
1.0.0
在 TensorFlow 2 中使用 Transformer 构建端到端聊天机器人。在 blog.tensorflow.org 上查看我的教程。
setup.sh
脚本以安装 Apple Silicon 版本的 TensorFlow 2.9(仅当您喜欢冒险时才使用此脚本)。PositionalEncoding
和MultiHeadAttentionLayer
,以允许通过model.save()
或tf.keras.models.save_model()
保存模型。train.py
展示了如何调用model.save()
和tf.keras.models.load_model()
。chatbot
conda create -n chatbot python=3.8
conda activate chatbot
sh setup.sh
tensorflow-metal
有大量错误,例如 Adam 优化器不起作用)。max_samples
个对话对提取到questions
和answers
列表中。start_token
和end_token
来表示每个句子的开始和结束。max_length
标记的句子。max_length
python main.py --help
python train.py --output_dir runs/save_model --batch_size 256 --epochs 50 --max_samples 50000
runs/save_model
。 input: where have you been?
output: i m not talking about that .
input: it's a trap!
output: no , it s not .