このリポジトリは、 AdvPrompter (arxiv:2404.16873) の公式実装です。
「スター」を付けてください。このリポジトリを読んで、私たちの成果を気に入っていただいた (および/または使用していただいた) 場合は、論文を引用してください。ありがとうございます!
conda create -n advprompter python=3.11.4
conda activate advprompter
pip install -r requirements.txt
構成管理ツールとして hydra を使用します。主な設定ファイル: ./conf/{train,eval,eval_suffix_dataset,base}.yaml
AdvPrompter と TargetLLM は conf/base.yaml で指定されており、さまざまなオプションがすでに実装されています。
コードベースは、conf/base.yaml で対応するオプションを設定することにより、オプションで wandb をサポートします。
走る
python3 main.py --config-name=eval
指定されたデータセット上の TargetLLM に対して指定された AdvPrompter のパフォーマンスをテストします。 conf/base.yaml で TargetLLM と AdvPrompter を指定する必要があります。また、AdvPrompter が以前に微調整されていた場合は、peft_checkpoint へのパスを指定することもできます。
// conf/prompter/llama2.yaml を参照
lora_params:
ウォームスタート: true
lora_checkpoint: "peft_checkpoint へのパス"
評価中に生成されたサフィックスは、後で使用できるように、 ./exp/.../suffix_dataset
の run-directory の下にある新しいデータセットに保存されます。このようなデータセットは、TargetLLM に対してベースラインや手作りのサフィックスを評価するのにも役立ちます。また、次のコマンドを実行して評価できます。
python3 main.py --config-name=eval_suffix_dataset
eval_suffix_dataset.yaml
にsuffix_dataset_pth_dct
設定した後
走る
python3 main.py --config-name=train
指定された AdvPrompter を TargetLLM に対してトレーニングします。上記で指定した評価を定期的に自動的に実行し、後のウォームスタートに備えて AdvPrompter の中間バージョンを./exp/.../checkpoints
の下の run-directory に保存します。チェックポイントは、モデル構成のlora_checkpoint
パラメーターで指定できます (「1.1 評価」で説明)。トレーニングでは、各エポックについて、 AdvPrompterOpt で生成されたターゲット サフィックスも./exp/.../suffix_opt_dataset
に保存されます。これにより、 train.yaml
の pretrain で対応するパスを指定することで、このようなサフィックスのデータセットでの事前トレーニングが可能になります。
conf/train.yaml で考慮すべき重要なハイパーパラメータ: [epochs, lr, top_k, num_chunks, lambda_val]
注: target_llm.llm_params.checkpoint をローカル パスに置き換えることもできます。
例 1: Vicuna-7B の AdvPrompter:
python3 main.py --config-name=train target_llm=vicuna_chat target_llm.llm_params.model_name=vicuna-7b-v1.5
例 2: Vicuna-13B の AdvPrompter:
python3 main.py --config-name=train target_llm=vicuna_chat target_llm.llm_params.model_name=vicuna-13b-v1.5 target_llm.llm_params.checkpoint=lmsys/vicuna-13b-v1.5 train.q_params.num_chunks=2
例 3: Mistral-7B-chat の AdvPrompter:
python3 main.py --config-name=train target_llm=mistral_chat
例 4: Llama2-7B チャットの AdvPrompter:
python3 main.py --config-name=train target_llm=llama2_chat train.q_params.lambda_val=150
アンセルム・パウルス*、アルマン・ザルマガンベトフ*、チュアン・グオ、ブランドン・エイモス**、ユアンドン・ティアン**
(* = 同等の第一著者、** = 同等のアドバイス)
私たちのソースコードは CC-BY-NC 4.0 ライセンスの下にあります。