Adversarial Learning for Generative Conversational Agents
versarial Learning for Generative Conversational Agents
此儲存庫包含用於我們的生成對話代理程式(GCA)的新對抗訓練方法。
有關這種新訓練方法的更多詳細信息,請參閱 Oswaldo Ludwig 論文,“生成對話代理的端到端對抗性學習”,arXiv:1711.10122 cs.CL,2017 年 11 月。請引用本文。
我們的方法假設 GCA 是一個生成器,旨在愚弄鑑別器,將對話標記為人類生成或機器生成。在我們的方法中,鑑別器執行令牌層級的分類,即它指示當前令牌是由人類還是機器產生的。為此,鑑別器也接收上下文話語(對話歷史)和當前標記之前的不完整答案作為輸入。這種新方法使得透過反向傳播進行端到端訓練成為可能。自對話過程能夠產生一組具有更多多樣性的生成數據,用於對抗訓練。這種方法提高了與訓練資料無關的問題的效能。
此處提供的訓練模型使用了從線上英語課程對話中收集的資料集(此處提供)。
我們的GCA模型可以用下面的流程圖來解釋:
而下面的偽代碼解釋了我們的 GCA 演算法:
我們新的端到端對抗訓練可以透過以下 Keras 模型(在檔案 train_bot_GAN.py 中實現)來解釋,該模型由生成器和鑑別器組成。黃色塊屬於 GCA(生成器),而綠色塊屬於鑑別器。白色區塊在生成器和鑑別器之間共用:
而以下偽代碼解釋了新演算法(有關變數的定義,請參閱論文):
與預先訓練的模型聊天:
使用預先訓練的鑑別器評估對話台詞:
使用新的對抗方法進行端到端訓練:
如果您想從頭開始對抗訓練,請將權重檔案 my_model_weights.h5 (預先訓練新對抗方法)設定為 my_model_weights20.h5 (透過教師強制預訓練)並執行 train_script.py。