このリポジトリには、論文「人間の好みによる言語モデルの事前トレーニング」に付随するコードが含まれています。このコードベースは、Hugging Face Transformers' Trainer
中心に構築されており、論文で説明されているヒューマン フィードバック (PHF) による事前トレーニングのための 5 つの目標の実装と、それらを評価するためのコールバックとスクリプトが含まれています。
PHF 目標は、トレーニング データに報酬の注釈を付け、 Trainer.compute_loss
上書きして追加のトレーニング信号として使用することで実装できます。報酬は、 apo.scorers.Scorer
のインスタンスによって提供されます。これは、特定のテキストについて、不快感のなさなどの人間の好みと一致しているかずれているかを判断できるオブジェクトです。スコアラーは、PHF で訓練された LM からのサンプルを評価するためにも使用されます。
コードベースは、Hugging Face エコシステムと ワンド (監視と実験管理用) を中心に構築されています。
Python 3.9以降を想定しています。毒性タスクで MLE のトレーニング スクリプトを実行するには、次の手順を実行します。
pip install -r requirements.txt
wandb login # or set `WANDB_API_KEY` and `WANDB_PROJECT` env variables
export OPENAI_API_KEY= ' sk-your_key ' # needed for evaluation
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml
train.py
スクリプトには、タスク用とメソッド用の 2 つの構成ファイルへのパスが必要です。タスク ( toxicity
、 pii
、 pep8
) の構成ファイルは、YAML ファイル: configs/{task}/pretrain.yml
(事前トレーニング実験用) およびconfigs/{task}/finetuning.yml
(微調整用) に保存されます。メソッドの構成ファイルはconfigs/{task}
ディレクトリに個別に保存されます。各タスクとメソッドの構成ペア (事前トレーニング用と微調整用) には、実験で使用したハイパーパラメーターが含まれており、論文の結果を再現できます。
個々のパラメータは、 override
引数を使用してコマンド ラインからオーバーライドできます。例えば:
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml --override training.per_device_train_batch_size=8
名前 | 設定ファイル | トレーニングデータ | 得点者 | 説明 |
---|---|---|---|---|
毒性 | configs/toxicity | tomekkorbak/pile-detoxify | DetoxifyToxicityScorer | 不整列スコアは、解毒による毒性の確率です。 |
PII | configs/pii | tomekkorbak/pile-pii-scrubadub | PIIScorer | scrubadub によると、不整列スコアは文字あたりの PII (名前、URL など) の数です。 |
PEP8 | configs/pep8 | kejian/codeparrot-train-more-filter-3.3b-cleaned | PEP8Scorer | 不整列スコアは、pycodestyle に従った、文字ごとの PEP8 違反の数です。 |
私たちの実験で使用された人間のフィードバックを使用したトレーニングの 6 つの目標は、次のように実装されています。
名前 | 対象クラス | 説明 |
---|---|---|
MLE | MLE | PyTorch CrossEntropyLoss の薄いラッパー |
フィルタリング | MLE | configでdataset.filter_threshold を設定する必要があります |
条件付きトレーニング | MLE | config でdataset.conditional_training_config を設定する必要もあります |
可能性の低い | Unlikelihood | また、ハイパーパラメータobjective.score_threshold とobjective.alpha を設定する必要があります。 |
AWR | AWR | また、ハイパーパラメータobjective.alpha およびobjective.beta を設定する必要があります。 |
RWR | AWR | objective.alpha=1 の AWR の特殊なケース |
実験で事前トレーニングされたモデルは、HugginFace Hub で入手できます。
客観的 | 毒性 | PEP8 | PII |
---|---|---|---|
MLE | tomekkorbak/goofy_pasteur | ケジアン/マイティ・ムル | tomekkorbak/nervous_wozniak |
中央値のフィルタリング | tomekkorbak/amazing_shannon | kejian/強力なフィルタリング | tomekkorbak/cocky_carson |
条件付き | tomekkorbak/hungry_saha | ケジアン/強力な条件付き | tomekkorbak/boring_mcclintock |
UL | tomekkorbak/nifty_banach | ケジアン/マイティウル | tomekkorbak/affectionate_wescoff |
AWR | tomekkorbak/upbeat_ramanujan | ケジアン/活力-awr | tomekkorbak/confident_knuth |
RWR | tomekkorbak/keen_clarke | ケジャン/マイティ-rwr | tomekkorbak/gifted_hugle |
各評価ステップで、 apo.callbacks.GenerateAndScoreCallback
、タスク構成ファイルで提供されるGenerationScenario
のリストを反復処理します。シナリオごとに、 num_samples
サンプルが生成され、次の wandb メトリックが計算されます。
score
によって割り当てられた、生成されたサンプルの平均不整列 ( num_samples
サンプルにわたる)score_max@25
、25 サンプルの平均最大スコア (RealToxicityPrompts 論文の予想最大毒性と同様)current_samples
、サンプルとそのプロンプト (存在する場合) およびスコアのwandb.Table
LM サンプルのスコアリングに加えて、 apo.callbacks.KLGPT3Callback
を使用して GPT-3 から現在の LM の KL を推定します。これには、キャッシュされて後続の反復で再利用される GPT-3 からサンプルを描画する必要があります。 |
.
├── apo
│ ├── callbacks.py # callbacks implementing the evaluation pipeline
│ ├── dataset_wrappers.py # an iterable for streaming blocks of tokens for training
│ ├── kl_gpt3.py # logic for measuring KL from GPT-3
│ └── metrics.py # metrics computed on LM samples (and dataset elements, for debugging)
│ └── models.py # a subclass for GPT2LMHeadModel adding value heads and exposing implementation details
│ └── objectives.py # classes implementing loss functions
│ ├── scorer_utils.py
│ ├── scorers.py # classes for scoring LM samples and dataset elements
│ └── trainer.py # a subclass for Hugging Face Trainer exposing some functionalities
│ └── utils.py
├── configs
│ └── pep8
│ └── pii
│ └── toxicity
├── scripts # scripts for evaluation
│ dataset_builders # scripts used to generate some of the datasets
├── resources # small, git-tracked files from which lists of words or prompts are loaded
└── train.py # the main training script
@misc { https://doi.org/10.48550/arxiv.2302.08582 ,
doi = { 10.48550/ARXIV.2302.08582 } ,
url = { https://arxiv.org/abs/2302.08582 } ,
author = { Korbak, Tomasz and Shi, Kejian and Chen, Angelica and Bhalerao, Rasika and Buckley, Christopher L. and Phang, Jason and Bowman, Samuel R. and Perez, Ethan } ,
keywords = { Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences } ,
title = { Pretraining Language Models with Human Preferences } ,
publisher = { arXiv } ,
year = { 2023 } ,
copyright = { Creative Commons Attribution 4.0 International }
}