Eine umfassende Bibliothek zum Nachtrainieren von Fundamentmodellen
TRL ist eine hochmoderne Bibliothek, die für Post-Training-Grundlagenmodelle entwickelt wurde und fortschrittliche Techniken wie Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO) und Direct Preference Optimization (DPO) verwendet. Aufgebaut auf dem ? TRL unterstützt im Transformers-Ökosystem eine Vielzahl von Modellarchitekturen und -modalitäten und kann auf verschiedene Hardware-Setups skaliert werden.
Effizient und skalierbar :
PEFT
ermöglicht das Training an großen Modellen mit bescheidener Hardware über Quantisierung und LoRA/QLoRA.Befehlszeilenschnittstelle (CLI) : Eine einfache Schnittstelle ermöglicht Ihnen die Feinabstimmung und Interaktion mit Modellen, ohne Code schreiben zu müssen.
Trainer : Verschiedene Feinabstimmungsmethoden sind über Trainer wie SFTTrainer
, DPOTrainer
, RewardTrainer
, ORPOTrainer
und mehr leicht zugänglich.
AutoModels : Verwenden Sie vordefinierte Modellklassen wie AutoModelForCausalLMWithValueHead
um das Reinforcement Learning (RL) mit LLMs zu vereinfachen.
Installieren Sie die Bibliothek mit pip
:
pip install trl
Wenn Sie die neuesten Funktionen vor einer offiziellen Veröffentlichung nutzen möchten, können Sie TRL von der Quelle installieren:
pip install git+https://github.com/huggingface/trl.git
Wenn Sie die Beispiele verwenden möchten, können Sie das Repository mit dem folgenden Befehl klonen:
git clone https://github.com/huggingface/trl.git
Sie können die TRL-Befehlszeilenschnittstelle (CLI) verwenden, um schnell mit Supervised Fine-tuning (SFT) und Direct Preference Optimization (DPO) zu beginnen, oder Ihr Modell mit der Chat-CLI auf Vibe überprüfen:
SFT:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
Datenschutzbeauftragter:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
Chat:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
Lesen Sie mehr über CLI im entsprechenden Dokumentationsabschnitt oder verwenden Sie --help
für weitere Details.
Für mehr Flexibilität und Kontrolle über das Training bietet TRL spezielle Trainerklassen zum Nachtrainieren von Sprachmodellen oder PEFT-Adaptern auf einem benutzerdefinierten Datensatz an. Jeder Trainer in TRL ist eine leichte Hülle um das ? Transformers-Trainer und unterstützt nativ verteilte Trainingsmethoden wie DDP, DeepSpeed ZeRO und FSDP.
SFTTrainer
Hier ist ein einfaches Beispiel für die Verwendung des SFTTrainer
:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
Hier ist ein einfaches Beispiel für die Verwendung des RewardTrainer
:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
implementiert eine Optimierung im REINFORCE-Stil für RLHF, die leistungsfähiger und speichereffizienter ist als PPO. Hier ist ein einfaches Beispiel für die Verwendung des RLOOTrainer
:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
implementiert den beliebten DPO-Algorithmus (Direct Preference Optimization), der zum Nachtrainieren von Llama 3 und vielen anderen Modellen verwendet wurde. Hier ist ein einfaches Beispiel für die Verwendung des DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
Wenn Sie zu trl
beitragen oder es an Ihre Bedürfnisse anpassen möchten, lesen Sie unbedingt den Beitragsleitfaden und führen Sie eine Entwicklerinstallation durch:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
Der Quellcode dieses Repositorys ist unter der Apache-2.0-Lizenz verfügbar.