此儲存庫託管論文「透過上下文感知主題注意力增強神經響應生成」的實現。
THRED 是一個多輪回應產生系統,旨在產生情境和主題感知回應。該程式碼庫是從 Tensorflow NMT 儲存庫演進而來的。
TL;DR使用此框架建立對話代理程式的步驟:
conda env create -f thred_env.yml
安裝依賴項(若要使用pip
,請參閱依賴項)MODEL_DIR
是將模型儲存到的目錄。我們建議在至少 2 個 GPU 上進行訓練,否則您可以減少資料大小(透過從訓練檔案中省略對話)和模型大小(透過修改設定檔)。 python -m thred --mode train --config conf/thred_medium.yml --model_dir
--train_data --dev_data --test_data
python -m thred --mode interactive --model_dir
1 個軟體包僅用於解析和清理 Reddit 資料。 2僅用於命令列互動模式下測試對話模型
若要使用pip
安裝依賴項,請執行pip install -r requirements
。對於 Anaconda,執行conda env create -f thred_env.yml
(建議)。完成依賴項後,執行pip install -e .
安裝 thred 包。
我們的 Reddit 資料集,我們稱之為 Reddit 對話語料庫 (RCC),是從 95 個選定的 Reddit 子版塊(此處列出)中收集的。我們處理了 Reddit,時間範圍為 2016 年 11 月至 2018 年 8 月(不包括 2017 年 6 月和 2017 年 7 月;我們利用這兩個月以及 2016 年 10 月的資料來訓練 LDA 模型)。請參閱此處以了解如何建立 Reddit 資料集的詳細信息,包括預處理和清理原始 Reddit 文件。下表總結了 RCC 資訊:
語料庫 | #火車 | #dev | #測試 | 下載 | 附主題詞下載 |
---|---|---|---|---|---|
每條 3 圈 | 9.2M | 508K | 406K | 下載 (773MB) | 下載(2.5GB) |
每條 4 圈 | 4M | 223K | 178K | 下載 (442MB) | 下載(1.2GB) |
每條線 5 圈 | 1.8M | 10萬 | 80K | 下載 (242MB) | 下載 (594MB) |
在資料檔案中,每一行對應一個對話,其中話語以製表符分隔。主題詞出現在最後一句話之後,也用製表符號分隔。
請注意,3 圈/4 圈/5 圈文件包含相似的內容,儘管每行的話語數不同。它們都是從同一來源提取的。如果您發現數據中有任何錯誤或任何不當言論,請在此處報告您的疑慮。
在模型設定檔(即conf中的YAML檔)中,嵌入類型可以是以下任一種: glove840B
、 fastText
、 word2vec
和hub_word2vec
。為了處理預先訓練的嵌入向量,我們利用 Pymagnitude 和 Tensorflow-Hub。請注意,您也可以在反應生成模型的訓練過程中使用random300
(300 指的是嵌入向量的維度,可以替換為任意值)來學習向量。 word_embeddings.yml 中提供了與嵌入模型相關的設定。
訓練配置應在類似 Tensorflow NMT 的 YAML 檔案中定義。此處提供了 THRED 和其他基線的範例配置。
實作的模型有 Seq2Seq、HRED、Topic Aware-Seq2Seq 和 THRED。
請注意,雖然大多數參數在不同模型中是通用的,但某些模型可能具有附加參數(例如,主題模型具有topic_words_per_utterance
和boost_topic_gen_prob
參數)。
要訓練模型,請執行以下命令:
python main.py --mode train --config < YAML_FILE >
--train_data < TRAIN_DATA > --dev_data < DEV_DATA > --test_data < TEST_DATA >
--model_dir < MODEL_DIR >
中儲存詞彙檔案和 Tensorflow 模型檔案。可以透過執行以下命令來恢復訓練:
python main.py --mode train --model_dir < MODEL_DIR >
使用以下命令,可以根據測試資料集測試模型。
python main.py --mode test --model_dir < MODEL_DIR > --test_data < TEST_DATA >
在測試期間可以覆蓋測試參數。這些參數是:光束寬度--beam_width
、長度懲罰權重--length_penalty_weight
和採樣溫度--sampling_temperature
。
實作了一個簡單的命令列介面,讓您與學習的模型進行對話(與測試模式類似,測試參數也可以被覆蓋):
python main.py --mode interactive --model_dir < MODEL_DIR >
在互動模式下,需要預先訓練的LDA模型將推論的主題詞輸入模型中。我們使用 Gensim 在為此目的收集的 Reddit 語料庫上訓練了一個 LDA 模型。可以從這裡下載。下載的檔案應解壓縮並透過--lda_model_dir
傳遞給程式。
如果您在研究中使用了我們的工作,請引用以下論文:
@article{dziri2018augmenting,
title={Augmenting Neural Response Generation with Context-Aware Topical Attention},
author={Dziri, Nouha and Kamalloo, Ehsan and Mathewson, Kory W and Zaiane, Osmar R},
journal={arXiv preprint arXiv:1811.01063},
year={2018}
}