事前トレーニングされたXLNet
Yang らによるニューラル対話ジェネレーター モデルの実装。 (2019)およびGPT2
アーキテクチャRadford et al. (2019)現在 3 つのデータセットについて: DailyDialog
Li et al. (2017) 、 PersonaChat
チャンら。 (2018)と新しいTopicalChat
Gopalakrishnan et al. (2019) Alexa 賞ソーシャルボット グランド チャレンジ 3 より。Top-k サンプリングFan et al。 (2018)および核解読Holtzman et al。 (2019) のデコード技術が利用可能です。トレーニングの目的は、発話と対話履歴に基づいた自己回帰言語モデリングです。
このモデルは、nvidia/apex の混合精度トレーニングを利用できます。 apex は必須ではなく、使用可能な場合にのみ使用されることに注意してください。インストールガイドについては、公式の手順を参照してください。このモジュールの使用はすべての GPU ( Volta と Turing のみ) にとって有用ではないため、インスタンスが混合精度トレーニングをサポートしているかどうかを事前にチェックインする必要があります。
モデルをトレーニングするには、このリポジトリのクローンを作成し、依存関係をインストールします。このプロジェクトでは、cython を使用してバッチを組み立て、入力パイプラインを高速化します。また、Python virtualenv を使用することも推奨されます。
git clone https://github.com/bme-chatbots/dialogue-generation.git
cd dialogue-generation
pip install -r requirements.txt
python setup.py build_ext --inplace
次のコマンドはPersonaChat
でgpt2-medium
モデルを使用して単一の GPU/CPU でトレーニングを開始します。 --name
、ログとチェックポイントが保存されるモデル フォルダー内のサブディレクトリの名前です。
python -m src.train --model gpt2-medium --data personachat --name my_test_run
分散マルチ GPU トレーニングの場合、トレーニング スクリプトは次のように呼び出す必要があります。
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS src/train.py --model gpt2
config json ファイルのパスを--config
引数として渡すことで、事前定義された構成を使用することもできます。これらはsrc/configs
フォルダーで入手でき、トレーニング結果は結果セクションの下に表示されます。
python -m src.train --config src/configs/xlnet-dailydialog.json
Google ColaboratoryまたはKaggle カーネルを使用すると、モデルのトレーニングが迅速かつ簡単になります。新しい Tesla P100 または Tesla T4 ユニットでは、混合精度トレーニングを最大限に活用でき、古い Tesla K80 バージョンよりもはるかに高速であるため、ランタイム タイプを GPU に設定することが重要です。現在のタイプを確認するには、colab のセルで!nvidia-smi
実行します。
ショートカットとして、完全なサンプル Gist をここに示します。これを共同作業ファイルとして Google ドライブにインポートするだけです。
colab (または Kaggle カーネル)ファイルのセルに次のコードをコピーして実行し、モデルをインストールします。 Kaggle カーネルを使用する場合は、インターネット アクセスも有効にする必要があります。
! git clone https://github.com/bme-chatbots/dialogue-generation.git
! python -m pip install --upgrade pip
# installing apex is optional and is only useful if Colab's Tesla P100 or T4 is used
# !git clone https://github.com/NVIDIA/apex
# !cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
# building the cython code and installing the required packages
! cd dialogue-generation ; pip install -r requirements.txt ; python setup.py build_ext --inplace
トレーニングと検証のメトリクスは Tensorboard に記録されます。トレーニング セルの前に以下のコードを実行すると、colab ファイルで追跡することもできます。
%load_ext tensorboard
%tensorboard --logdir " dialogue-generation/model "
モデルは、デフォルトのフラグを使用してtrain
スクリプトを実行するだけでトレーニングできます。 -h
フラグを指定すると、 train.py
スクリプトで受け入れられるすべてのフラグを確認できます。
! cd dialogue-generation ; python -m src.train
トレーニング後、次のスニペットのダウンロード リンクを、評価後にスクリプトによって記録されたリンクに設定することで、モデルをダウンロードできます。 ( Saving model to dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt
)
from IPython . display import FileLink
# note that in case of kaggle kernel you have to give path
# relative to your working directory
FileLink ( r'dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt' )
対話型評価モードは、 interact
スクリプトを実行し、 --model_file
でトレーニング済みモデルのパスを指定することで、トレーニング済みモデルで使用できます。 --config
ファイルを指定することも、トレーニング中に使用したのと同じ--model
および--name
引数を単に指定することもできます。
python -m src.interact --model gpt2-medium --name my_test_run
python -m src.interact --config src/configs/xlnet-dailydialog.json
独自のデータセットでモデルをトレーニングするには、 DialogDataset
からサブクラス化し、生ファイルからのデータ生成を実装するだけです。 train.txt
、 valid.txt
、 test.txt
data<name of your data>
に配置されているとすると、ダイアログの各ターンは新しい行になり、個別のダイアログは追加の空行で分割されます。
├── data
| ├── dailydialog
| └── custom_dataset # name of your data ( custom_dataset by default )
| ├── train.txt # the correctly formated train, valid and test files
| ├── valid.txt
| └── test.txt
├── src
Hello how are you?
Hi I'm fine thanks. And you?
Me too thanks for asking.
Hi my name Peter.
Nice to meet you I am Eric.
CustomDataset
という名前のカスタム データセット クラスの例は、これらのプロパティを持つデータセットを読み取るdata.py
に実装されています。
以下のダイアログは、 PersonaChat
データセットの 7 エポック (149051 ステップ) と--max_hist 3
のデフォルト パラメーターでトレーニングされたgpt2
からサンプリングされたものです。
User: hi
Bot: greetings, i'm gerald.
User: how are you ?
Bot: good, thanks. getting ready for another hike, i love hiking!
User: oh realy? where are you going?
Bot: pikes peak. i am curently exploring famous mountains when i have time.
User: that is cool are you going alone?
Bot: my brother is in the us, we go hiking with him every month.
User: how old are you by the way?
Bot: a little under 46 years old
User: are you in a relationship?
Bot: yes i am married, 3 boys, i am pregnant again. i am too busy to focus
以下のダイアログは、 DailyDialog
データセット上の 8 エポック (48824 ステップ) および--max_hist 3
のデフォルト パラメーターでトレーニングされたgpt2
からサンプリングされたものです。
User: hi how are you ?
Bot: fine. i got a promotion at work earlier this week.
User: where do you work ?
Bot: I work at the polling place.
User: where is that ?
Bot: I live in the dorms across the street.
User: do you like it ?
Bot: it's great. I heard it is far from here.
User: are you a student ?
Bot: yes. I am a student too.