此儲存庫提供了自回歸變壓器(僅限 GPU)上的一級模型編輯(ROME)的實作。我們目前支援 OpenAI 的 GPT-2 XL (1.5B) 和 EleutherAI 的 GPT-J (6B)。 EleutherAI 預計很快就會發布類似 20B GPT 的模型;我們希望盡快支持它。
如果您發現任何問題,請隨時提出問題;我們正在積極開發這個儲存庫,並將密切監控票證。
我們建議使用conda
來管理 Python、CUDA 和 PyTorch 相關的依賴項,並使用pip
來管理其他所有內容。首先,只需安裝conda
並運行:
./scripts/setup_conda.sh
notebooks/causal_trace.ipynb
示範了因果跟踪,可以對其進行修改以將跟踪應用於任何語句的處理。
notebooks/rome.ipynb
示範了 ROME。 API簡單;只需指定以下形式的請求重寫:
request = {
"prompt" : "{} plays the sport of" ,
"subject" : "LeBron James" ,
"target_new" : {
"str" : "football"
}
}
筆記本中包含了幾個類似的範例。
詳細資訊即將推出!
有關可用基線的說明,請參閱baselines/
。
experiments/evaluate.py
可用於評估baselines/
中的任何方法。若要開始使用(例如在 GPT-2 XL 上使用 ROME),請執行:
python3 -m experiments.evaluate
--alg_name=ROME
--model_name=gpt2-xl
--hparams_fname=gpt2-xl.json
每次運行的結果以特定格式儲存在results/<method_name>/run_<run_id>
中:
results/
| __ ROME/
| __ run_ < run_id > /
| __ params.json
| __ case_0.json
| __ case_1.json
| __ ...
| __ case_10000.json
要總結結果,您可以使用experiments/summarize.py
:
python3 -m experiments.summarize --dir_name=ROME --runs=run_ < run_id >
運行python3 -m experiments.evaluate -h
或python3 -m experiments.summarize -h
提供有關命令列標誌的詳細資訊。
假設您有一個新方法X
,並希望在 CounterFact 上進行基準測試。將X
與我們的跑步者整合:
HyperParams
子類別為XHyperParams
並指定所有超參數欄位。有關範例實現,請參閱ROMEHyperParameters
。hparams/X/gpt2-xl.json
建立一個超參數檔並指定一些預設值。有關範例,請參閱hparams/ROME/gpt2-xl.json
。apply_X_to_model
,它接受多個參數並傳回 (i) 重寫的模型和 (ii) 已編輯參數的原始權重值(採用字典格式{weight_name: original_weight_value}
)。有關範例,請參閱rome/rome_main.py
。"X": (XHyperParams, apply_X_to_model)
將X
加入experiments/evaluate.py
中的ALG_DICT
。最後,運行主要腳本:
python3 -m experiments.evaluate
--alg_name=X
--model_name=gpt2-xl
--hparams_fname=gpt2-xl.json
python3 -m experiments.summarize --dir_name=X --runs=run_ < run_id >
我們目前僅支援使用 PyTorch 後端編輯自回歸 HuggingFace 模型的方法。我們正在研究一組通用方法(可在例如 TensorFlow 上使用,無需 HuggingFace),並將很快發布。
@article { meng2022locating ,
title = { Locating and Editing Factual Associations in {GPT} } ,
author = { Kevin Meng and David Bau and Alex Andonian and Yonatan Belinkov } ,
journal = { Advances in Neural Information Processing Systems } ,
volume = { 35 } ,
year = { 2022 }
}