사전 훈련된 XLNet
사용한 신경 대화 생성기 모델 구현 Yang et al. (2019) 및 GPT2
아키텍처 Radford et al. (2019) 현재 3개의 데이터세트에 대해: DailyDialog
Li et al. (2017) , PersonaChat
Zhang 외. (2018) 및 새로운 TopicalChat
Gopalakrishnan et al. (2019) Alexa Prize Socialbot Grand Challenge 3. Top-k 샘플링 Fan et al. (2018) 및 핵 디코딩 Holtzman et al. (2019)은 디코딩 기술로 사용 가능합니다. 훈련 목표는 발화 및 대화 기록에 대한 자동 회귀 언어 모델링입니다.
이 모델은 nvidia/apex의 혼합 정밀 교육을 활용할 수 있습니다. apex는 필수가 아니며 사용 가능한 경우에만 사용됩니다. 설치 가이드는 공식 지침을 참조하세요. 이 모듈을 사용하는 것은 모든 GPU(Volta 및 Turing만)에 유용하지 않으며 인스턴스가 혼합 정밀도 교육을 지원하는지 사전에 확인해야 합니다.
모델을 훈련하려면 이 저장소를 복제하고 종속 항목을 설치하세요. 이 프로젝트는 더 빠른 입력 파이프라인을 위해 cython을 사용하여 배치를 어셈블합니다. 또한 Python virtualenv를 사용하는 것이 선호되었습니다.
git clone https://github.com/bme-chatbots/dialogue-generation.git
cd dialogue-generation
pip install -r requirements.txt
python setup.py build_ext --inplace
다음 명령은 PersonaChat
에서 gpt2-medium
모델을 사용하는 단일 GPU/CPU에 대한 훈련을 시작합니다. --name
은 로그와 체크포인트가 저장되는 모델 폴더의 하위 디렉터리 이름입니다.
python -m src.train --model gpt2-medium --data personachat --name my_test_run
분산 다중 GPU 훈련의 경우 기차 스크립트를 다음과 같이 호출해야 합니다.
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS src/train.py --model gpt2
구성 json 파일의 경로를 --config
인수로 전달하여 사전 정의된 구성을 사용할 수도 있습니다. 이는 src/configs
폴더에서 사용할 수 있으며 훈련 결과는 결과 섹션 아래에서 볼 수 있습니다.
python -m src.train --config src/configs/xlnet-dailydialog.json
Google Colaboratory 또는 Kaggle 커널 에서는 모델 학습이 빠르고 쉽습니다. 혼합 정밀도 훈련을 완벽하게 활용할 수 있고 이전 Tesla K80 버전보다 훨씬 빠르기 때문에 새로운 Tesla P100 또는 Tesla T4 장치에서는 런타임 유형을 GPU로 설정하는 것이 중요합니다. Colab의 셀에서 !nvidia-smi
실행하여 현재 유형을 확인할 수 있습니다.
여기 바로 가기로 전체 예제 요점이 있으며 이를 공동 작업 파일로 Google 드라이브에 간단히 가져올 수 있습니다.
colab (또는 Kaggle 커널) 파일의 셀에서 다음 코드를 복사하고 실행하여 모델을 설치합니다. Kaggle 커널을 사용하는 경우 인터넷 액세스도 활성화해야 합니다.
! git clone https://github.com/bme-chatbots/dialogue-generation.git
! python -m pip install --upgrade pip
# installing apex is optional and is only useful if Colab's Tesla P100 or T4 is used
# !git clone https://github.com/NVIDIA/apex
# !cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
# building the cython code and installing the required packages
! cd dialogue-generation ; pip install -r requirements.txt ; python setup.py build_ext --inplace
훈련 및 검증 지표는 Tensorboard에 기록되며, 아래 코드가 훈련 셀 전에 실행되는 경우 colab 파일에서도 추적할 수 있습니다.
%load_ext tensorboard
%tensorboard --logdir " dialogue-generation/model "
그런 다음 기본 플래그를 사용하여 train
스크립트를 실행하기만 하면 모델을 훈련할 수 있습니다. -h
플래그를 제공하면 train.py
스크립트에서 허용하는 모든 플래그를 볼 수 있습니다.
! cd dialogue-generation ; python -m src.train
학습 후 다음 코드 조각의 다운로드 링크를 평가 후 스크립트에 의해 기록된 링크로 설정하여 모델을 다운로드할 수 있습니다. ( Saving model to dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt
)
from IPython . display import FileLink
# note that in case of kaggle kernel you have to give path
# relative to your working directory
FileLink ( r'dialogue-generation/src/../model/gpt2/19.11.03-12:59:47/model.pt' )
interact
스크립트를 실행하고 --model_file
로 훈련된 모델의 경로를 제공하여 훈련된 모델에서 대화형 평가 모드를 사용할 수 있습니다. --config
파일을 제공하거나 교육 중에 사용된 것과 동일한 --model
및 --name
인수를 제공할 수도 있습니다.
python -m src.interact --model gpt2-medium --name my_test_run
python -m src.interact --config src/configs/xlnet-dailydialog.json
자신의 데이터세트로 모델을 훈련하려면 DialogDataset
에서 하위 클래스를 만들고 원시 파일에서 데이터 생성을 구현하기만 하면 됩니다. train.txt
, valid.txt
및 test.txt
가 data<name of your data>
에 배치되어 있는 경우 대화 상자의 각 차례는 새 줄에 있고 별도의 대화 상자는 추가 빈 줄로 구분됩니다.
├── data
| ├── dailydialog
| └── custom_dataset # name of your data ( custom_dataset by default )
| ├── train.txt # the correctly formated train, valid and test files
| ├── valid.txt
| └── test.txt
├── src
Hello how are you?
Hi I'm fine thanks. And you?
Me too thanks for asking.
Hi my name Peter.
Nice to meet you I am Eric.
CustomDataset
라는 예제 사용자 정의 데이터 세트 클래스는 이러한 속성이 있는 데이터 세트를 읽는 data.py
에 구현됩니다.
아래 대화는 7세대(149051단계)에 대한 기본 매개변수와 PersonaChat
데이터세트의 --max_hist 3
으로 훈련된 gpt2
에서 샘플링되었습니다.
User: hi
Bot: greetings, i'm gerald.
User: how are you ?
Bot: good, thanks. getting ready for another hike, i love hiking!
User: oh realy? where are you going?
Bot: pikes peak. i am curently exploring famous mountains when i have time.
User: that is cool are you going alone?
Bot: my brother is in the us, we go hiking with him every month.
User: how old are you by the way?
Bot: a little under 46 years old
User: are you in a relationship?
Bot: yes i am married, 3 boys, i am pregnant again. i am too busy to focus
아래 대화는 DailyDialog
데이터세트에서 8개 에포크(48824 단계) 및 --max_hist 3
에 대한 기본 매개변수로 훈련된 gpt2
에서 샘플링되었습니다.
User: hi how are you ?
Bot: fine. i got a promotion at work earlier this week.
User: where do you work ?
Bot: I work at the polling place.
User: where is that ?
Bot: I live in the dorms across the street.
User: do you like it ?
Bot: it's great. I heard it is far from here.
User: are you a student ?
Bot: yes. I am a student too.