Perpustakaan komprehensif untuk model pondasi pasca pelatihan
TRL adalah perpustakaan mutakhir yang dirancang untuk model dasar pasca pelatihan menggunakan teknik canggih seperti Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), dan Direct Preference Optimization (DPO). Dibangun di atas? Ekosistem Transformers, TRL mendukung berbagai model arsitektur dan modalitas, dan dapat ditingkatkan di berbagai pengaturan perangkat keras.
Efisien dan terukur :
PEFT
memungkinkan pelatihan pada model besar dengan perangkat keras sederhana melalui kuantisasi dan LoRA/QLoRA.Antarmuka Baris Perintah (CLI) : Antarmuka sederhana memungkinkan Anda menyempurnakan dan berinteraksi dengan model tanpa perlu menulis kode.
Pelatih : Berbagai metode penyesuaian mudah diakses melalui pelatih seperti SFTTrainer
, DPOTrainer
, RewardTrainer
, ORPOTrainer
, dan banyak lagi.
AutoModels : Gunakan kelas model yang telah ditentukan sebelumnya seperti AutoModelForCausalLMWithValueHead
untuk menyederhanakan pembelajaran penguatan (RL) dengan LLM.
Instal perpustakaan menggunakan pip
:
pip install trl
Jika Anda ingin menggunakan fitur terbaru sebelum rilis resmi, Anda dapat menginstal TRL dari sumber:
pip install git+https://github.com/huggingface/trl.git
Jika Anda ingin menggunakan contoh, Anda dapat mengkloning repositori dengan perintah berikut:
git clone https://github.com/huggingface/trl.git
Anda dapat menggunakan TRL Command Line Interface (CLI) untuk memulai dengan cepat Supervised Fine-tuning (SFT) dan Direct Preference Optimization (DPO), atau memeriksa model Anda dengan chat CLI:
SFT:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
DPO:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
Mengobrol:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
Baca selengkapnya tentang CLI di bagian dokumentasi yang relevan atau gunakan --help
untuk detail selengkapnya.
Untuk lebih banyak fleksibilitas dan kontrol atas pelatihan, TRL menyediakan kelas pelatih khusus untuk model bahasa pasca-latihan atau adaptor PEFT pada kumpulan data khusus. Setiap pelatih di TRL adalah pembungkus ringan di sekeliling ? Pelatih Transformers dan secara asli mendukung metode pelatihan terdistribusi seperti DDP, DeepSpeed ZeRO, dan FSDP.
SFTTrainer
Berikut adalah contoh dasar cara menggunakan SFTTrainer
:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
Berikut adalah contoh dasar cara menggunakan RewardTrainer
:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
mengimplementasikan optimasi gaya REINFORCE untuk RLHF yang lebih berperforma dan hemat memori dibandingkan PPO. Berikut adalah contoh dasar cara menggunakan RLOOTrainer
:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
mengimplementasikan algoritma Direct Preference Optimization (DPO) populer yang digunakan untuk pasca-pelatihan Llama 3 dan banyak model lainnya. Berikut adalah contoh dasar cara menggunakan DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
Jika Anda ingin berkontribusi pada trl
atau menyesuaikannya dengan kebutuhan Anda, pastikan untuk membaca panduan kontribusi dan pastikan Anda melakukan instalasi dev:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
Kode sumber repositori ini tersedia di bawah Lisensi Apache-2.0.