ห้องสมุดที่ครอบคลุมสำหรับโมเดลพื้นฐานหลังการฝึก
TRL เป็นคลังข้อมูลล้ำสมัยที่ออกแบบมาสำหรับโมเดลพื้นฐานหลังการฝึกอบรมโดยใช้เทคนิคขั้นสูง เช่น Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO) และ Direct Preference Optimization (DPO) สร้างขึ้นบน ? ระบบนิเวศของ Transformers TRL รองรับสถาปัตยกรรมโมเดลและรูปแบบที่หลากหลาย และสามารถขยายขนาดได้ผ่านการตั้งค่าฮาร์ดแวร์ต่างๆ
มีประสิทธิภาพและปรับขนาดได้ :
PEFT
ช่วยให้สามารถฝึกอบรมโมเดลขนาดใหญ่ที่มีฮาร์ดแวร์ขนาดเล็กผ่านการหาปริมาณและ LoRA/QLoRAอินเทอร์เฟซบรรทัดคำสั่ง (CLI) : อินเทอร์เฟซที่เรียบง่ายช่วยให้คุณปรับแต่งและโต้ตอบกับโมเดลได้อย่างละเอียดโดยไม่จำเป็นต้องเขียนโค้ด
ผู้ฝึกสอน : วิธีการปรับแต่งอย่างละเอียดต่างๆ สามารถเข้าถึงได้ง่ายผ่านผู้ฝึกสอน เช่น SFTTrainer
, DPOTrainer
, RewardTrainer
, ORPOTrainer
และอื่นๆ
AutoModels : ใช้คลาสโมเดลที่กำหนดไว้ล่วงหน้า เช่น AutoModelForCausalLMWithValueHead
เพื่อทำให้การเรียนรู้แบบเสริมกำลัง (RL) ง่ายขึ้นด้วย LLM
ติดตั้งไลบรารี่โดยใช้ pip
:
pip install trl
หากคุณต้องการใช้ฟีเจอร์ล่าสุดก่อนการเปิดตัวอย่างเป็นทางการ คุณสามารถติดตั้ง TRL จากแหล่งที่มา:
pip install git+https://github.com/huggingface/trl.git
หากคุณต้องการใช้ตัวอย่าง คุณสามารถโคลนพื้นที่เก็บข้อมูลด้วยคำสั่งต่อไปนี้:
git clone https://github.com/huggingface/trl.git
คุณสามารถใช้ TRL Command Line Interface (CLI) เพื่อเริ่มต้นอย่างรวดเร็วด้วย Supervised Fine-tuning (SFT) และ Direct Preference Optimization (DPO) หรือตรวจสอบโมเดลของคุณด้วย Chat CLI:
เอสเอฟที:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B
--dataset_name trl-lib/Capybara
--output_dir Qwen2.5-0.5B-SFT
อ.ส.ค.:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
--dataset_name argilla/Capybara-Preferences
--output_dir Qwen2.5-0.5B-DPO
แชท:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
อ่านเพิ่มเติมเกี่ยวกับ CLI ในส่วนเอกสารประกอบที่เกี่ยวข้อง หรือใช้ --help
เพื่อดูรายละเอียดเพิ่มเติม
เพื่อความยืดหยุ่นและการควบคุมการฝึกอบรมที่มากขึ้น TRL มีคลาสผู้ฝึกอบรมเฉพาะสำหรับโมเดลภาษาหลังการฝึกหรืออะแดปเตอร์ PEFT บนชุดข้อมูลที่กำหนดเอง ผู้ฝึกสอนแต่ละคนใน TRL จะเป็นเสื้อคลุมสีอ่อนรอบๆ ? ผู้ฝึกสอน Transformers และรองรับวิธีการฝึกอบรมแบบกระจายเช่น DDP, DeepSpeed ZeRO และ FSDP
SFTTrainer
นี่คือตัวอย่างพื้นฐานของวิธีใช้ SFTTrainer
:
from trl import SFTConfig , SFTTrainer
from datasets import load_dataset
dataset = load_dataset ( "trl-lib/Capybara" , split = "train" )
training_args = SFTConfig ( output_dir = "Qwen/Qwen2.5-0.5B-SFT" )
trainer = SFTTrainer (
args = training_args ,
model = "Qwen/Qwen2.5-0.5B" ,
train_dataset = dataset ,
)
trainer . train ()
RewardTrainer
นี่คือตัวอย่างพื้นฐานของวิธีใช้ RewardTrainer
:
from trl import RewardConfig , RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification , AutoTokenizer
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
model . config . pad_token_id = tokenizer . pad_token_id
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = RewardConfig ( output_dir = "Qwen2.5-0.5B-Reward" , per_device_train_batch_size = 2 )
trainer = RewardTrainer (
args = training_args ,
model = model ,
processing_class = tokenizer ,
train_dataset = dataset ,
)
trainer . train ()
RLOOTrainer
RLOOTrainer
ใช้การปรับให้เหมาะสมสไตล์ REINFORCE สำหรับ RLHF ที่มีประสิทธิภาพและประสิทธิภาพของหน่วยความจำมากกว่า PPO นี่คือตัวอย่างพื้นฐานของวิธีใช้ RLOOTrainer
:
from trl import RLOOConfig , RLOOTrainer , apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM ,
AutoModelForSequenceClassification ,
AutoTokenizer ,
)
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
reward_model = AutoModelForSequenceClassification . from_pretrained (
"Qwen/Qwen2.5-0.5B-Instruct" , num_labels = 1
)
ref_policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
policy = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback-prompt" )
dataset = dataset . map ( apply_chat_template , fn_kwargs = { "tokenizer" : tokenizer })
dataset = dataset . map ( lambda x : tokenizer ( x [ "prompt" ]), remove_columns = "prompt" )
training_args = RLOOConfig ( output_dir = "Qwen2.5-0.5B-RL" )
trainer = RLOOTrainer (
config = training_args ,
processing_class = tokenizer ,
policy = policy ,
ref_policy = ref_policy ,
reward_model = reward_model ,
train_dataset = dataset [ "train" ],
eval_dataset = dataset [ "test" ],
)
trainer . train ()
DPOTrainer
DPOTrainer
ใช้อัลกอริทึม Direct Preference Optimization (DPO) ยอดนิยมที่ใช้ในการฝึก Llama 3 และรุ่นอื่นๆ อีกมากมาย นี่คือตัวอย่างพื้นฐานของวิธีใช้ DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM , AutoTokenizer
from trl import DPOConfig , DPOTrainer
model = AutoModelForCausalLM . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
tokenizer = AutoTokenizer . from_pretrained ( "Qwen/Qwen2.5-0.5B-Instruct" )
dataset = load_dataset ( "trl-lib/ultrafeedback_binarized" , split = "train" )
training_args = DPOConfig ( output_dir = "Qwen2.5-0.5B-DPO" )
trainer = DPOTrainer ( model = model , args = training_args , train_dataset = dataset , processing_class = tokenizer )
trainer . train ()
หากคุณต้องการสนับสนุน trl
หรือปรับแต่งตามความต้องการของคุณ โปรดอ่านคู่มือการสนับสนุนและตรวจสอบให้แน่ใจว่าคุณได้ทำการติดตั้ง dev:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
@misc { vonwerra2022trl ,
author = { Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec } ,
title = { TRL: Transformer Reinforcement Learning } ,
year = { 2020 } ,
publisher = { GitHub } ,
journal = { GitHub repository } ,
howpublished = { url{https://github.com/huggingface/trl} }
}
ซอร์สโค้ดของที่เก็บนี้มีให้ใช้งานภายใต้ลิขสิทธิ์ Apache-2.0