该存储库提供了自回归变压器(仅限 GPU)上的一级模型编辑(ROME)的实现。我们目前支持 OpenAI 的 GPT-2 XL (1.5B) 和 EleutherAI 的 GPT-J (6B)。 EleutherAI 预计很快就会发布类似 20B GPT 的模型;我们希望尽快支持它。
如果您发现任何问题,请随时提出问题;我们正在积极开发这个存储库,并将密切监控票证。
我们建议使用conda
来管理 Python、CUDA 和 PyTorch 相关的依赖项,并使用pip
来管理其他所有内容。首先,只需安装conda
并运行:
./scripts/setup_conda.sh
notebooks/causal_trace.ipynb
演示了因果跟踪,可以对其进行修改以将跟踪应用于任何语句的处理。
notebooks/rome.ipynb
演示了 ROME。 API简单;只需指定以下形式的请求重写:
request = {
"prompt" : "{} plays the sport of" ,
"subject" : "LeBron James" ,
"target_new" : {
"str" : "football"
}
}
笔记本中包含了几个类似的示例。
详细信息即将推出!
有关可用基线的说明,请参阅baselines/
。
experiments/evaluate.py
可用于评估baselines/
中的任何方法。要开始使用(例如在 GPT-2 XL 上使用 ROME),请运行:
python3 -m experiments.evaluate
--alg_name=ROME
--model_name=gpt2-xl
--hparams_fname=gpt2-xl.json
每次运行的结果以特定格式存储在results/<method_name>/run_<run_id>
中:
results/
| __ ROME/
| __ run_ < run_id > /
| __ params.json
| __ case_0.json
| __ case_1.json
| __ ...
| __ case_10000.json
要总结结果,您可以使用experiments/summarize.py
:
python3 -m experiments.summarize --dir_name=ROME --runs=run_ < run_id >
运行python3 -m experiments.evaluate -h
或python3 -m experiments.summarize -h
提供有关命令行标志的详细信息。
假设您有一个新方法X
,并希望在 CounterFact 上对其进行基准测试。将X
与我们的跑步者集成:
HyperParams
子类为XHyperParams
并指定所有超参数字段。有关示例实现,请参阅ROMEHyperParameters
。hparams/X/gpt2-xl.json
创建一个超参数文件并指定一些默认值。有关示例,请参阅hparams/ROME/gpt2-xl.json
。apply_X_to_model
,它接受多个参数并返回 (i) 重写的模型和 (ii) 已编辑参数的原始权重值(采用字典格式{weight_name: original_weight_value}
)。有关示例,请参阅rome/rome_main.py
。"X": (XHyperParams, apply_X_to_model)
将X
添加到experiments/evaluate.py
中的ALG_DICT
中。最后,运行主要脚本:
python3 -m experiments.evaluate
--alg_name=X
--model_name=gpt2-xl
--hparams_fname=gpt2-xl.json
python3 -m experiments.summarize --dir_name=X --runs=run_ < run_id >
我们目前仅支持使用 PyTorch 后端编辑自回归 HuggingFace 模型的方法。我们正在研究一组通用方法(可在例如 TensorFlow 上使用,无需 HuggingFace),并将很快发布。
@article { meng2022locating ,
title = { Locating and Editing Factual Associations in {GPT} } ,
author = { Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov } ,
journal = { Advances in Neural Information Processing Systems } ,
volume = { 35 } ,
year = { 2022 }
}