Repo ini berisi kode yang menyertai makalah Pra-Pelatihan Model Bahasa dengan Preferensi Manusia. Basis kode dibangun berdasarkan Trainer
Hugging Face Transformers dan berisi implementasi lima tujuan untuk prapelatihan dengan umpan balik manusia (PHF) yang dibahas dalam makalah, serta callback dan skrip untuk mengevaluasinya.
Tujuan PHF dapat diimplementasikan dengan menganotasi data pelatihan dengan reward dan menimpa Trainer.compute_loss
untuk menggunakannya sebagai sinyal pelatihan tambahan. Hadiah diberikan oleh contoh apo.scorers.Scorer
: sebuah objek yang dapat menentukan, untuk bagian teks tertentu, apakah teks tersebut selaras atau tidak selaras dengan preferensi manusia seperti tidak menyinggung. Pencetak skor juga digunakan untuk mengevaluasi sampel dari LM yang dilatih PHF.
Basis kode dibangun berdasarkan ekosistem dan tongkat Hugging Face (untuk pemantauan dan manajemen eksperimen).
Kami mengasumsikan Python 3.9+. Untuk menjalankan skrip pelatihan MLE pada tugas toksisitas, lakukan:
pip install -r requirements.txt
wandb login # or set `WANDB_API_KEY` and `WANDB_PROJECT` env variables
export OPENAI_API_KEY= ' sk-your_key ' # needed for evaluation
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml
Skrip train.py
memerlukan jalur ke dua file konfigurasi: untuk tugas dan untuk metode. File konfigurasi untuk tugas ( toxicity
, pii
, pep8
) disimpan dalam file YAML: configs/{task}/pretrain.yml
(untuk eksperimen pra-pelatihan) dan configs/{task}/finetuning.yml
(untuk menyempurnakan). File konfigurasi untuk metode disimpan secara terpisah di direktori configs/{task}
. Setiap pasangan konfigurasi metode tugas (untuk pra-pelatihan dan untuk penyesuaian) berisi hyperparameter yang kami gunakan dalam eksperimen dan memungkinkan untuk mereproduksi hasil dari makalah.
Parameter individual dapat diganti dari baris perintah menggunakan argumen override
. Misalnya:
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml --override training.per_device_train_batch_size=8
Nama | File konfigurasi | Data pelatihan | Pencetak gol | Keterangan |
---|---|---|---|---|
Toksisitas | configs/toxicity | tomekkorbak/pile-detoxify | DetoxifyToxicityScorer | Skor misalignment adalah kemungkinan toksisitas menurut detoksifikasi |
PII | configs/pii | tomekkorbak/pile-pii-scrubadub | PIIScorer | Skor misalignment adalah jumlah PII (misalnya nama, URL) per karakter, menurut scrubadub |
PEP8 | configs/pep8 | kejian/codeparrot-train-more-filter-3.3b-cleaned | PEP8Scorer | Skor misalignment adalah jumlah pelanggaran PEP8 per karakter, menurut pycodestyle |
Enam tujuan pelatihan dengan umpan balik manusia yang digunakan dalam eksperimen kami diterapkan sebagai berikut:
Nama | Kelas objektif | Keterangan |
---|---|---|
MLE | MLE | Pembungkus tipis di sekitar PyTorch CrossEntropyLoss |
Penyaringan | MLE | Anda perlu menyetel dataset.filter_threshold di konfigurasi |
Pelatihan bersyarat | MLE | Anda juga perlu menyetel dataset.conditional_training_config di config` |
Kemungkinannya | Unlikelihood | Anda juga perlu menyetel hyperparameter objective.score_threshold dan objective.alpha |
AWR | AWR | Anda juga perlu menyetel hyperparameter objective.alpha dan objective.beta |
RWR | AWR | Kasus khusus AWR dengan objective.alpha=1 |
Model yang dilatih sebelumnya dalam eksperimen kami tersedia di HugginFace Hub:
Tujuan | Toksisitas | PEP8 | PII |
---|---|---|---|
MLE | tomekkorbak/goofy_pasteur | kejian/mle perkasa | tomekkorbak/nervous_wozniak |
Memfilter median | tomekkorbak/amazing_shannon | kejian/penyaringan perkasa | tomekkorbak/sombong_carson |
Bersyarat | tomekkorbak/hungry_saha | kejian/perkasa-kondisional | tomekkorbak/boring_mcclintock |
UL | tomekkorbak/nifty_banach | kejian/perkasa-ul | tomekkorbak/affectionate_wescoff |
AWR | tomekkorbak/upbeat_ramanujan | kejian/vigor-awr | tomekkorbak/confident_knuth |
RWR | tomekkorbak/keen_clarke | kejian/perkasa-rwr | tomekkorbak/gifted_hugle |
Pada setiap langkah evaluasi, apo.callbacks.GenerateAndScoreCallback
mengulangi daftar GenerationScenario
yang disediakan dalam file konfigurasi tugas. Untuk setiap skenario, num_samples
sampel dihasilkan dan metrik tongkat sihir berikut dihitung:
score
, ketidakselarasan rata-rata (di seluruh num_samples
sampel) dari sampel yang dihasilkan yang ditetapkan oleh pencetak golscore_max@25
, skor maksimum rata-rata dalam 25 sampel (mirip dengan toksisitas maksimum yang diharapkan dalam makalah RealToxicityPrompts)current_samples
, sebuah wandb.Table
Tabel sampel beserta petunjuknya (jika ada) dan skornya Selain menilai sampel LM, kami menggunakan apo.callbacks.KLGPT3Callback
untuk memperkirakan KL LM saat ini dari GPT-3. Hal ini memerlukan pengambilan sampel dari GPT-3 yang di-cache dan digunakan kembali pada iterasi berikutnya. |
.
├── apo
│ ├── callbacks.py # callbacks implementing the evaluation pipeline
│ ├── dataset_wrappers.py # an iterable for streaming blocks of tokens for training
│ ├── kl_gpt3.py # logic for measuring KL from GPT-3
│ └── metrics.py # metrics computed on LM samples (and dataset elements, for debugging)
│ └── models.py # a subclass for GPT2LMHeadModel adding value heads and exposing implementation details
│ └── objectives.py # classes implementing loss functions
│ ├── scorer_utils.py
│ ├── scorers.py # classes for scoring LM samples and dataset elements
│ └── trainer.py # a subclass for Hugging Face Trainer exposing some functionalities
│ └── utils.py
├── configs
│ └── pep8
│ └── pii
│ └── toxicity
├── scripts # scripts for evaluation
│ dataset_builders # scripts used to generate some of the datasets
├── resources # small, git-tracked files from which lists of words or prompts are loaded
└── train.py # the main training script
@misc { https://doi.org/10.48550/arxiv.2302.08582 ,
doi = { 10.48550/ARXIV.2302.08582 } ,
url = { https://arxiv.org/abs/2302.08582 } ,
author = { Korbak, Tomasz and Shi, Kejian and Chen, Angelica and Bhalerao, Rasika and Buckley, Christopher L. and Phang, Jason and Bowman, Samuel R. and Perez, Ethan } ,
keywords = { Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences } ,
title = { Pretraining Language Models with Human Preferences } ,
publisher = { arXiv } ,
year = { 2023 } ,
copyright = { Creative Commons Attribution 4.0 International }
}