Repo ini berisi kode yang menyertai postingan blog? Cara membangun AI Percakapan yang Canggih dengan Pembelajaran Transfer.
Kode ini adalah basis kode yang bersih dan diberi komentar dengan skrip pelatihan dan pengujian yang dapat digunakan untuk melatih agen dialog yang memanfaatkan pembelajaran transfer dari model bahasa OpenAI GPT dan GPT-2 Transformer.
Basis kode ini dapat digunakan untuk mereproduksi hasil partisipasi HuggingFace pada kompetisi dialog NeurIPS 2018 ConvAI2 yang merupakan metrik otomatis tercanggih. 3k+ baris kode kompetisi disaring menjadi sekitar 250 baris kode pelatihan dengan opsi terdistribusi & FP16 untuk membentuk repositori saat ini.
Model ini dapat dilatih dalam waktu sekitar satu jam pada 8 instans cloud V100 (saat ini berharga sekitar $25) dan model terlatih juga tersedia.
Untuk menginstal dan menggunakan skrip pelatihan dan inferensi, harap kloning repo dan instal persyaratannya:
git clone https://github.com/huggingface/transfer-learning-conv-ai
cd transfer-learning-conv-ai
pip install -r requirements.txt
python -m spacy download en
Untuk menginstal menggunakan buruh pelabuhan, silakan buat image mandiri:
docker build -t convai .
Catatan: Pastikan pengaturan Docker Anda mengalokasikan cukup memori untuk membangun container. Membangun dengan default 1,75GB akan gagal karena roda Pytorch yang besar.
Anda kemudian dapat memasukkan gambar
ip-192-168-22-157:transfer-learning-conv-ai loretoparisi$ docker run --rm -it convai bash
root@91e241bb823e:/ # ls
Dockerfile README.md boot dev home lib media models proc root sbin sys train.py utils.py
LICENCE bin convai_evaluation.py etc interact.py lib64 mnt opt requirements.txt run srv tmp usr var
Anda kemudian dapat menjalankan skrip interact.py
pada model yang telah dilatih sebelumnya:
python3 interact.py --model models/
Kami menyediakan model yang telah dilatih dan disempurnakan di S3 kami di sini. Cara termudah untuk mengunduh dan menggunakan model ini adalah dengan menjalankan skrip interact.py
untuk berbicara dengan model tersebut. Tanpa argumen apa pun, skrip ini akan secara otomatis mengunduh dan menyimpan model kita dalam cache.
Skrip pelatihan dapat digunakan dalam pengaturan GPU tunggal atau multi GPU:
python ./train.py # Single GPU training
python -m torch.distributed.launch --nproc_per_node=8 ./train.py # Training on 8 GPUs
Skrip pelatihan menerima beberapa argumen untuk mengubah pelatihan:
Argumen | Jenis | Nilai bawaan | Keterangan |
---|---|---|---|
kumpulan data_jalur | str | "" | Jalur atau url kumpulan data. Jika kosong unduh dari S3. |
kumpulan data_cache | str | './dataset_cache.bin' | Jalur atau url cache kumpulan data |
model | str | "openai-gpt" | Jalur, url, atau nama pendek model |
jumlah_kandidat | int | 2 | Jumlah kandidat untuk pelatihan |
max_history | int | 2 | Jumlah pertukaran sebelumnya yang perlu disimpan dalam sejarah |
train_batch_size | int | 4 | Ukuran batch untuk pelatihan |
valid_batch_size | int | 4 | Ukuran batch untuk validasi |
gradien_akumulasi_langkah | int | 8 | Akumulasi gradien pada beberapa langkah |
lr | float | 6.25e-5 | Kecepatan pembelajaran |
lm_coef | float | 1.0 | Koefisien kerugian LM |
mc_coef | float | 1.0 | Koefisien kerugian pilihan ganda |
max_norm | float | 1.0 | Memotong norma gradien |
n_zaman | int | 3 | Jumlah periode pelatihan |
kepribadian_permutasi | int | 1 | Banyaknya permutasi kalimat kepribadian |
perangkat | str | "cuda" if torch.cuda.is_available() else "cpu" | Perangkat (cuda atau cpu) |
fp16 | str | "" | Atur ke O0, O1, O2 atau O3 untuk pelatihan fp16 (lihat dokumentasi apex) |
peringkat_lokal | int | -1 | Peringkat lokal untuk pelatihan terdistribusi (-1: tidak terdistribusi) |
Berikut ini cara mereproduksi hasil kami di server dengan 8 GPU V100 (sesuaikan jumlah node dan ukuran batch dengan konfigurasi Anda):
python -m torch.distributed.launch --nproc_per_node=8 ./train.py --gradient_accumulation_steps=4 --lm_coef=2.0 --max_history=2 --n_epochs=1 --num_candidates=4 --personality_permutations=2 --train_batch_size=2 --valid_batch_size=2
Model ini harus memberikan Hits@1 lebih dari 79, kebingungan 20,5 dan F1 16,5 menggunakan skrip evaluasi convai2 (lihat di bawah).
Angka tersebut sedikit lebih rendah dibandingkan angka yang kami peroleh pada kompetisi ConvAI2. Inilah yang dapat Anda sesuaikan untuk mencapai hasil yang sama:
Skrip pelatihan menyimpan semua eksperimen dan pos pemeriksaan dalam sub-folder yang diberi nama dengan stempel waktu eksperimen di folder ./runs
pada folder dasar repositori.
Anda kemudian dapat menggunakan skrip interaktif untuk berinteraksi dengan model hanya dengan menunjuk ke folder ini.
Berikut adalah contoh baris perintah untuk menjalankan skrip interaktif:
python ./interact.py --model_checkpoint ./data/Apr17_13-31-38_thunder/ # run the interactive script with a training checkpoint
python ./interact.py # run the interactive script with the finetuned model on our S3
Model yang disempurnakan akan memberikan FINAL Hits@1: 0,715
Skrip interaktif menerima beberapa argumen untuk mengubah algoritma decoding:
Argumen | Jenis | Nilai bawaan | Keterangan |
---|---|---|---|
kumpulan data_jalur | str | "" | Jalur atau url kumpulan data. Jika kosong unduh dari S3. |
kumpulan data_cache | str | './dataset_cache.bin' | Jalur atau url cache kumpulan data |
model | str | "openai-gpt" | Jalur, url, atau nama pendek model |
max_history | int | 2 | Jumlah ucapan sebelumnya yang perlu disimpan dalam sejarah |
perangkat | str | cuda if torch.cuda.is_available() lain cpu | Perangkat (cuda atau cpu) |
tidak ada_sampel | store_true tindakan_true | Atur untuk menggunakan decoding serakah alih-alih pengambilan sampel | |
panjang_maks | int | 20 | Panjang maksimum ucapan keluaran |
min_length | int | 1 | Panjang minimal keluaran ujaran |
benih | int | 42 | Benih |
suhu | int | 0.7 | Pengambilan sampel suhu softmax |
teratas_k | int | 0 | Filter token teratas sebelum pengambilan sampel ( <=0 : tanpa pemfilteran) |
atas_p | float | 0.9 | Pemfilteran inti (top-p) sebelum pengambilan sampel ( <=0.0 : tanpa pemfilteran) |
Untuk menjalankan skrip evaluasi tantangan ConvAI2, Anda harus menginstal ParlAI
terlebih dahulu di folder basis repo seperti ini:
git clone https://github.com/facebookresearch/ParlAI.git
cd ParlAI
python setup.py develop
Anda kemudian dapat menjalankan skrip evaluasi dari folder dasar ParlAI
:
cd ParlAI
python ../convai_evaluation.py --eval_type hits@1 # to download and evaluate our fine-tuned model on hits@1 metric
python ../convai_evaluation.py --eval_type hits@1 --model_checkpoint ./data/Apr17_13-31-38_thunder/ # to evaluate a training checkpoint on hits@1 metric
Skrip evaluasi menerima beberapa argumen untuk memilih metrik evaluasi dan mengubah algoritma decoding:
Argumen | Jenis | Nilai bawaan | Keterangan |
---|---|---|---|
eval_type | str | "hits@1" | Evaluasi model pada metrik hits@1 , ppl atau f1 pada kumpulan data validasi ConvAI2 |
model | str | "openai-gpt" | Jalur, url, atau nama pendek model |
max_history | int | 2 | Jumlah ucapan sebelumnya yang perlu disimpan dalam sejarah |
perangkat | str | cuda if torch.cuda.is_available() lain cpu | Perangkat (cuda atau cpu) |
tidak ada_sampel | store_true tindakan_true | Atur untuk menggunakan decoding serakah alih-alih pengambilan sampel | |
panjang_maks | int | 20 | Panjang maksimum ucapan keluaran |
min_length | int | 1 | Panjang minimal keluaran ujaran |
benih | int | 42 | Benih |
suhu | int | 0.7 | Pengambilan sampel suhu softmax |
teratas_k | int | 0 | Filter token teratas sebelum pengambilan sampel ( <=0 : tanpa pemfilteran) |
atas_p | float | 0.9 | Pemfilteran inti (top-p) sebelum pengambilan sampel ( <=0.0 : tanpa pemfilteran) |
lihat example_entry.py
, dan komentar di atas.
Jika Anda menggunakan kode ini dalam penelitian Anda, Anda dapat mengutip makalah lokakarya NeurIPS CAI kami:
@article{DBLP:journals/corr/abs-1901-08149,
author = {Thomas Wolf and
Victor Sanh and
Julien Chaumond and
Clement Delangue},
title = {TransferTransfo: {A} Transfer Learning Approach for Neural Network
Based Conversational Agents},
journal = {CoRR},
volume = {abs/1901.08149},
year = {2019},
url = {http://arxiv.org/abs/1901.08149},
archivePrefix = {arXiv},
eprint = {1901.08149},
timestamp = {Sat, 02 Feb 2019 16:56:00 +0100},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1901-08149},
bibsource = {dblp computer science bibliography, https://dblp.org}
}