用于后期训练基础模型的综合库
TRL 是一个尖端库,专为训练后基础模型而设计,使用监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术。建立在? Transformers 生态系统,TRL 支持各种模型架构和模式,并且可以跨各种硬件设置进行扩展。
高效且可扩展:
PEFT
完全集成,可以通过量化和 LoRA/QLoRA 对具有适度硬件的大型模型进行训练。命令行界面 (CLI) :简单的界面让您无需编写代码即可微调模型并与模型交互。
训练器:通过SFTTrainer
、 DPOTrainer
、 RewardTrainer
、 ORPOTrainer
等训练器可以轻松访问各种微调方法。
AutoModels :使用AutoModelForCausalLMWithValueHead
等预定义模型类来简化法学硕士的强化学习 (RL)。
使用pip
安装库:
pip install trl
如果您想在正式发布之前使用最新功能,可以从源代码安装 TRL:
pip install git+https://github.com/huggingface/trl.git
如果您想使用示例,可以使用以下命令克隆存储库:
git clone https://github.com/huggingface/trl.git
您可以使用 TRL 命令行界面 (CLI) 快速开始监督微调 (SFT) 和直接偏好优化 (DPO),或者使用聊天 CLI 检查您的模型:
斯夫特:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
数据保护专员:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
聊天:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
在相关文档部分中阅读有关 CLI 的更多信息,或使用--help
了解更多详细信息。
为了提高训练的灵活性和控制力,TRL 提供了专用的训练器课程,用于在自定义数据集上训练后语言模型或 PEFT 适配器。 TRL 中的每个训练器都是围绕 ? 的轻型包装器。 Transformers 训练器本身支持 DDP、DeepSpeed ZeRO 和 FSDP 等分布式训练方法。
SFTTrainer
以下是如何使用SFTTrainer
的基本示例:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
以下是如何使用RewardTrainer
的基本示例:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
为 RLHF 实现了 REINFORCE 式优化,比 PPO 具有更高的性能和内存效率。以下是如何使用RLOOTrainer
的基本示例:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
实现了流行的直接偏好优化 (DPO) 算法,该算法用于对 Llama 3 和许多其他模型进行后期训练。以下是如何使用DPOTrainer
的基本示例:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
如果您想为trl
做出贡献或根据您的需要对其进行自定义,请务必阅读贡献指南并确保进行开发安装:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
该存储库的源代码可根据 Apache-2.0 许可证获得。