Реализация модели нейронного генератора диалогов с предварительно обученным XLNet
Yang et al. (2019) и архитектура GPT2
Radford et al. (2019) на трех наборах данных: DailyDialog
Li et al. (2017) , PersonaChat
Чжан и др. (2018) и новый TopicalChat
Gopalakrishnan et al. (2019) из Alexa Prize Socialbot Grand Challenge 3. Выборка Top-k Fan et al. (2018) и декодирование ядра Holtzman et al. (2019) доступны в качестве методов декодирования. Целью обучения является авторегрессионное языковое моделирование на основе историй высказываний и диалогов.
Модель может использовать обучение смешанной точности от nvidia/apex. Обратите внимание, что апекс не является обязательным и используется только в том случае, если он доступен. Руководство по установке смотрите в официальной инструкции. Использование этого модуля бесполезно не для всех графических процессоров (только Volta и Turing), и вам следует заранее проверить, поддерживает ли ваш экземпляр обучение смешанной точности.
Чтобы обучить модель, клонируйте этот репозиторий и установите зависимости. В проекте используется Cython для сборки пакетов для ускорения конвейера ввода. Также было предпочтительнее использовать виртуальную среду Python.
git clone https://github.com/bme-chatbots/dialogue-generation.git
cd dialogue-generation
pip install -r requirements.txt
python setup.py build_ext --inplace
Следующая команда начнет обучение на одном графическом процессоре/процессоре со gpt2-medium
в PersonaChat
. --name
— имя подкаталога в папке модели, где сохраняются логи и контрольные точки.
python -m src.train --model gpt2-medium --data personachat --name my_test_run
Для распределенного обучения с несколькими графическими процессорами сценарий поезда должен вызываться следующим образом.
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS src/train.py --model gpt2
Вы также можете использовать предопределенные конфигурации, передав путь к файлу конфигурации json в качестве аргумента --config
. Они доступны в папке src/configs
, а результаты их обучения можно увидеть под разделом результатов.
python -m src.train --config src/configs/xlnet-dailydialog.json
Обучение модели происходит быстро и легко с помощью ядра Google Colaboratory или Kaggle . Важно установить тип среды выполнения на графический процессор для нового устройства Tesla P100 или Tesla T4, поскольку оно может полностью использовать обучение смешанной точности и работает намного быстрее, чем старая версия Tesla K80. Вы можете проверить текущий тип, запустив !nvidia-smi
в ячейке вашего колаба.
В качестве ярлыка здесь приведен полный пример, который вы можете просто импортировать на свой Google Диск как файл для совместной работы.
Скопируйте и запустите следующий код в ячейке вашего файла colab (или ядра Kaggle), чтобы установить модель. Если вы используете ядро Kaggle, вам также необходимо включить доступ в Интернет.
! git clone https://github.com/bme-chatbots/dialogue-generation.git
! python -m pip install --upgrade pip
# installing apex is optional and is only useful if Colab's Tesla P100 or T4 is used
# !git clone https://github.com/NVIDIA/apex
# !cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
# building the cython code and installing the required packages
! cd dialogue-generation ; pip install -r requirements.txt ; python setup.py build_ext --inplace
Метрики обучения и проверки записываются в Tensorboard, что также можно отслеживать в файле colab, если приведенный ниже код выполняется перед обучающей ячейкой.
%load_ext tensorboard
%tensorboard --logdir " dialogue-generation/model "
Затем модель можно обучить, просто запустив сценарий train
с флагами по умолчанию. Вы можете просмотреть все флаги, принятые сценарием train.py
, указав флаг -h
.
! cd dialogue-generation ; python -m src.train
После обучения модель можно загрузить, установив ссылку для скачивания в следующем фрагменте на ссылку, записанную сценарием после оценки. ( Saving model to dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt
)
from IPython . display import FileLink
# note that in case of kaggle kernel you have to give path
# relative to your working directory
FileLink ( r'dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt' )
Режим интерактивной оценки доступен для обученной модели путем запуска сценария interact
и указания пути к обученной модели с помощью --model_file
. Вы также можете предоставить файл --config
или просто указать тот же аргумент --model
и --name
, который использовался во время обучения.
python -m src.interact --model gpt2-medium --name my_test_run
python -m src.interact --config src/configs/xlnet-dailydialog.json
Чтобы обучить любую модель на вашем собственном наборе данных, вам просто нужно создать подкласс DialogDataset
и реализовать генерацию данных из необработанных файлов. Учитывая, что train.txt
, valid.txt
и test.txt
помещены в data<name of your data>
, где каждый поворот в диалоге находится на новой строке, а отдельные диалоги разделены дополнительной пустой строкой.
├── data
| ├── dailydialog
| └── custom_dataset # name of your data ( custom_dataset by default )
| ├── train.txt # the correctly formated train, valid and test files
| ├── valid.txt
| └── test.txt
├── src
Hello how are you?
Hi I'm fine thanks. And you?
Me too thanks for asking.
Hi my name Peter.
Nice to meet you I am Eric.
Пример класса пользовательского набора данных с именем CustomDataset
реализован в data.py
, который считывает набор данных с этими свойствами.
Диалоговое окно ниже взято из gpt2
обученного с параметрами по умолчанию для 7 эпох (149051 шагов) и --max_hist 3
в наборе данных PersonaChat
.
User: hi
Bot: greetings, i'm gerald.
User: how are you ?
Bot: good, thanks. getting ready for another hike, i love hiking!
User: oh realy? where are you going?
Bot: pikes peak. i am curently exploring famous mountains when i have time.
User: that is cool are you going alone?
Bot: my brother is in the us, we go hiking with him every month.
User: how old are you by the way?
Bot: a little under 46 years old
User: are you in a relationship?
Bot: yes i am married, 3 boys, i am pregnant again. i am too busy to focus
Диалоговое окно ниже взято из gpt2
обученного с параметрами по умолчанию для 8 эпох (48824 шагов) и --max_hist 3
в наборе данных DailyDialog
.
User: hi how are you ?
Bot: fine. i got a promotion at work earlier this week.
User: where do you work ?
Bot: I work at the polling place.
User: where is that ?
Bot: I live in the dorms across the street.
User: do you like it ?
Bot: it's great. I heard it is far from here.
User: are you a student ?
Bot: yes. I am a student too.