该存储库托管论文“Augmenting Neural Response Generation with Context-Aware Topical Attention”的实现。
THRED 是一个多轮响应生成系统,旨在生成上下文和主题感知响应。该代码库是从 Tensorflow NMT 存储库演变而来的。
TL;DR使用此框架创建对话代理的步骤:
conda env create -f thred_env.yml
安装依赖项(要使用pip
,请参阅依赖项)MODEL_DIR
是将模型保存到的目录。我们建议在至少 2 个 GPU 上进行训练,否则您可以减小数据大小(通过省略训练文件中的对话)和模型大小(通过修改配置文件)。 python -m thred --mode train --config conf/thred_medium.yml --model_dir
--train_data --dev_data --test_data
python -m thred --mode interactive --model_dir
1 个软件包仅用于解析和清理 Reddit 数据。 2仅用于命令行交互模式下测试对话模型
要使用pip
安装依赖项,请运行pip install -r requirements
。对于 Anaconda,运行conda env create -f thred_env.yml
(推荐)。完成依赖项后,运行pip install -e .
安装 thred 包。
我们的 Reddit 数据集,我们称之为 Reddit 对话语料库 (RCC),是从 95 个选定的 Reddit 子版块(此处列出)中收集的。我们处理了从 2016 年 11 月到 2018 年 8 月期间 20 个月的 Reddit(不包括 2017 年 6 月和 2017 年 7 月;我们利用这两个月以及 2016 年 10 月的数据来训练 LDA 模型)。请参阅此处了解如何构建 Reddit 数据集的详细信息,包括预处理和清理原始 Reddit 文件。下表总结了 RCC 信息:
语料库 | #火车 | #dev | #测试 | 下载 | 带主题词下载 |
---|---|---|---|---|---|
每线 3 圈 | 9.2M | 508K | 406K | 下载 (773MB) | 下载(2.5GB) |
每线 4 圈 | 4M | 223K | 178K | 下载 (442MB) | 下载(1.2GB) |
每线 5 圈 | 1.8M | 10万 | 80K | 下载 (242MB) | 下载 (594MB) |
在数据文件中,每一行对应一个对话,其中话语以制表符分隔。主题词出现在最后一句话之后,也用制表符分隔。
请注意,3 圈/4 圈/5 圈文件包含相似的内容,尽管每行的话语数不同。它们都是从同一来源提取的。如果您发现数据中有任何错误或任何不当言论,请在此处报告您的疑虑。
在模型配置文件(即conf中的YAML文件)中,嵌入类型可以是以下任意一种: glove840B
、 fastText
、 word2vec
和hub_word2vec
。为了处理预先训练的嵌入向量,我们利用 Pymagnitude 和 Tensorflow-Hub。请注意,您还可以在响应生成模型的训练过程中使用random300
(300 指的是嵌入向量的维度,可以替换为任意值)来学习向量。 word_embeddings.yml 中提供了与嵌入模型相关的设置。
训练配置应在类似于 Tensorflow NMT 的 YAML 文件中定义。此处提供了 THRED 和其他基线的示例配置。
实现的模型有 Seq2Seq、HRED、Topic Aware-Seq2Seq 和 THRED。
请注意,虽然大多数参数在不同模型中是通用的,但某些模型可能具有附加参数(例如,主题模型具有topic_words_per_utterance
和boost_topic_gen_prob
参数)。
要训练模型,请运行以下命令:
python main.py --mode train --config < YAML_FILE >
--train_data < TRAIN_DATA > --dev_data < DEV_DATA > --test_data < TEST_DATA >
--model_dir < MODEL_DIR >
中存储词汇文件和 Tensorflow 模型文件。可以通过执行以下命令恢复训练:
python main.py --mode train --model_dir < MODEL_DIR >
使用以下命令,可以根据测试数据集测试模型。
python main.py --mode test --model_dir < MODEL_DIR > --test_data < TEST_DATA >
在测试期间可以覆盖测试参数。这些参数是:光束宽度--beam_width
、长度惩罚权重--length_penalty_weight
和采样温度--sampling_temperature
。
实现了一个简单的命令行界面,允许您与学习的模型进行对话(与测试模式类似,测试参数也可以被覆盖):
python main.py --mode interactive --model_dir < MODEL_DIR >
在交互模式下,需要预先训练的LDA模型将推断的主题词输入到模型中。我们使用 Gensim 在为此目的收集的 Reddit 语料库上训练了一个 LDA 模型。可以从这里下载。下载的文件应解压缩并通过--lda_model_dir
传递给程序。
如果您在研究中使用了我们的工作,请引用以下论文:
@article{dziri2018augmenting,
title={Augmenting Neural Response Generation with Context-Aware Topical Attention},
author={Dziri, Nouha and Kamalloo, Ehsan and Mathewson, Kory W and Zaiane, Osmar R},
journal={arXiv preprint arXiv:1811.01063},
year={2018}
}