该存储库包含论文《根据人类偏好预训练语言模型》随附的代码。该代码库围绕 Hugging Face Transformers 的Trainer
构建,包含论文中讨论的利用人类反馈 (PHF) 进行预训练的五个目标的实现,以及用于评估它们的回调和脚本。
PHF 目标可以通过用奖励注释训练数据并覆盖Trainer.compute_loss
以将其用作附加训练信号来实现。奖励由apo.scorers.Scorer
的实例提供:对于给定的文本片段,该对象能够确定它是否与人类偏好(例如非攻击性)一致或不一致。评分器还用于评估来自经过 PHF 训练的 LM 的样本。
该代码库围绕 Hugging Face 生态系统和魔杖(用于监控和实验管理)构建。
我们假设 Python 3.9+。要在毒性任务上运行 MLE 训练脚本,请执行以下操作:
pip install -r requirements.txt
wandb login # or set `WANDB_API_KEY` and `WANDB_PROJECT` env variables
export OPENAI_API_KEY= ' sk-your_key ' # needed for evaluation
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml
train.py
脚本需要两个配置文件的路径:任务和方法。任务的配置文件( toxicity
、 pii
、 pep8
)存储在YAML文件中: configs/{task}/pretrain.yml
(用于预训练实验)和configs/{task}/finetuning.yml
(用于微调)。方法的配置文件单独存储在configs/{task}
目录中。每个任务方法配置对(用于预训练和微调)包含我们在实验中使用的超参数,并允许重现论文中的结果。
可以使用override
参数从命令行覆盖各个参数。例如:
python train.py --task configs/toxicity/pretrain.yml --method configs/toxicity/mle.yml --override training.per_device_train_batch_size=8
姓名 | 配置文件 | 训练数据 | 得分手 | 描述 |
---|---|---|---|---|
毒性 | configs/toxicity | tomekkorbak/pile-detoxify | DetoxifyToxicityScorer | 错位分数是根据解毒的毒性概率 |
个人身份信息 | configs/pii | tomekkorbak/pile-pii-scrubadub | PIIScorer | 根据 scrapadub 的数据,错位分数是每个字符的 PII(例如名称、URL)数量 |
PEP8 | configs/pep8 | kejian/codeparrot-train-more-filter-3.3b-cleaned | PEP8Scorer | 根据 pycodestyle,错位分数是每个字符 PEP8 违规的数量 |
我们的实验中使用的人类反馈训练的六个目标的实现如下:
姓名 | 客观类 | 描述 |
---|---|---|
最大LE | MLE | PyTorch CrossEntropyLoss 的薄包装 |
过滤 | MLE | 您需要在配置中设置dataset.filter_threshold |
有条件的培训 | MLE | 您还需要在 config` 中设置dataset.conditional_training_config |
可能性 | Unlikelihood | 您还需要设置超参数objective.score_threshold 和objective.alpha |
预警机 | AWR | 您还需要设置超参数objective.alpha 和objective.beta |
反应堆 | AWR | objective.alpha=1 的 AWR 特例 |
我们实验中预训练的模型可在 HugginFace Hub 上找到:
客观的 | 毒性 | PEP8 | 个人身份信息 |
---|---|---|---|
最大LE | 托梅科巴克/goofy_pasteur | 科健/mighty-mle | 托梅科巴克/nervous_wozniak |
过滤中位数 | 托梅科巴克/amazing_shannon | 科健/强力过滤 | 托梅科巴克/cocky_carson |
有条件的 | 托梅科巴克/hungry_saha | kejian/强大的条件 | 托梅科巴克/boring_mcclintock |
UL | 托梅科巴克/nifty_banach | 克健/mighty-ul | tomekkorbak/affectionate_wescoff |
预警机 | tomekkorbak/upbeat_ramanujan | 科健/活力-awr | tomekkorbak/confident_knuth |
反应堆 | 托梅科巴克/keen_clarke | 科健/mighty-rwr | tomekkorbak/gifted_hugle |
在每个评估步骤中, apo.callbacks.GenerateAndScoreCallback
都会迭代任务配置文件中提供的GenerationScenario
列表。对于每个场景,都会生成num_samples
个样本,并计算以下 wandb 指标:
score
,评分器分配的生成样本的平均错位(跨num_samples
个样本)score_max@25
,25个样本的平均最大得分(类似于RealToxicityPrompts论文中预期的最大毒性)current_samples
,一个wandb.Table
样本及其提示(如果有)和分数除了对 LM 样本进行评分之外,我们还使用apo.callbacks.KLGPT3Callback
来估计 GPT-3 中当前 LM 的 KL。这需要从 GPT-3 中抽取样本,这些样本会被缓存并在后续迭代中重用。 |
.
├── apo
│ ├── callbacks.py # callbacks implementing the evaluation pipeline
│ ├── dataset_wrappers.py # an iterable for streaming blocks of tokens for training
│ ├── kl_gpt3.py # logic for measuring KL from GPT-3
│ └── metrics.py # metrics computed on LM samples (and dataset elements, for debugging)
│ └── models.py # a subclass for GPT2LMHeadModel adding value heads and exposing implementation details
│ └── objectives.py # classes implementing loss functions
│ ├── scorer_utils.py
│ ├── scorers.py # classes for scoring LM samples and dataset elements
│ └── trainer.py # a subclass for Hugging Face Trainer exposing some functionalities
│ └── utils.py
├── configs
│ └── pep8
│ └── pii
│ └── toxicity
├── scripts # scripts for evaluation
│ dataset_builders # scripts used to generate some of the datasets
├── resources # small, git-tracked files from which lists of words or prompts are loaded
└── train.py # the main training script
@misc { https://doi.org/10.48550/arxiv.2302.08582 ,
doi = { 10.48550/ARXIV.2302.08582 } ,
url = { https://arxiv.org/abs/2302.08582 } ,
author = { Korbak, Tomasz and Shi, Kejian and Chen, Angelica and Bhalerao, Rasika and Buckley, Christopher L. and Phang, Jason and Bowman, Samuel R. and Perez, Ethan } ,
keywords = { Computation and Language (cs.CL), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences } ,
title = { Pretraining Language Models with Human Preferences } ,
publisher = { arXiv } ,
year = { 2023 } ,
copyright = { Creative Commons Attribution 4.0 International }
}