Uma biblioteca abrangente para modelos de fundação pós-treinamento
TRL é uma biblioteca de ponta projetada para modelos básicos pós-treinamento usando técnicas avançadas como ajuste fino supervisionado (SFT), otimização de política proximal (PPO) e otimização de preferência direta (DPO). Construído em cima do? Ecossistema de Transformers, o TRL suporta uma variedade de arquiteturas e modalidades de modelos e pode ser ampliado em várias configurações de hardware.
Eficiente e escalável :
PEFT
permite o treinamento em modelos grandes com hardware modesto via quantização e LoRA/QLoRA.Interface de linha de comando (CLI) : uma interface simples permite ajustar e interagir com modelos sem a necessidade de escrever código.
Treinadores : Vários métodos de ajuste fino são facilmente acessíveis por meio de treinadores como SFTTrainer
, DPOTrainer
, RewardTrainer
, ORPOTrainer
e muito mais.
AutoModels : Use classes de modelo predefinidas como AutoModelForCausalLMWithValueHead
para simplificar o aprendizado por reforço (RL) com LLMs.
Instale a biblioteca usando pip
:
pip install trl
Se quiser usar os recursos mais recentes antes do lançamento oficial, você pode instalar o TRL a partir do código-fonte:
pip install git+https://github.com/huggingface/trl.git
Se quiser usar os exemplos você pode clonar o repositório com o seguinte comando:
git clone https://github.com/huggingface/trl.git
Você pode usar a interface de linha de comando (CLI) do TRL para começar rapidamente com o ajuste fino supervisionado (SFT) e a otimização de preferência direta (DPO), ou verificar seu modelo com a CLI de bate-papo:
OFVM:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
DPO:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
Bater papo:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
Leia mais sobre CLI na seção de documentação relevante ou use --help
para obter mais detalhes.
Para maior flexibilidade e controle sobre o treinamento, o TRL oferece aulas de treinamento dedicadas para modelos de linguagem pós-treinamento ou adaptadores PEFT em um conjunto de dados personalizado. Cada treinador no TRL é um invólucro leve em torno do? Treinador de Transformers e oferece suporte nativo a métodos de treinamento distribuído como DDP, DeepSpeed ZeRO e FSDP.
SFTTrainer
Aqui está um exemplo básico de como usar o SFTTrainer
:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
Aqui está um exemplo básico de como usar o RewardTrainer
:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
implementa uma otimização estilo REINFORCE para RLHF que tem melhor desempenho e eficiência de memória do que PPO. Aqui está um exemplo básico de como usar o RLOOTrainer
:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
implementa o popular algoritmo Direct Preference Optimization (DPO) que foi usado para pós-treinar o Llama 3 e muitos outros modelos. Aqui está um exemplo básico de como usar o DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
Se você quiser contribuir com trl
ou personalizá-lo de acordo com suas necessidades, leia o guia de contribuição e faça uma instalação de desenvolvimento:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
O código-fonte deste repositório está disponível sob a licença Apache-2.0.