Adversarial Learning for Generative Conversational Agents
versarial Learning for Generative Conversational Agents
该存储库包含用于我们的生成对话代理(GCA)的新对抗训练方法。
有关这种新训练方法的更多详细信息,请参阅 Oswaldo Ludwig 论文,“生成对话代理的端到端对抗性学习”,arXiv:1711.10122 cs.CL,2017 年 11 月。来自此存储库的代码,请引用本文。
我们的方法假设 GCA 是一个生成器,旨在愚弄鉴别器,将对话标记为人类生成或机器生成。在我们的方法中,鉴别器执行令牌级别的分类,即它指示当前令牌是由人类还是机器生成的。为此,鉴别器还接收上下文话语(对话历史)和当前标记之前的不完整答案作为输入。这种新方法使得通过反向传播进行端到端训练成为可能。自对话过程能够生成一组具有更多多样性的生成数据,用于对抗训练。这种方法提高了与训练数据无关的问题的性能。
此处提供的训练模型使用了从在线英语课程对话中收集的数据集(此处提供)。
我们的GCA模型可以用下面的流程图来解释:
而下面的伪代码解释了我们的 GCA 算法:
我们新的端到端对抗训练可以通过以下 Keras 模型(在文件 train_bot_GAN.py 中实现)来解释,该模型由生成器和鉴别器组成。黄色块属于 GCA(生成器),而绿色块属于鉴别器。白色块在生成器和鉴别器之间共享:
而以下伪代码解释了新算法(有关变量的定义,请参阅论文):
与预先训练的模型聊天:
使用预先训练的鉴别器评估对话台词:
使用新的对抗方法进行端到端训练:
如果您想从头开始对抗训练,请将权重文件 my_model_weights.h5 (预训练新对抗方法)设置为 my_model_weights20.h5 (通过教师强制预训练)并运行 train_script.py。