This repository contains a new adversarial training method for our Generative Conversational Agent (GCA).
Further details on this new training method can be found in the paper Oswaldo Ludwig, "End-to-end Adversarial Learning for Generative Conversational Agents," arXiv:1711.10122 cs.CL, Nov 2017. In the case of publication using ideas or pieces of code from this repository, please kindly cite this paper.
Our method assumes the GCA as a generator that aims at fooling a discriminator that labels dialogues as human-generated or machine-generated. In our approach, the discriminator performs token-level classification, i.e. it indicates whether the current token was generated by humans or machines. To do so, the discriminator also receives the context utterances (the dialogue history) and the incomplete answer up to the current token as input. This new approach makes possible the end-to-end training by backpropagation. A self-conversation process enables to produce a set of generated data with more diversity for the adversarial training. This approach improves the performance on questions not related to the training data.
The trained model available here used a dataset collected from dialogues of English courses online, available here.
Our GCA model can be explained by the following flowchart:
while the following pseudocode explains our GCA algorithm:
Our new end-to-end adversarial training can be explained by the following Keras model (implemented in the file train_bot_GAN.py), which is composed by the generator and the discriminator. The yellow blocks belong to the GCA (the generator), while the green blocks belong to the discriminator. The white blocks are shared between generator and discriminator:
while the following pseudocode explains the new algorithm (see the paper for the definition of the variables):
To chat with the pre-trained models:
To evaluate dialog lines using the pre-trained discriminator:
To train end-to-end using the new adversarial method:
If you want to start the adversarial training from the scratch, make the weight file my_model_weights.h5 (pre-trained the new adversarial method) equal to my_model_weights20.h5 (pre-trained by teacher forcing) and run train_script.py.