TF2 트랜스포머 챗봇
1.0.0
TensorFlow 2에서 Transformer를 사용하여 엔드투엔드 챗봇을 구축하세요. blog.tensorflow.org에서 제 튜토리얼을 확인해 보세요.
setup.sh
스크립트를 업데이트하여 Apple Silicon 버전의 TensorFlow 2.9를 설치하세요(모험적인 느낌이 드는 경우에만 이 스크립트를 사용하세요).model.save()
또는 tf.keras.models.save_model()
통해 모델을 저장할 수 있도록 두 개의 사용자 정의 레이어인 PositionalEncoding
및 MultiHeadAttentionLayer
업데이트했습니다.train.py
model.save()
및 tf.keras.models.load_model()
호출하는 방법을 보여줍니다.chatbot
초기화 conda create -n chatbot python=3.8
conda activate chatbot
sh setup.sh
tensorflow-metal
통해 설치합니다(Apple Silicon GPU의 TensorFlow에는 수많은 버그가 있습니다. 예를 들어 Adam 최적화 프로그램은 작동하지 않습니다).max_samples
대화 쌍을 questions
및 answers
목록으로 추출합니다.start_token
및 end_token
추가하여 각 문장의 시작과 끝을 나타냅니다.max_length
토큰보다 많은 문장을 필터링합니다.max_length
로 채웁니다.python main.py --help
python train.py --output_dir runs/save_model --batch_size 256 --epochs 50 --max_samples 50000
runs/save_model
에 저장됩니다. input: where have you been?
output: i m not talking about that .
input: it's a trap!
output: no , it s not .