Una biblioteca completa para postentrenar modelos de cimentación
TRL es una biblioteca de vanguardia diseñada para modelos básicos posteriores al entrenamiento que utiliza técnicas avanzadas como ajuste fino supervisado (SFT), optimización de políticas próximas (PPO) y optimización de preferencias directas (DPO). Construido sobre el? En el ecosistema de Transformers, TRL admite una variedad de arquitecturas y modalidades de modelos, y se puede ampliar a través de varias configuraciones de hardware.
Eficiente y escalable :
PEFT
permite el entrenamiento en modelos grandes con hardware modesto mediante cuantificación y LoRA/QLoRA.Interfaz de línea de comandos (CLI) : una interfaz sencilla que le permite ajustar e interactuar con los modelos sin necesidad de escribir código.
Entrenadores : se puede acceder fácilmente a varios métodos de ajuste a través de entrenadores como SFTTrainer
, DPOTrainer
, RewardTrainer
, ORPOTrainer
y más.
AutoModels : utilice clases de modelos predefinidas como AutoModelForCausalLMWithValueHead
para simplificar el aprendizaje por refuerzo (RL) con LLM.
Instale la biblioteca usando pip
:
pip install trl
Si desea utilizar las funciones más recientes antes del lanzamiento oficial, puede instalar TRL desde la fuente:
pip install git+https://github.com/huggingface/trl.git
Si desea utilizar los ejemplos, puede clonar el repositorio con el siguiente comando:
git clone https://github.com/huggingface/trl.git
Puede utilizar la interfaz de línea de comandos (CLI) de TRL para comenzar rápidamente con el ajuste fino supervisado (SFT) y la optimización directa de preferencias (DPO), o verificar su modelo con la CLI de chat:
OFV:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
DPO:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
Charlar:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
Lea más sobre CLI en la sección de documentación correspondiente o utilice --help
para obtener más detalles.
Para obtener más flexibilidad y control sobre la capacitación, TRL proporciona clases de capacitadores dedicados para entrenar posteriormente modelos de lenguaje o adaptadores PEFT en un conjunto de datos personalizado. Cada entrenador en TRL es una envoltura ligera alrededor del? Entrenador de Transformers y admite de forma nativa métodos de entrenamiento distribuidos como DDP, DeepSpeed ZeRO y FSDP.
SFTTrainer
A continuación se muestra un ejemplo básico de cómo utilizar SFTTrainer
:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
A continuación se muestra un ejemplo básico de cómo utilizar RewardTrainer
:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
implementa una optimización de estilo REINFORCE para RLHF que es más eficaz y eficiente en memoria que PPO. A continuación se muestra un ejemplo básico de cómo utilizar RLOOTrainer
:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
implementa el popular algoritmo de Optimización de Preferencia Directa (DPO) que se utilizó para entrenar posteriormente Llama 3 y muchos otros modelos. A continuación se muestra un ejemplo básico de cómo utilizar DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
Si desea contribuir a trl
o personalizarlo según sus necesidades, asegúrese de leer la guía de contribución y de realizar una instalación de desarrollo:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
El código fuente de este repositorio está disponible bajo la licencia Apache-2.0.