ACL 2017 で長い論文として公開された「Learning Discourse-level Diversity for Neural Dialog Models using Conditional variational Autoencoders」で説明されている CVAE ベースのダイアログ モデルの TensorFlow 実装を提供します。詳細については、論文を参照してください。
このツールキットに含まれるソース コードまたはデータセットを仕事で使用する場合は、次の論文を引用してください。ビブテックスは以下のとおりです。
[Zhao et al, 2017]:
@inproceedings{zhao2017learning,
title={Learning Discourse-level Diversity for Neural Dialog Models using Conditional Variational Autoencoders},
author={Zhao, Tiancheng and Zhao, Ran and Eskenazi, Maxine},
booktitle={Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
volume={1},
pages={654--664},
year={2017}
}
同じ SwitchBoard データセットを使用するベースライン メソッド HRED は、汎用テキスト生成ツールキットであるTexarにも実装されています。ここでチェックアウトしてください。
python kgcvae_swda.py
デフォルトのトレーニングを実行し、モデルを ./working に保存します
既存のモデルを実行するには、kgcvae_swda.py の先頭にある TF フラグを次のように変更します。
forward_only: False -> True
test_path: set to the folder contains the model. E.g. runxxxx
その後、次の方法でモデルを実行できます。
python kgcvae_swda.py
出力は標準出力に出力され、生成された応答は test_path の test.txt に保存されます。
https://nlp.stanford.edu/projects/glove/ から Glove の単語埋め込みをダウンロードします。デフォルト設定では、Twitter でトレーニングされた 200 次元の単語埋め込みが使用されます。
最後にkgcvae_swda.pyの15行目にword2vec_pathを設定します。
2 つのデータセットをリリースします。
独自のデータでモデルをトレーニングしたい場合。次の形式の pickle ファイルを作成してください。
# The top directory is a python dictionary
type(data) = dict
data.keys() = ['train', 'valid', 'test']
# Train/valid/test is a list, each element is one dialog
train = data['train']
type(train) = list
# Each dialog is a dict
dialog = train[0]
type(dialog)= dict
dialog.keys() = ['A', 'B', 'topic', 'utts']
# A, B contain meta info about speaker A and B.
# topic defines the dialog prompt topic in Switchboard Corpus.
# utts is a list, each element is a tuple that contain info about an utterance
utts = dialog['utts']
type(utts) = list
utts[0] = ("A" or "B", "utterance in string", [dialog_act, other_meta_info])
# For example, a utterance look like this:
('B','especially your foreign cars',['statement-non-opinion'])
結果のファイルを ./data に配置し、kgcvae_swda.py にdata_dirを設定します。