基礎モデルをトレーニング後に使用するための包括的なライブラリ
TRL は、教師あり微調整 (SFT)、近接ポリシー最適化 (PPO)、直接優先最適化 (DPO) などの高度な技術を使用して、トレーニング後の基礎モデル用に設計された最先端のライブラリです。 ?の上に建てられています。 Transformers エコシステムである TRL は、さまざまなモデル アーキテクチャとモダリティをサポートしており、さまざまなハードウェア セットアップにわたってスケールアップできます。
効率的でスケーラブル:
PEFT
との完全な統合により、量子化と LoRA/QLoRA を介した適度なハードウェアを使用した大規模モデルのトレーニングが可能になります。コマンド ライン インターフェイス (CLI) : シンプルなインターフェイスにより、コードを記述することなくモデルを微調整したり操作したりできます。
トレーナー: SFTTrainer
、 DPOTrainer
、 RewardTrainer
、 ORPOTrainer
などのトレーナーを介して、さまざまな微調整方法に簡単にアクセスできます。
AutoModels : AutoModelForCausalLMWithValueHead
などの事前定義されたモデル クラスを使用して、LLM による強化学習 (RL) を簡素化します。
pip
を使用してライブラリをインストールします。
pip install trl
正式リリース前に最新の機能を使用したい場合は、ソースから TRL をインストールできます。
pip install git+https://github.com/huggingface/trl.git
例を使用する場合は、次のコマンドを使用してリポジトリのクローンを作成できます。
git clone https://github.com/huggingface/trl.git
TRL コマンド ライン インターフェイス (CLI) を使用して、教師あり微調整 (SFT) と直接優先最適化 (DPO) をすぐに開始したり、チャット CLI でモデルの雰囲気をチェックしたりできます。
SFT:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
DPO:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
チャット:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
関連ドキュメントのセクションで CLI の詳細を参照するか、 --help
を使用して詳細を確認してください。
トレーニングの柔軟性と制御を高めるために、TRL は、カスタム データセットで言語モデルまたは PEFT アダプターをトレーニング後に使用するための専用のトレーナー クラスを提供します。 TRL の各トレーナーは、? の軽いラッパーです。 Transformers トレーナーは、DDP、DeepSpeed ZeRO、FSDP などの分散トレーニング方法をネイティブにサポートします。
SFTTrainer
以下は、 SFTTrainer
の使用方法の基本的な例です。
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
RewardTrainer
の基本的な使用例を次に示します。
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
PPO よりもパフォーマンスとメモリ効率が高い RLHF 用の REINFORCE スタイルの最適化を実装します。 RLOOTrainer
の基本的な使用例を次に示します。
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
Llama 3 や他の多くのモデルのポストトレーニングに使用された、一般的な Direct Preference Optimization (DPO) アルゴリズムを実装しています。 DPOTrainer
の基本的な使用例を次に示します。
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
trl
に貢献したい場合、またはニーズに合わせてカスタマイズしたい場合は、必ず貢献ガイドを読み、dev インストールを行ってください。
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
このリポジトリのソース コードは、Apache-2.0 ライセンスに基づいて利用できます。