Implementação de um modelo gerador de diálogo neural com XLNet
pré-treinado Yang et al. (2019) e arquitetura GPT2
Radford et al. (2019) em três conjuntos de dados atualmente: DailyDialog
Li et al. (2017) , PersonaChat
Zhang et al. (2018) e o novo TopicalChat
Gopalakrishnan et al. (2019) do Alexa Prize Socialbot Grand Challenge 3. Amostragem Top-k Fan et al. (2018) e decodificação de núcleo Holtzman et al. (2019) estão disponíveis como técnicas de decodificação. O objetivo do treinamento é a modelagem de linguagem autorregressiva nos enunciados e nas histórias de diálogo.
O modelo pode aproveitar o treinamento de precisão mista da nvidia/apex. Observe que o apex não é obrigatório e só será usado se estiver disponível. Para guia de instalação, consulte as instruções oficiais. Usar este módulo não é útil para todas as GPUs (apenas Volta e Turing) e você deve verificar antes se sua instância suporta treinamento de precisão mista.
Para treinar o modelo, clone este repositório e instale as dependências. O projeto usa Cython para montar lotes para um pipeline de entrada mais rápido. Também preferiu usar um python virtualenv.
git clone https://github.com/bme-chatbots/dialogue-generation.git
cd dialogue-generation
pip install -r requirements.txt
python setup.py build_ext --inplace
O comando a seguir iniciará o treinamento em uma única GPU/CPU com modelo gpt2-medium
no PersonaChat
. --name
é o nome do subdiretório na pasta do modelo, onde os logs e pontos de verificação são salvos.
python -m src.train --model gpt2-medium --data personachat --name my_test_run
Para treinamento multi-GPU distribuído, o script de treinamento deve ser chamado assim.
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS src/train.py --model gpt2
Você também pode usar configurações predefinidas passando o caminho do arquivo json de configuração como argumento --config
. Eles estão disponíveis na pasta src/configs
e seus resultados de treinamento podem ser vistos abaixo da seção de resultados.
python -m src.train --config src/configs/xlnet-dailydialog.json
Treinar o modelo é rápido e fácil no Google Colaboratory ou no kernel Kaggle . É importante definir o tipo de tempo de execução para GPU com a nova unidade Tesla P100 ou Tesla T4, pois ela pode aproveitar totalmente o treinamento de precisão mista e é muito mais rápida do que a versão mais antiga do Tesla K80. Você pode verificar o tipo atual executando !nvidia-smi
em uma célula do seu colab.
Como atalho, aqui está um exemplo completo, que você pode simplesmente importar para o seu Google Drive como um arquivo colaborativo.
Copie e execute o código a seguir em uma célula do seu arquivo colab (ou kernel Kaggle) para instalar o modelo. Se você usa o kernel Kaggle, também deve habilitar o acesso à Internet.
! git clone https://github.com/bme-chatbots/dialogue-generation.git
! python -m pip install --upgrade pip
# installing apex is optional and is only useful if Colab's Tesla P100 or T4 is used
# !git clone https://github.com/NVIDIA/apex
# !cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
# building the cython code and installing the required packages
! cd dialogue-generation ; pip install -r requirements.txt ; python setup.py build_ext --inplace
As métricas de treinamento e validação são registradas no Tensorboard, que também pode ser rastreado no arquivo colab se o código abaixo for executado antes da célula de treinamento.
%load_ext tensorboard
%tensorboard --logdir " dialogue-generation/model "
O modelo pode ser treinado simplesmente executando o script train
com os sinalizadores padrão. Você pode ver todos os sinalizadores aceitos pelo script train.py
fornecendo o sinalizador -h
.
! cd dialogue-generation ; python -m src.train
Após o treinamento, o modelo pode ser baixado definindo o link de download no snippet a seguir para aquele registrado pelo script após a avaliação. ( Saving model to dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt
)
from IPython . display import FileLink
# note that in case of kaggle kernel you have to give path
# relative to your working directory
FileLink ( r'dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt' )
Um modo de avaliação interativo está disponível no modelo treinado executando o script interact
e fornecendo o caminho do modelo treinado com --model_file
. Você também pode fornecer o arquivo --config
ou simplesmente fornecer o mesmo argumento --model
e --name
, que foi usado durante o treinamento.
python -m src.interact --model gpt2-medium --name my_test_run
python -m src.interact --config src/configs/xlnet-dailydialog.json
Para treinar qualquer modelo em seu próprio conjunto de dados, basta criar uma subclasse de DialogDataset
e implementar a geração de dados a partir dos arquivos brutos. Dado um train.txt
, valid.txt
e test.txt
colocado em data<name of your data>
, onde cada turno em uma caixa de diálogo está em uma nova linha e as caixas de diálogo separadas são divididas por uma linha vazia extra.
├── data
| ├── dailydialog
| └── custom_dataset # name of your data ( custom_dataset by default )
| ├── train.txt # the correctly formated train, valid and test files
| ├── valid.txt
| └── test.txt
├── src
Hello how are you?
Hi I'm fine thanks. And you?
Me too thanks for asking.
Hi my name Peter.
Nice to meet you I am Eric.
Um exemplo de classe de conjunto de dados personalizado chamada CustomDataset
é implementada em data.py
que lê um conjunto de dados com essas propriedades.
A caixa de diálogo abaixo é uma amostra de gpt2
treinado com parâmetros padrão para 7 épocas (149051 etapas) e --max_hist 3
no conjunto de dados PersonaChat
.
User: hi
Bot: greetings, i'm gerald.
User: how are you ?
Bot: good, thanks. getting ready for another hike, i love hiking!
User: oh realy? where are you going?
Bot: pikes peak. i am curently exploring famous mountains when i have time.
User: that is cool are you going alone?
Bot: my brother is in the us, we go hiking with him every month.
User: how old are you by the way?
Bot: a little under 46 years old
User: are you in a relationship?
Bot: yes i am married, 3 boys, i am pregnant again. i am too busy to focus
A caixa de diálogo abaixo é uma amostra de gpt2
treinado com parâmetros padrão para 8 épocas (48824 etapas) e --max_hist 3
no conjunto de dados DailyDialog
.
User: hi how are you ?
Bot: fine. i got a promotion at work earlier this week.
User: where do you work ?
Bot: I work at the polling place.
User: where is that ?
Bot: I live in the dorms across the street.
User: do you like it ?
Bot: it's great. I heard it is far from here.
User: are you a student ?
Bot: yes. I am a student too.