Implementierung eines neuronalen Dialoggeneratormodells mit vortrainiertem XLNet
Yang et al. (2019) und GPT2
Architektur Radford et al. (2019) zu derzeit drei Datensätzen: DailyDialog
Li et al. (2017) , PersonaChat
Zhang et al. (2018) und der neue TopicalChat
Gopalakrishnan et al. (2019) von Alexa Prize Socialbot Grand Challenge 3. Top-K-Sampling Fan et al. (2018) und Kerndekodierung Holtzman et al. (2019) stehen als Dekodierungstechniken zur Verfügung. Das Trainingsziel ist die autoregressive Sprachmodellierung auf den Äußerungen und Dialogverläufen.
Das Modell kann das gemischte Präzisionstraining von NVIDIA/Apex nutzen. Beachten Sie, dass Apex nicht erforderlich ist und nur verwendet wird, wenn es verfügbar ist. Eine Installationsanleitung finden Sie in den offiziellen Anweisungen. Die Verwendung dieses Moduls ist nicht für alle GPUs (nur Volta und Turing) sinnvoll und Sie sollten vorher prüfen, ob Ihre Instanz gemischtes Präzisionstraining unterstützt.
Um das Modell zu trainieren, klonen Sie dieses Repository und installieren Sie Abhängigkeiten. Das Projekt verwendet Cython, um Stapel für eine schnellere Eingabepipeline zusammenzustellen. Es wurde auch bevorzugt, eine virtuelle Python-Umgebung zu verwenden.
git clone https://github.com/bme-chatbots/dialogue-generation.git
cd dialogue-generation
pip install -r requirements.txt
python setup.py build_ext --inplace
Der folgende Befehl startet das Training auf einer einzelnen GPU/CPU mit gpt2-medium
-Modell auf PersonaChat
. --name
ist der Name des Unterverzeichnisses im Modellordner, in dem Protokolle und Prüfpunkte gespeichert werden.
python -m src.train --model gpt2-medium --data personachat --name my_test_run
Für verteiltes Multi-GPU-Training sollte das Train-Skript so aufgerufen werden.
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS src/train.py --model gpt2
Sie können auch vordefinierte Konfigurationen verwenden, indem Sie den Pfad der JSON-Konfigurationsdatei als Argument --config
übergeben. Diese sind im Ordner src/configs
verfügbar und ihre Trainingsergebnisse können unterhalb des Ergebnisabschnitts eingesehen werden.
python -m src.train --config src/configs/xlnet-dailydialog.json
Das Training des Modells erfolgt schnell und einfach im Google Colaboratory- oder Kaggle-Kernel . Bei der neuen Tesla P100- oder Tesla T4-Einheit ist es wichtig, den Laufzeittyp auf GPU einzustellen, da diese das Mixed-Precision-Training voll ausnutzen kann und viel schneller ist als die ältere Tesla K80-Version. Sie können den aktuellen Typ überprüfen, indem Sie !nvidia-smi
in einer Zelle Ihres Colab ausführen.
Als Abkürzung finden Sie hier einen vollständigen Beispielinhalt, den Sie einfach als gemeinsame Datei in Ihr Google Drive importieren können.
Kopieren Sie den folgenden Code und führen Sie ihn in einer Zelle Ihrer Colab-Datei (oder Kaggle-Kernel-Datei) aus, um das Modell zu installieren. Wenn Sie den Kaggle-Kernel verwenden, müssen Sie auch den Internetzugang aktivieren.
! git clone https://github.com/bme-chatbots/dialogue-generation.git
! python -m pip install --upgrade pip
# installing apex is optional and is only useful if Colab's Tesla P100 or T4 is used
# !git clone https://github.com/NVIDIA/apex
# !cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
# building the cython code and installing the required packages
! cd dialogue-generation ; pip install -r requirements.txt ; python setup.py build_ext --inplace
Die Trainings- und Validierungsmetriken werden in Tensorboard protokolliert, was auch in der Colab-Datei verfolgt werden kann, wenn der folgende Code vor der Trainingszelle ausgeführt wird.
%load_ext tensorboard
%tensorboard --logdir " dialogue-generation/model "
Das Modell kann dann trainiert werden, indem einfach das train
mit den Standardflags ausgeführt wird. Sie können alle vom train.py
Skript akzeptierten Flags anzeigen, indem Sie das Flag -h
angeben.
! cd dialogue-generation ; python -m src.train
Nach dem Training kann das Modell heruntergeladen werden, indem der Download-Link im folgenden Snippet auf den vom Skript nach der Auswertung protokollierten Link gesetzt wird. ( Saving model to dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt
)
from IPython . display import FileLink
# note that in case of kaggle kernel you have to give path
# relative to your working directory
FileLink ( r'dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt' )
Für das trainierte Modell ist ein interaktiver Bewertungsmodus verfügbar, indem Sie das interact
ausführen und den Pfad des trainierten Modells mit --model_file
angeben. Sie können auch die Datei --config
bereitstellen oder einfach dasselbe Argument --model
und --name
angeben, das während des Trainings verwendet wurde.
python -m src.interact --model gpt2-medium --name my_test_run
python -m src.interact --config src/configs/xlnet-dailydialog.json
Um ein beliebiges Modell auf Ihrem eigenen Datensatz zu trainieren, müssen Sie lediglich eine Unterklasse von DialogDataset
erstellen und die Datengenerierung aus den Rohdateien implementieren. Angenommen sind train.txt
, valid.txt
und test.txt
in data<name of your data>
, wobei jede Runde in einem Dialog in einer neuen Zeile steht und separate Dialoge durch eine zusätzliche Leerzeile getrennt sind.
├── data
| ├── dailydialog
| └── custom_dataset # name of your data ( custom_dataset by default )
| ├── train.txt # the correctly formated train, valid and test files
| ├── valid.txt
| └── test.txt
├── src
Hello how are you?
Hi I'm fine thanks. And you?
Me too thanks for asking.
Hi my name Peter.
Nice to meet you I am Eric.
In data.py
ist eine beispielhafte benutzerdefinierte Datensatzklasse mit dem Namen CustomDataset
implementiert, die einen Datensatz mit diesen Eigenschaften liest.
Der folgende Dialog stammt aus gpt2
das mit Standardparametern für 7 Epochen (149051 Schritte) und --max_hist 3
im PersonaChat
Datensatz trainiert wurde.
User: hi
Bot: greetings, i'm gerald.
User: how are you ?
Bot: good, thanks. getting ready for another hike, i love hiking!
User: oh realy? where are you going?
Bot: pikes peak. i am curently exploring famous mountains when i have time.
User: that is cool are you going alone?
Bot: my brother is in the us, we go hiking with him every month.
User: how old are you by the way?
Bot: a little under 46 years old
User: are you in a relationship?
Bot: yes i am married, 3 boys, i am pregnant again. i am too busy to focus
Der folgende Dialog stammt aus gpt2
das mit Standardparametern für 8 Epochen (48824 Schritte) und --max_hist 3
im DailyDialog
Datensatz trainiert wurde.
User: hi how are you ?
Bot: fine. i got a promotion at work earlier this week.
User: where do you work ?
Bot: I work at the polling place.
User: where is that ?
Bot: I live in the dorms across the street.
User: do you like it ?
Bot: it's great. I heard it is far from here.
User: are you a student ?
Bot: yes. I am a student too.