此儲存庫包含論文《根據人類偏好預訓練語言模型》隨附的程式碼。該程式碼庫圍繞著 Hugging Face Transformers 的Trainer
構建,包含論文中討論的利用人類反饋 (PHF) 進行預訓練的五個目標的實現,以及用於評估它們的回調和腳本。
PHF 目標可以透過以獎勵註釋訓練資料並覆蓋Trainer.compute_loss
以將其用作附加訓練訊號來實現。獎勵由apo.scorers.Scorer
的實例提供:對於給定的文本片段,該物件能夠確定它是否與人類偏好(例如非攻擊性)一致或不一致。評分器也用於評估經過 PHF 訓練的 LM 的樣本。
該程式碼庫圍繞著 Hugging Face 生態系統和魔杖(用於監控和實驗管理)建構。
我們假設 Python 3.9+。若要在毒性任務上執行 MLE 訓練腳本,請執行以下操作:
pip install -r requirements.txt
wandb login # or set `WANDB_API_KEY` and `WANDB_PROJECT` env variables
export OPENAI_API_KEY= ' sk-your_key ' # needed for evaluation
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml
train.py
腳本需要兩個設定檔的路徑:任務和方法。任務的設定檔( toxicity
、 pii
、 pep8
)儲存在YAML檔案中: configs/{task}/pretrain.yml
(用於預訓練實驗)和configs/{task}/finetuning.yml
(用於微調)。方法的設定檔單獨儲存在configs/{task}
目錄中。每個任務方法配置對(用於預訓練和微調)包含我們在實驗中使用的超參數,並允許重現論文中的結果。
可以使用override
參數從命令列覆蓋各個參數。例如:
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml --override training.per_device_train_batch_size=8
姓名 | 設定檔 | 訓練資料 | 得分手 | 描述 |
---|---|---|---|---|
毒性 | configs/toxicity | tomekkorbak/pile-detoxify | DetoxifyToxicityScorer | 錯位分數是根據解毒的毒性機率 |
個人識別資訊 | configs/pii | tomekkorbak/pile-pii-scrubadub | PIIScorer | 根據 scrapadub 的數據,錯位分數是每個字元的 PII(例如名稱、URL)數量 |
PEP8 | configs/pep8 | kejian/codeparrot-train-more-filter-3.3b-cleaned | PEP8Scorer | 根據 pycodestyle,錯位分數是每個字元 PEP8 違規的數量 |
我們的實驗中所使用的人類回饋訓練的六個目標的實現如下:
姓名 | 客觀類 | 描述 |
---|---|---|
最大LE | MLE | PyTorch CrossEntropyLoss 的薄包裝 |
濾 | MLE | 您需要在配置中設定dataset.filter_threshold |
有條件的培訓 | MLE | 您還需要在 config` 中設定dataset.conditional_training_config |
可能性 | Unlikelihood | 您還需要設定超參數objective.score_threshold 和objective.alpha |
預警機 | AWR | 您還需要設定超參數objective.alpha 和objective.beta |
反應爐 | AWR | objective.alpha=1 的 AWR 特例 |
我們實驗中預先訓練的模型可在 HugginFace Hub 上找到:
客觀的 | 毒性 | PEP8 | 個人識別資訊 |
---|---|---|---|
最大LE | 托梅科巴克/goofy_pasteur | 科健/mighty-mle | 托梅科巴克/nervous_wozniak |
過濾中位數 | 托梅科巴克/amazing_shannon | 科健/強力過濾 | 托梅科巴克/cocky_carson |
有條件的 | 托梅科巴克/hungry_saha | kejian/強大的條件 | 托梅科巴克/boring_mcclintock |
UL | 托梅科巴克/nifty_banach | 克健/mighty-ul | tomekkorbak/affectionate_wescoff |
預警機 | tomekkorbak/upbeat_ramanujan | 科健/活力-awr | tomekkorbak/confident_knuth |
反應爐 | 托梅科巴克/keen_clarke | 科健/mighty-rwr | tomekkorbak/gifted_hugle |
在每個評估步驟中, apo.callbacks.GenerateAndScoreCallback
都會迭代任務設定檔中提供的GenerationScenario
清單。對於每個場景,都會產生num_samples
個樣本,並計算以下 wandb 指標:
score
,評分器分配的生成樣本的平均錯位(跨num_samples
個樣本)score_max@25
,25個樣本的平均最大分數(類似於RealToxicityPrompts論文中預期的最大毒性)current_samples
,一個wandb.Table
樣本及其提示(如果有)和分數除了對 LM 樣本進行評分之外,我們還使用apo.callbacks.KLGPT3Callback
來估計 GPT-3 中目前 LM 的 KL。這需要從 GPT-3 中抽取樣本,這些樣本會被快取並在後續迭代中重複使用。 |
.
├── apo
│ ├── callbacks.py # callbacks implementing the evaluation pipeline
│ ├── dataset_wrappers.py # an iterable for streaming blocks of tokens for training
│ ├── kl_gpt3.py # logic for measuring KL from GPT-3
│ └── metrics.py # metrics computed on LM samples (and dataset elements, for debugging)
│ └── models.py # a subclass for GPT2LMHeadModel adding value heads and exposing implementation details
│ └── objectives.py # classes implementing loss functions
│ ├── scorer_utils.py
│ ├── scorers.py # classes for scoring LM samples and dataset elements
│ └── trainer.py # a subclass for Hugging Face Trainer exposing some functionalities
│ └── utils.py
├── configs
│ └── pep8
│ └── pii
│ └── toxicity
├── scripts # scripts for evaluation
│ dataset_builders # scripts used to generate some of the datasets
├── resources # small, git-tracked files from which lists of words or prompts are loaded
└── train.py # the main training script
@misc { https://doi.org/10.48550/arxiv.2302.08582 ,
doi = { 10.48550/ARXIV.2302.08582 } ,
url = { https://arxiv.org/abs/2302.08582 } ,
author = { Korbak, Tomasz and Shi, Kejian and Chen, Angelica and Bhalerao, Rasika and Buckley, Christopher L. and Phang, Jason and Bowman, Samuel R. and Perez, Ethan } ,
keywords = { Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences } ,
title = { Pretraining Language Models with Human Preferences } ,
publisher = { arXiv } ,
year = { 2023 } ,
copyright = { Creative Commons Attribution 4.0 International }
}