Dokumentasi | Tensordict | Fitur | Contoh, tutorial dan demo | Kutipan | Instalasi | Mengajukan pertanyaan | Berkontribusi
Torchrl adalah perpustakaan pembelajaran penguatan open-source (RL) untuk Pytorch.
Baca makalah lengkap untuk deskripsi perpustakaan yang lebih dikuratori.
Periksa tutorial awal kami untuk dengan cepat meningkatkan fitur dasar perpustakaan!
Dokumentasi Torchrl dapat ditemukan di sini. Ini berisi tutorial dan referensi API.
Torchrl juga menyediakan basis pengetahuan RL untuk membantu Anda men -debug kode Anda, atau hanya mempelajari dasar -dasar RL. Lihat di sini.
Kami memiliki beberapa video pengantar untuk Anda ketahui perpustakaan dengan lebih baik, periksa:
Torchrl menjadi domain-agnostik, Anda dapat menggunakannya di berbagai bidang. Berikut beberapa contoh:
TensorDict
Algoritma RL sangat heterogen, dan mungkin sulit untuk mendaur ulang basis kode di seluruh pengaturan (misalnya dari online ke offline, dari pembelajaran berbasis negara ke pixel). Torchrl memecahkan masalah ini melalui TensorDict
, struktur data yang nyaman (1) yang dapat digunakan untuk merampingkan basis kode RL seseorang. Dengan alat ini, seseorang dapat menulis skrip pelatihan PPO lengkap dalam kurang dari 100 baris kode !
import torch
from tensordict . nn import TensorDictModule
from tensordict . nn . distributions import NormalParamExtractor
from torch import nn
from torchrl . collectors import SyncDataCollector
from torchrl . data . replay_buffers import TensorDictReplayBuffer ,
LazyTensorStorage , SamplerWithoutReplacement
from torchrl . envs . libs . gym import GymEnv
from torchrl . modules import ProbabilisticActor , ValueOperator , TanhNormal
from torchrl . objectives import ClipPPOLoss
from torchrl . objectives . value import GAE
env = GymEnv ( "Pendulum-v1" )
model = TensorDictModule (
nn . Sequential (
nn . Linear ( 3 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 2 ),
NormalParamExtractor ()
),
in_keys = [ "observation" ],
out_keys = [ "loc" , "scale" ]
)
critic = ValueOperator (
nn . Sequential (
nn . Linear ( 3 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 1 ),
),
in_keys = [ "observation" ],
)
actor = ProbabilisticActor (
model ,
in_keys = [ "loc" , "scale" ],
distribution_class = TanhNormal ,
distribution_kwargs = { "low" : - 1.0 , "high" : 1.0 },
return_log_prob = True
)
buffer = TensorDictReplayBuffer (
storage = LazyTensorStorage ( 1000 ),
sampler = SamplerWithoutReplacement (),
batch_size = 50 ,
)
collector = SyncDataCollector (
env ,
actor ,
frames_per_batch = 1000 ,
total_frames = 1_000_000 ,
)
loss_fn = ClipPPOLoss ( actor , critic )
adv_fn = GAE ( value_network = critic , average_gae = True , gamma = 0.99 , lmbda = 0.95 )
optim = torch . optim . Adam ( loss_fn . parameters (), lr = 2e-4 )
for data in collector : # collect data
for epoch in range ( 10 ):
adv_fn ( data ) # compute advantage
buffer . extend ( data )
for sample in buffer : # consume data
loss_vals = loss_fn ( sample )
loss_val = sum (
value for key , value in loss_vals . items () if
key . startswith ( "loss" )
)
loss_val . backward ()
optim . step ()
optim . zero_grad ()
print ( f"avg reward: { data [ 'next' , 'reward' ]. mean (). item (): 4.4f } " )
Berikut adalah contoh bagaimana API lingkungan bergantung pada Tensordict untuk membawa data dari satu fungsi ke fungsi lainnya selama eksekusi peluncuran:
TensorDict
memudahkan untuk menggunakan kembali kode di seluruh lingkungan, model, dan algoritma.
Misalnya, inilah cara membuat kode peluncuran di Torchrl:
- obs, done = env.reset()
+ tensordict = env.reset()
policy = SafeModule(
model,
in_keys=["observation_pixels", "observation_vector"],
out_keys=["action"],
)
out = []
for i in range(n_steps):
- action, log_prob = policy(obs)
- next_obs, reward, done, info = env.step(action)
- out.append((obs, next_obs, action, log_prob, reward, done))
- obs = next_obs
+ tensordict = policy(tensordict)
+ tensordict = env.step(tensordict)
+ out.append(tensordict)
+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
Menggunakan ini, Torchrl abstrak dari tanda tangan input / output dari modul, env, kolektor, buffer replay dan kerugian perpustakaan, memungkinkan semua primitif dengan mudah didaur ulang di seluruh pengaturan.
Berikut adalah contoh lain dari loop pelatihan off-kebijakan di Torchrl (dengan asumsi bahwa pengumpul data, buffer replay, kerugian dan pengoptimal telah dipakai):
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
+ for i, tensordict in enumerate(collector):
- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
+ replay_buffer.add(tensordict)
for j in range(num_optim_steps):
- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
+ tensordict = replay_buffer.sample(batch_size)
+ loss = loss_fn(tensordict)
loss.backward()
optim.step()
optim.zero_grad()
Loop pelatihan ini dapat digunakan kembali di seluruh algoritma karena membuat jumlah minimal asumsi tentang struktur data.
Tensordict mendukung beberapa operasi tensor pada perangkat dan bentuknya (bentuk Tensordict, atau ukuran batchnya, adalah dimensi pertama yang arbitrer N dari semua tensor yang terkandung):
# stack and cat
tensordict = torch . stack ( list_of_tensordicts , 0 )
tensordict = torch . cat ( list_of_tensordicts , 0 )
# reshape
tensordict = tensordict . view ( - 1 )
tensordict = tensordict . permute ( 0 , 2 , 1 )
tensordict = tensordict . unsqueeze ( - 1 )
tensordict = tensordict . squeeze ( - 1 )
# indexing
tensordict = tensordict [: 2 ]
tensordict [:, 2 ] = sub_tensordict
# device and memory location
tensordict . cuda ()
tensordict . to ( "cuda:1" )
tensordict . share_memory_ ()
Tensordict hadir dengan modul tensordict.nn
khusus yang berisi semua yang mungkin Anda butuhkan untuk menulis model Anda dengannya. Dan itu adalah functorch
dan torch.compile
.
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
- out = transformer_model(src, tgt)
+ td_module(tensordict)
+ out = tensordict["out"]
Kelas TensorDictSequential
memungkinkan untuk cabang urutan instance nn.Module
dengan cara yang sangat modular. Misalnya, berikut adalah implementasi transformator yang menggunakan blok encoder dan decoder:
encoder_module = TransformerEncoder (...)
encoder = TensorDictSequential ( encoder_module , in_keys = [ "src" , "src_mask" ], out_keys = [ "memory" ])
decoder_module = TransformerDecoder (...)
decoder = TensorDictModule ( decoder_module , in_keys = [ "tgt" , "memory" ], out_keys = [ "output" ])
transformer = TensorDictSequential ( encoder , decoder )
assert transformer . in_keys == [ "src" , "src_mask" , "tgt" ]
assert transformer . out_keys == [ "memory" , "output" ]
TensorDictSequential
memungkinkan untuk mengisolasi subgraph dengan menanyakan satu set kunci input / output yang diinginkan:
transformer . select_subsequence ( out_keys = [ "memory" ]) # returns the encoder
transformer . select_subsequence ( in_keys = [ "tgt" , "memory" ]) # returns the decoder
Periksa tutorial Tensordict untuk mempelajari lebih lanjut!
Antarmuka umum untuk lingkungan yang mendukung perpustakaan umum (gym openai, laboratorium kontrol DeepMind, dll.) (1) dan eksekusi tanpa negara (misalnya lingkungan berbasis model). Wadah lingkungan batch memungkinkan eksekusi paralel (2) . Kelas Pytorch-First yang umum dari kelas spesifikasi tensor juga disediakan. Lingkungan Torchrl API sederhana tetapi ketat dan spesifik. Periksa dokumentasi dan tutorial untuk mempelajari lebih lanjut!
env_make = lambda : GymEnv ( "Pendulum-v1" , from_pixels = True )
env_parallel = ParallelEnv ( 4 , env_make ) # creates 4 envs in parallel
tensordict = env_parallel . rollout ( max_steps = 20 , policy = None ) # random rollout (no policy given)
assert tensordict . shape == [ 4 , 20 ] # 4 envs, 20 steps rollout
env_parallel . action_spec . is_in ( tensordict [ "action" ]) # spec check returns True
multiproses dan pengumpul data terdistribusi (2) yang bekerja secara serempak atau asinkron. Melalui penggunaan Tensordict, loop pelatihan Torchrl dibuat sangat mirip dengan loop pelatihan reguler dalam pembelajaran yang diawasi (meskipun "DataLoader"-baca pengumpul data-dimodifikasi saat terbang):
env_make = lambda : GymEnv ( "Pendulum-v1" , from_pixels = True )
collector = MultiaSyncDataCollector (
[ env_make , env_make ],
policy = policy ,
devices = [ "cuda:0" , "cuda:0" ],
total_frames = 10000 ,
frames_per_batch = 50 ,
...
)
for i , tensordict_data in enumerate ( collector ):
loss = loss_module ( tensordict_data )
loss . backward ()
optim . step ()
optim . zero_grad ()
collector . update_policy_weights_ ()
Periksa contoh kolektor terdistribusi kami untuk mempelajari lebih lanjut tentang pengumpulan data ultra-cepat dengan Torchrl.
Efisien (2) dan generik (1) buffer replay dengan penyimpanan modularisasi:
storage = LazyMemmapStorage ( # memory-mapped (physical) storage
cfg . buffer_size ,
scratch_dir = "/tmp/"
)
buffer = TensorDictPrioritizedReplayBuffer (
alpha = 0.7 ,
beta = 0.5 ,
collate_fn = lambda x : x ,
pin_memory = device != torch . device ( "cpu" ),
prefetch = 10 , # multi-threaded sampling
storage = storage
)
Buffer replay juga ditawarkan sebagai pembungkus di sekitar dataset umum untuk Offline RL :
from torchrl . data . replay_buffers import SamplerWithoutReplacement
from torchrl . data . datasets . d4rl import D4RLExperienceReplay
data = D4RLExperienceReplay (
"maze2d-open-v0" ,
split_trajs = True ,
batch_size = 128 ,
sampler = SamplerWithoutReplacement ( drop_last = True ),
)
for sample in data : # or alternatively sample = data.sample()
fun ( sample )
Transformasi Lingkungan Lintang-Perpustakaan (1) , dieksekusi pada perangkat dan dengan cara yang divektifisasi (2) , yang memproses dan menyiapkan data yang keluar dari lingkungan yang akan digunakan oleh agen:
env_make = lambda : GymEnv ( "Pendulum-v1" , from_pixels = True )
env_base = ParallelEnv ( 4 , env_make , device = "cuda:0" ) # creates 4 envs in parallel
env = TransformedEnv (
env_base ,
Compose (
ToTensorImage (),
ObservationNorm ( loc = 0.5 , scale = 1.0 )), # executes the transforms once and on device
)
tensordict = env . reset ()
assert tensordict . device == torch . device ( "cuda:0" )
Transformasi lain meliputi: penskalaan imbalan ( RewardScaling
), operasi bentuk (gabungan tensor, unqueezing dll.), Penggabungan operasi berturut -turut ( CatFrames
), mengubah ukuran ( Resize
) dan banyak lagi.
Tidak seperti perpustakaan lain, transformasi ditumpuk sebagai daftar (dan tidak dibungkus satu sama lain), yang membuatnya mudah untuk ditambahkan dan menghapusnya sesuka hati:
env . insert_transform ( 0 , NoopResetEnv ()) # inserts the NoopResetEnv transform at the index 0
Namun demikian, Transforms dapat mengakses dan menjalankan operasi pada lingkungan induk:
transform = env . transform [ 1 ] # gathers the second transform of the list
parent_env = transform . parent # returns the base environment of the second transform, i.e. the base env + the first transform
berbagai alat untuk pembelajaran terdistribusi (misalnya tensor memetakan memori) (2) ;
Berbagai arsitektur dan model (misalnya aktor-kritik) (1) :
# create an nn.Module
common_module = ConvNet (
bias_last_layer = True ,
depth = None ,
num_cells = [ 32 , 64 , 64 ],
kernel_sizes = [ 8 , 4 , 3 ],
strides = [ 4 , 2 , 1 ],
)
# Wrap it in a SafeModule, indicating what key to read in and where to
# write out the output
common_module = SafeModule (
common_module ,
in_keys = [ "pixels" ],
out_keys = [ "hidden" ],
)
# Wrap the policy module in NormalParamsWrapper, such that the output
# tensor is split in loc and scale, and scale is mapped onto a positive space
policy_module = SafeModule (
NormalParamsWrapper (
MLP ( num_cells = [ 64 , 64 ], out_features = 32 , activation = nn . ELU )
),
in_keys = [ "hidden" ],
out_keys = [ "loc" , "scale" ],
)
# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
# SafeProbabilisticModule, indicating how to build the
# torch.distribution.Distribution object and what to do with it
policy_module = SafeProbabilisticTensorDictSequential ( # stochastic policy
policy_module ,
SafeProbabilisticModule (
in_keys = [ "loc" , "scale" ],
out_keys = "action" ,
distribution_class = TanhNormal ,
),
)
value_module = MLP (
num_cells = [ 64 , 64 ],
out_features = 1 ,
activation = nn . ELU ,
)
# Wrap the policy and value funciton in a common module
actor_value = ActorValueOperator ( common_module , policy_module , value_module )
# standalone policy from this
standalone_policy = actor_value . get_policy_operator ()
Pembungkus eksplorasi dan modul untuk dengan mudah bertukar antara eksplorasi dan eksploitasi (1) :
policy_explore = EGreedyWrapper ( policy )
with set_exploration_type ( ExplorationType . RANDOM ):
tensordict = policy_explore ( tensordict ) # will use eps-greedy
with set_exploration_type ( ExplorationType . DETERMINISTIC ):
tensordict = policy_explore ( tensordict ) # will not use eps-greedy
Serangkaian modul kerugian yang efisien dan pengembalian fungsional yang sangat vektor dan komputasi keuntungan.
from torchrl . objectives import DQNLoss
loss_module = DQNLoss ( value_network = value_network , gamma = 0.99 )
tensordict = replay_buffer . sample ( batch_size )
loss = loss_module ( tensordict )
from torchrl . objectives . value . functional import vec_td_lambda_return_estimate
advantage = vec_td_lambda_return_estimate ( gamma , lmbda , next_state_value , reward , done , terminated )
Kelas pelatih generik (1) yang mengeksekusi loop pelatihan yang disebutkan di atas. Melalui mekanisme pengait, ini juga mendukung operasi penebangan atau transformasi data pada waktu tertentu.
Berbagai resep untuk membangun model yang sesuai dengan lingkungan yang digunakan.
Jika Anda merasa fitur hilang dari perpustakaan, silakan kirimkan masalah! Jika Anda ingin berkontribusi pada fitur baru, periksa panggilan kami untuk kontribusi dan halaman kontribusi kami.
Serangkaian implementasi canggih dilengkapi dengan tujuan ilustrasi:
Algoritma | Kompilasi dukungan ** | API bebas tarik | Kerugian modular | Kontinu dan diskrit |
Dqn | 1.9x | + | Na | + (melalui ActionDiscretizer Transform) |
Ddpg | 1.87x | + | + | - (hanya kontinu) |
IQL | 3.22x | + | + | + |
Cql | 2.68x | + | + | + |
TD3 | 2.27x | + | + | - (hanya kontinu) |
TD3+BC | belum dicoba | + | + | - (hanya kontinu) |
A2C | 2.67x | + | - | + |
PPO | 2.42x | + | - | + |
KANTUNG | 2.62x | + | - | + |
Redq | 2.28x | + | - | - (hanya kontinu) |
Pemimpi V1 | belum dicoba | + | + (kelas yang berbeda) | - (hanya kontinu) |
Transformer Keputusan | belum dicoba | + | Na | - (hanya kontinu) |
Crossq | belum dicoba | + | + | - (hanya kontinu) |
Gail | belum dicoba | + | Na | + |
Impala | belum dicoba | + | - | + |
IQL (Marl) | belum dicoba | + | + | + |
DDPG (Marl) | belum dicoba | + | + | - (hanya kontinu) |
PPO (marl) | belum dicoba | + | - | + |
Qmix-vdn (marl) | belum dicoba | + | Na | + |
Sac (marl) | belum dicoba | + | - | + |
Rlhf | Na | + | Na | Na |
** Angka tersebut menunjukkan kecepatan yang diharapkan dibandingkan dengan mode yang bersemangat saat dieksekusi pada CPU. Angka dapat bervariasi tergantung pada arsitektur dan perangkat.
Dan banyak lagi yang akan datang!
Contoh kode yang menampilkan cuplikan kode mainan dan skrip pelatihan juga tersedia
Periksa direktori contoh untuk detail lebih lanjut tentang menangani berbagai pengaturan konfigurasi.
Kami juga menyediakan tutorial dan demo yang memberikan perasaan tentang apa yang dapat dilakukan perpustakaan.
Jika Anda menggunakan Torchrl, silakan merujuk ke entri Bibtex ini untuk mengutip karya ini:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Buat lingkungan Conda di mana paket akan diinstal.
conda create --name torch_rl python=3.9
conda activate torch_rl
Pytorch
Bergantung pada penggunaan functorch yang ingin Anda buat, Anda mungkin ingin menginstal rilis Pytorch terbaru (malam) atau versi stabil Pytorch terbaru. Lihat di sini untuk daftar perintah terperinci, termasuk pip3
atau instruksi instalasi khusus lainnya.
Torchrl
Anda dapat menginstal rilis stabil terbaru dengan menggunakan
pip3 install torchrl
Ini harus bekerja di Linux, Windows 10 dan OSX (chip Intel atau silikon). Pada mesin Windows tertentu (Windows 11), seseorang harus menginstal perpustakaan secara lokal (lihat di bawah).
Bangunan malam dapat diinstal melalui
pip3 install torchrl-nightly
yang saat ini kami hanya mengirimkan mesin Linux dan OSX (Intel). Yang penting, bangunan malam juga membutuhkan pembangunan pytorch malam hari.
Untuk memasang dependensi ekstra, hubungi
pip3 install " torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing] "
atau subset dari ini.
Seseorang mungkin juga ingin menginstal perpustakaan secara lokal. Tiga alasan utama dapat memotivasi ini:
Untuk menginstal perpustakaan secara lokal, mulailah dengan mengkloning repo:
git clone https://github.com/pytorch/rl
Dan jangan lupa untuk memeriksa cabang atau tag yang ingin Anda gunakan untuk build:
git checkout v0.4.0
Pergi ke direktori tempat Anda mengkloning repo torchrl dan menginstalnya (setelah menginstal ninja
)
cd /path/to/torchrl/
pip3 install ninja -U
python setup.py develop
Seseorang juga dapat membangun roda untuk mendistribusikan ke rekan kerja menggunakan
python setup.py bdist_wheel
Roda Anda akan disimpan di sana ./dist/torchrl<name>.whl
dan dapat diinstal melalui
pip install torchrl < name > .whl
PERINGATAN : Sayangnya, pip3 install -e .
saat ini tidak berhasil. Kontribusi untuk membantu memperbaiki ini dipersilakan!
Pada mesin M1, ini harus bekerja di luar kotak dengan build malam pytorch. Jika generasi artefak ini dalam macOS M1 tidak berfungsi dengan benar atau dalam eksekusi pesan (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))
muncul, lalu coba, coba, lalu coba, coba, coba, coba, coba coba, maka coba coba, coba coba, coba coba ,asihungam
ARCHFLAGS="-arch arm64" python setup.py develop
Untuk menjalankan pemeriksaan kewarasan cepat, tinggalkan direktori itu (misalnya dengan mengeksekusi cd ~/
) dan mencoba mengimpor perpustakaan.
python -c "import torchrl"
Ini seharusnya tidak mengembalikan peringatan atau kesalahan apa pun.
Dependensi opsional
Perpustakaan berikut dapat diinstal tergantung pada penggunaan yang ingin dilakukan oleh Torchrl:
# diverse
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
# rendering
pip3 install moviepy
# deepmind control suite
pip3 install dm_control
# gym, atari games
pip3 install "gym[atari]" "gym[accept-rom-license]" pygame
# tests
pip3 install pytest pyyaml pytest-instafail
# tensorboard
pip3 install tensorboard
# wandb
pip3 install wandb
Pemecahan masalah
Jika ModuleNotFoundError: No module named 'torchrl._torchrl
kesalahan terjadi (atau peringatan yang menunjukkan bahwa binari C ++ tidak dapat dimuat), itu berarti bahwa ekstensi C ++ tidak dipasang atau tidak ditemukan.
develop
: cd ~/path/to/rl/repo
python -c 'from torchrl.envs.libs.gym import GymEnv'
python setup.py develop
. Salah satu penyebab umum adalah perbedaan versi G ++/C ++ dan/atau masalah dengan pustaka ninja
. wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
python collect_env.py
OS: macOS *** (arm64)
OS: macOS **** (x86_64)
Masalah versi dapat menyebabkan pesan kesalahan dari undefined symbol
dan semacamnya. Untuk ini, lihat dokumen masalah versi untuk penjelasan lengkap dan solusi yang diusulkan.
Jika Anda melihat bug di perpustakaan, silakan angkat masalah dalam repo ini.
Jika Anda memiliki pertanyaan yang lebih umum tentang RL di Pytorch, posting di Forum Pytorch.
Kolaborasi internal untuk Torchrl dipersilakan! Jangan ragu untuk membayar, kirimkan masalah dan PR. Anda dapat memeriksa panduan kontribusi terperinci di sini. Seperti disebutkan di atas, daftar kontribusi terbuka dapat ditemukan di sini.
Kontributor disarankan untuk memasang kait pra-komit (menggunakan pre-commit install
). Pra-komit akan memeriksa masalah terkait linting ketika kode dilakukan secara lokal. Anda dapat menonaktifkan cek dengan menambahkan -n
ke perintah komit Anda: git commit -m <commit message> -n
Perpustakaan ini dirilis sebagai fitur beta Pytorch. Perubahan BC-Breaking kemungkinan terjadi tetapi mereka akan diperkenalkan dengan garansi pendesahan setelah beberapa siklus rilis.
Torchrl dilisensikan di bawah lisensi MIT. Lihat lisensi untuk detailnya.