使用預先訓練的XLNet
實現神經對話生成器模型Yang 等人。 (2019)和GPT2
架構Radford 等人。 (2019)目前三個資料集: DailyDialog
Li 等人。 (2017) , PersonaChat
張等人。 (2018)和新的TopicalChat
Gopalakrishnan 等人。 (2019)來自 Alexa 獎社交機器人大挑戰 3 。 (2018)和核解碼Holtzman 等人。 (2019)可用作解碼技術。訓練目標是對話語和對話歷史進行自回歸語言建模。
此模型可以利用 nvidia/apex 的混合精度訓練。請注意,apex 不是必需的,僅在可用時才使用。安裝指南請參閱官方說明。使用此模組並不適用於所有 GPU(僅 Volta 和 Turing),您應該事先檢查您的執行個體是否支援混合精度訓練。
若要訓練模型,請複製此儲存庫並安裝相依性。此專案使用 cython 來組裝批次以加快輸入管道的速度。它也喜歡使用 python virtualenv。
git clone https://github.com/bme-chatbots/dialogue-generation.git
cd dialogue-generation
pip install -r requirements.txt
python setup.py build_ext --inplace
以下指令將開始在PersonaChat
上使用gpt2-medium
模型在單一 GPU/CPU 上進行訓練。 --name
是模型資料夾中保存日誌和檢查點的子目錄的名稱。
python -m src.train --model gpt2-medium --data personachat --name my_test_run
對於分散式多 GPU 訓練,訓練腳本應該這樣呼叫。
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS src/train.py --model gpt2
您也可以透過將配置 json 檔案的路徑作為--config
參數傳遞來使用預先定義的配置。這些可以在src/configs
資料夾中找到,它們的訓練結果可以在結果部分下方看到。
python -m src.train --config src/configs/xlnet-dailydialog.json
在Google Colaboratory或Kaggle 核心上訓練模型既快速又簡單。使用新的 Tesla P100 或 Tesla T4 單元將運行時類型設為 GPU 非常重要,因為它可以充分利用混合精度訓練,並且比舊的 Tesla K80 版本快得多。您可以透過在 Colab 的單元中執行!nvidia-smi
來檢查目前類型。
作為一個快捷方式,這裡有一個完整的範例要點,您可以將其作為協作文件匯入到您的 Google 雲端硬碟中。
在 colab (或 Kaggle 核心)檔案的儲存格中複製並執行以下程式碼以安裝模型。如果您使用 Kaggle 內核,您還必須啟用網路存取。
! git clone https://github.com/bme-chatbots/dialogue-generation.git
! python -m pip install --upgrade pip
# installing apex is optional and is only useful if Colab's Tesla P100 or T4 is used
# !git clone https://github.com/NVIDIA/apex
# !cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
# building the cython code and installing the required packages
! cd dialogue-generation ; pip install -r requirements.txt ; python setup.py build_ext --inplace
訓練和驗證指標會記錄到 Tensorboard,如果在訓練單元之前執行以下程式碼,也可以在 colab 檔案中進行追蹤。
%load_ext tensorboard
%tensorboard --logdir " dialogue-generation/model "
然後只需使用預設標誌執行train
腳本即可訓練模型。您可以透過提供-h
標誌來查看train.py
腳本接受的所有標誌。
! cd dialogue-generation ; python -m src.train
訓練後,可以透過將以下程式碼片段中的下載連結設定為評估後腳本記錄的連結來下載模型。 ( Saving model to dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt
)
from IPython . display import FileLink
# note that in case of kaggle kernel you have to give path
# relative to your working directory
FileLink ( r'dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt' )
透過執行interact
腳本並使用--model_file
提供訓練模型的路徑,可以在訓練模型上使用互動式評估模式。您也可以提供--config
文件,或只是簡單地提供訓練期間使用的相同--model
和--name
參數。
python -m src.interact --model gpt2-medium --name my_test_run
python -m src.interact --config src/configs/xlnet-dailydialog.json
要在您自己的資料集上訓練任何模型,您只需從DialogDataset
進行子類化並從原始檔案實現資料生成即可。給定一個放在data<name of your data>
中的train.txt
、 valid.txt
和test.txt
,其中對話框中的每個回合都在一個新行中,並且單獨的對話框由額外的空行分隔。
├── data
| ├── dailydialog
| └── custom_dataset # name of your data ( custom_dataset by default )
| ├── train.txt # the correctly formated train, valid and test files
| ├── valid.txt
| └── test.txt
├── src
Hello how are you?
Hi I'm fine thanks. And you?
Me too thanks for asking.
Hi my name Peter.
Nice to meet you I am Eric.
data.py
中實作了一個名為CustomDataset
的範例自訂資料集類,它讀取具有這些屬性的資料集。
下面的對話方塊是從gpt2
取樣的,在PersonaChat
資料集上使用預設參數訓練了 7 個 epoch(149051 步驟)和--max_hist 3
。
User: hi
Bot: greetings, i'm gerald.
User: how are you ?
Bot: good, thanks. getting ready for another hike, i love hiking!
User: oh realy? where are you going?
Bot: pikes peak. i am curently exploring famous mountains when i have time.
User: that is cool are you going alone?
Bot: my brother is in the us, we go hiking with him every month.
User: how old are you by the way?
Bot: a little under 46 years old
User: are you in a relationship?
Bot: yes i am married, 3 boys, i am pregnant again. i am too busy to focus
下面的對話方塊是從gpt2
取樣的,在DailyDialog
資料集上使用預設參數訓練了 8 個時期(48824 個步驟)和--max_hist 3
。
User: hi how are you ?
Bot: fine. i got a promotion at work earlier this week.
User: where do you work ?
Bot: I work at the polling place.
User: where is that ?
Bot: I live in the dorms across the street.
User: do you like it ?
Bot: it's great. I heard it is far from here.
User: are you a student ?
Bot: yes. I am a student too.