Une bibliothèque complète pour post-former les modèles de fondation
TRL est une bibliothèque de pointe conçue pour les modèles de base post-formation utilisant des techniques avancées telles que le réglage fin supervisé (SFT), l'optimisation des politiques proximales (PPO) et l'optimisation des préférences directes (DPO). Construit au-dessus du ? Écosystème Transformers, TRL prend en charge une variété d'architectures et de modalités de modèles et peut être étendu à diverses configurations matérielles.
Efficace et évolutif :
PEFT
permet la formation sur de grands modèles avec un matériel modeste via la quantification et LoRA/QLoRA.Interface de ligne de commande (CLI) : une interface simple vous permet d'affiner et d'interagir avec les modèles sans avoir besoin d'écrire du code.
Formateurs : Diverses méthodes de réglage fin sont facilement accessibles via des formateurs comme SFTTrainer
, DPOTrainer
, RewardTrainer
, ORPOTrainer
et plus encore.
AutoModels : utilisez des classes de modèles prédéfinies comme AutoModelForCausalLMWithValueHead
pour simplifier l'apprentissage par renforcement (RL) avec les LLM.
Installez la bibliothèque en utilisant pip
:
pip install trl
Si vous souhaitez utiliser les dernières fonctionnalités avant une version officielle, vous pouvez installer TRL à partir des sources :
pip install git+https://github.com/huggingface/trl.git
Si vous souhaitez utiliser les exemples, vous pouvez cloner le référentiel avec la commande suivante :
git clone https://github.com/huggingface/trl.git
Vous pouvez utiliser l'interface de ligne de commande (CLI) TRL pour démarrer rapidement avec le réglage fin supervisé (SFT) et l'optimisation des préférences directes (DPO), ou vérifier votre modèle avec la CLI de chat :
SFT :
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
DPD :
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
Chat:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
En savoir plus sur CLI dans la section de documentation appropriée ou utiliser --help
pour plus de détails.
Pour plus de flexibilité et de contrôle sur la formation, TRL propose des classes de formateurs dédiées pour post-entraîner des modèles de langage ou des adaptateurs PEFT sur un ensemble de données personnalisé. Chaque entraîneur de TRL est un emballage léger autour du ? Entraîneur Transformers et prend en charge nativement les méthodes de formation distribuées telles que DDP, DeepSpeed ZeRO et FSDP.
SFTTrainer
Voici un exemple de base d'utilisation du SFTTrainer
:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
Voici un exemple de base de la façon d'utiliser le RewardTrainer
:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
implémente une optimisation de style REINFORCE pour RLHF qui est plus performante et plus économe en mémoire que PPO. Voici un exemple de base d'utilisation du RLOOTrainer
:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
implémente l'algorithme populaire d'optimisation des préférences directes (DPO) qui a été utilisé pour post-entraîner Llama 3 et de nombreux autres modèles. Voici un exemple de base d'utilisation du DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
Si vous souhaitez contribuer à trl
ou le personnaliser selon vos besoins, assurez-vous de lire le guide de contribution et assurez-vous d'effectuer une installation en développement :
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
Le code source de ce référentiel est disponible sous la licence Apache-2.0.