Documentação | Tensordito | Recursos | Exemplos, tutoriais e demos | Citação | Instalação | Fazendo uma pergunta | Contribuindo
Torchrl é uma biblioteca de aprendizado de reforço de código aberto (RL) para Pytorch.
Leia o papel completo para uma descrição mais curada da biblioteca.
Verifique nossos tutoriais para começar rapidamente com os recursos básicos da biblioteca!
A documentação do Torchrl pode ser encontrada aqui. Ele contém tutoriais e a referência da API.
O Torchrl também fornece uma base de conhecimento da RL para ajudá -lo a depurar seu código ou simplesmente aprender o básico do RL. Confira aqui.
Temos alguns vídeos introdutórios para você conhecer melhor a biblioteca, confira:
Torchrl sendo o domínio-agnóstico, você pode usá-lo em muitos campos diferentes. Aqui estão alguns exemplos:
TensorDict
Os algoritmos RL são muito heterogêneos e pode ser difícil reciclar uma base de código entre as configurações (por exemplo, on-line a offline, da aprendizagem baseada no estado à aprendizagem baseada em pixels). O Torchrl resolve esse problema através TensorDict
, uma estrutura de dados conveniente (1) que pode ser usada para otimizar a base de código RL. Com esta ferramenta, pode -se escrever um script de treinamento PPO completo em menos de 100 linhas de código !
import torch
from tensordict . nn import TensorDictModule
from tensordict . nn . distributions import NormalParamExtractor
from torch import nn
from torchrl . collectors import SyncDataCollector
from torchrl . data . replay_buffers import TensorDictReplayBuffer ,
LazyTensorStorage , SamplerWithoutReplacement
from torchrl . envs . libs . gym import GymEnv
from torchrl . modules import ProbabilisticActor , ValueOperator , TanhNormal
from torchrl . objectives import ClipPPOLoss
from torchrl . objectives . value import GAE
env = GymEnv ( "Pendulum-v1" )
model = TensorDictModule (
nn . Sequential (
nn . Linear ( 3 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 2 ),
NormalParamExtractor ()
),
in_keys = [ "observation" ],
out_keys = [ "loc" , "scale" ]
)
critic = ValueOperator (
nn . Sequential (
nn . Linear ( 3 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 128 ), nn . Tanh (),
nn . Linear ( 128 , 1 ),
),
in_keys = [ "observation" ],
)
actor = ProbabilisticActor (
model ,
in_keys = [ "loc" , "scale" ],
distribution_class = TanhNormal ,
distribution_kwargs = { "low" : - 1.0 , "high" : 1.0 },
return_log_prob = True
)
buffer = TensorDictReplayBuffer (
storage = LazyTensorStorage ( 1000 ),
sampler = SamplerWithoutReplacement (),
batch_size = 50 ,
)
collector = SyncDataCollector (
env ,
actor ,
frames_per_batch = 1000 ,
total_frames = 1_000_000 ,
)
loss_fn = ClipPPOLoss ( actor , critic )
adv_fn = GAE ( value_network = critic , average_gae = True , gamma = 0.99 , lmbda = 0.95 )
optim = torch . optim . Adam ( loss_fn . parameters (), lr = 2e-4 )
for data in collector : # collect data
for epoch in range ( 10 ):
adv_fn ( data ) # compute advantage
buffer . extend ( data )
for sample in buffer : # consume data
loss_vals = loss_fn ( sample )
loss_val = sum (
value for key , value in loss_vals . items () if
key . startswith ( "loss" )
)
loss_val . backward ()
optim . step ()
optim . zero_grad ()
print ( f"avg reward: { data [ 'next' , 'reward' ]. mean (). item (): 4.4f } " )
Aqui está um exemplo de como a API do ambiente se baseia no tensordito para transportar dados de uma função para outra durante uma execução de lançamento:
TensorDict
facilita a reutilização de peças de código em ambientes, modelos e algoritmos.
Por exemplo, veja como codificar um lançamento no Torchrl:
- obs, done = env.reset()
+ tensordict = env.reset()
policy = SafeModule(
model,
in_keys=["observation_pixels", "observation_vector"],
out_keys=["action"],
)
out = []
for i in range(n_steps):
- action, log_prob = policy(obs)
- next_obs, reward, done, info = env.step(action)
- out.append((obs, next_obs, action, log_prob, reward, done))
- obs = next_obs
+ tensordict = policy(tensordict)
+ tensordict = env.step(tensordict)
+ out.append(tensordict)
+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*
- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]
+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
Usando isso, a Torchrl abstrava as assinaturas de entrada / saída dos módulos, Env, colecionadores, buffers de repetição e perdas da biblioteca, permitindo que todas as primitivas sejam facilmente recicladas entre as configurações.
Aqui está outro exemplo de um loop de treinamento fora da política no Torchrl (assumindo que um coletor de dados, um buffer de repetição, uma perda e um otimizador foram instanciados):
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):
+ for i, tensordict in enumerate(collector):
- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))
+ replay_buffer.add(tensordict)
for j in range(num_optim_steps):
- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)
- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)
+ tensordict = replay_buffer.sample(batch_size)
+ loss = loss_fn(tensordict)
loss.backward()
optim.step()
optim.zero_grad()
Esse loop de treinamento pode ser reutilizado em todos os algoritmos, pois faz um número mínimo de suposições sobre a estrutura dos dados.
O tensordito suporta várias operações tensoras em seu dispositivo e forma (a forma do tensordito, ou seu tamanho de lote, é a primeira dimensões arbitrárias de todos os seus tensores contidos):
# stack and cat
tensordict = torch . stack ( list_of_tensordicts , 0 )
tensordict = torch . cat ( list_of_tensordicts , 0 )
# reshape
tensordict = tensordict . view ( - 1 )
tensordict = tensordict . permute ( 0 , 2 , 1 )
tensordict = tensordict . unsqueeze ( - 1 )
tensordict = tensordict . squeeze ( - 1 )
# indexing
tensordict = tensordict [: 2 ]
tensordict [:, 2 ] = sub_tensordict
# device and memory location
tensordict . cuda ()
tensordict . to ( "cuda:1" )
tensordict . share_memory_ ()
O Tensordict vem com um módulo tensordict.nn
dedicado que contém tudo o que você pode precisar para escrever seu modelo com ele. E é functorch
e torch.compile
compatível!
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])
- out = transformer_model(src, tgt)
+ td_module(tensordict)
+ out = tensordict["out"]
A classe TensorDictSequential
permite ramificar sequências de instâncias nn.Module
de maneira altamente modular. Por exemplo, aqui está uma implementação de um transformador usando os blocos de codificador e decodificador:
encoder_module = TransformerEncoder (...)
encoder = TensorDictSequential ( encoder_module , in_keys = [ "src" , "src_mask" ], out_keys = [ "memory" ])
decoder_module = TransformerDecoder (...)
decoder = TensorDictModule ( decoder_module , in_keys = [ "tgt" , "memory" ], out_keys = [ "output" ])
transformer = TensorDictSequential ( encoder , decoder )
assert transformer . in_keys == [ "src" , "src_mask" , "tgt" ]
assert transformer . out_keys == [ "memory" , "output" ]
TensorDictSequential
permite isolar os subgrafos, consultando um conjunto de chaves de entrada / saída desejadas:
transformer . select_subsequence ( out_keys = [ "memory" ]) # returns the encoder
transformer . select_subsequence ( in_keys = [ "tgt" , "memory" ]) # returns the decoder
Verifique os tutoriais tensordados para saber mais!
Uma interface comum para ambientes que suporta bibliotecas comuns (academia OpenAI, Laboratório de Controle DeepMind, etc.) (1) e execução sem estado (por exemplo, ambientes baseados em modelo). Os contêineres de ambientes em lotes permitem a execução paralela (2) . Também é fornecida uma classe comum de Pytorch-primeiro da classe de especificação de tensores. A API de ambientes de Torchrl é simples, mas rigorosa e específica. Verifique a documentação e o tutorial para saber mais!
env_make = lambda : GymEnv ( "Pendulum-v1" , from_pixels = True )
env_parallel = ParallelEnv ( 4 , env_make ) # creates 4 envs in parallel
tensordict = env_parallel . rollout ( max_steps = 20 , policy = None ) # random rollout (no policy given)
assert tensordict . shape == [ 4 , 20 ] # 4 envs, 20 steps rollout
env_parallel . action_spec . is_in ( tensordict [ "action" ]) # spec check returns True
coletores de dados multiprocessos e distribuídos (2) que funcionam de maneira síncrona ou assíncrona. Através do uso do tensordito, os loops de treinamento da Torchrl são muito semelhantes aos loops de treinamento regulares em aprendizado supervisionado (embora o "Dataloader"-leia o coletor de dados-seja modificado na fly):
env_make = lambda : GymEnv ( "Pendulum-v1" , from_pixels = True )
collector = MultiaSyncDataCollector (
[ env_make , env_make ],
policy = policy ,
devices = [ "cuda:0" , "cuda:0" ],
total_frames = 10000 ,
frames_per_batch = 50 ,
...
)
for i , tensordict_data in enumerate ( collector ):
loss = loss_module ( tensordict_data )
loss . backward ()
optim . step ()
optim . zero_grad ()
collector . update_policy_weights_ ()
Verifique nossos exemplos de colecionadores distribuídos para saber mais sobre a coleta de dados ultra-rápida com a Torchrl.
buffers de repetição eficientes (2) e genéricos (1) com armazenamento modularizado:
storage = LazyMemmapStorage ( # memory-mapped (physical) storage
cfg . buffer_size ,
scratch_dir = "/tmp/"
)
buffer = TensorDictPrioritizedReplayBuffer (
alpha = 0.7 ,
beta = 0.5 ,
collate_fn = lambda x : x ,
pin_memory = device != torch . device ( "cpu" ),
prefetch = 10 , # multi-threaded sampling
storage = storage
)
Os buffers de repetição também são oferecidos como invólucros em torno de conjuntos de dados comuns para RL offline :
from torchrl . data . replay_buffers import SamplerWithoutReplacement
from torchrl . data . datasets . d4rl import D4RLExperienceReplay
data = D4RLExperienceReplay (
"maze2d-open-v0" ,
split_trajs = True ,
batch_size = 128 ,
sampler = SamplerWithoutReplacement ( drop_last = True ),
)
for sample in data : # or alternatively sample = data.sample()
fun ( sample )
O ambiente de biblioteca transversal transforma (1) , executada no dispositivo e de maneira vetorizada (2) , que processam e preparam os dados que saem dos ambientes a serem usados pelo agente:
env_make = lambda : GymEnv ( "Pendulum-v1" , from_pixels = True )
env_base = ParallelEnv ( 4 , env_make , device = "cuda:0" ) # creates 4 envs in parallel
env = TransformedEnv (
env_base ,
Compose (
ToTensorImage (),
ObservationNorm ( loc = 0.5 , scale = 1.0 )), # executes the transforms once and on device
)
tensordict = env . reset ()
assert tensordict . device == torch . device ( "cuda:0" )
Outras transformadas incluem: escala de recompensa ( RewardScaling
), operações de forma (concatenação de tensores, inquietação etc.), concatenação de operações sucessivas ( CatFrames
), redimensionamento ( Resize
) e muitos mais.
Ao contrário de outras bibliotecas, as transformações são empilhadas como uma lista (e não embrulhadas uma na outra), o que facilita adicioná -las e removê -las à vontade:
env . insert_transform ( 0 , NoopResetEnv ()) # inserts the NoopResetEnv transform at the index 0
No entanto, o Transforms pode acessar e executar operações no ambiente pai:
transform = env . transform [ 1 ] # gathers the second transform of the list
parent_env = transform . parent # returns the base environment of the second transform, i.e. the base env + the first transform
várias ferramentas para aprendizado distribuído (por exemplo, tensores mapeados de memória) (2) ;
Várias arquiteturas e modelos (por exemplo, ator-crítico) (1) :
# create an nn.Module
common_module = ConvNet (
bias_last_layer = True ,
depth = None ,
num_cells = [ 32 , 64 , 64 ],
kernel_sizes = [ 8 , 4 , 3 ],
strides = [ 4 , 2 , 1 ],
)
# Wrap it in a SafeModule, indicating what key to read in and where to
# write out the output
common_module = SafeModule (
common_module ,
in_keys = [ "pixels" ],
out_keys = [ "hidden" ],
)
# Wrap the policy module in NormalParamsWrapper, such that the output
# tensor is split in loc and scale, and scale is mapped onto a positive space
policy_module = SafeModule (
NormalParamsWrapper (
MLP ( num_cells = [ 64 , 64 ], out_features = 32 , activation = nn . ELU )
),
in_keys = [ "hidden" ],
out_keys = [ "loc" , "scale" ],
)
# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a
# SafeProbabilisticModule, indicating how to build the
# torch.distribution.Distribution object and what to do with it
policy_module = SafeProbabilisticTensorDictSequential ( # stochastic policy
policy_module ,
SafeProbabilisticModule (
in_keys = [ "loc" , "scale" ],
out_keys = "action" ,
distribution_class = TanhNormal ,
),
)
value_module = MLP (
num_cells = [ 64 , 64 ],
out_features = 1 ,
activation = nn . ELU ,
)
# Wrap the policy and value funciton in a common module
actor_value = ActorValueOperator ( common_module , policy_module , value_module )
# standalone policy from this
standalone_policy = actor_value . get_policy_operator ()
invólucros de exploração e módulos para trocar facilmente entre exploração e exploração (1) :
policy_explore = EGreedyWrapper ( policy )
with set_exploration_type ( ExplorationType . RANDOM ):
tensordict = policy_explore ( tensordict ) # will use eps-greedy
with set_exploration_type ( ExplorationType . DETERMINISTIC ):
tensordict = policy_explore ( tensordict ) # will not use eps-greedy
Uma série de módulos de perda eficiente e computação funcional e vantagem altamente vetorizada.
from torchrl . objectives import DQNLoss
loss_module = DQNLoss ( value_network = value_network , gamma = 0.99 )
tensordict = replay_buffer . sample ( batch_size )
loss = loss_module ( tensordict )
from torchrl . objectives . value . functional import vec_td_lambda_return_estimate
advantage = vec_td_lambda_return_estimate ( gamma , lmbda , next_state_value , reward , done , terminated )
Uma classe de treinador genérico (1) que executa o loop de treinamento acima mencionado. Através de um mecanismo de gancho, ele também suporta qualquer operação de registro ou transformação de dados a qualquer momento.
Várias receitas para criar modelos que correspondem ao ambiente que está sendo implantado.
Se você acha que um recurso está faltando na biblioteca, envie um problema! Se você deseja contribuir com novos recursos, consulte nossa chamada de contribuições e nossa página de contribuição.
Uma série de implementações de última geração é fornecida com um objetivo ilustrativo:
Algoritmo | Suporte de compilação ** | API livre de tensordes | Perdas modulares | Contínuo e discreto |
Dqn | 1.9x | + | N / D | + (através da transformação de discretizador de ação) |
Ddpg | 1.87x | + | + | - (somente contínuo) |
IQL | 3.22x | + | + | + |
CQL | 2.68x | + | + | + |
TD3 | 2.27x | + | + | - (somente contínuo) |
TD3+BC | não testado | + | + | - (somente contínuo) |
A2C | 2.67x | + | - | + |
PPO | 2.42x | + | - | + |
SACO | 2.62x | + | - | + |
Redq | 2.28x | + | - | - (somente contínuo) |
Dreamer v1 | não testado | + | + (classes diferentes) | - (somente contínuo) |
Transformadores de decisão | não testado | + | N / D | - (somente contínuo) |
Crossq | não testado | + | + | - (somente contínuo) |
Gail | não testado | + | N / D | + |
Impala | não testado | + | - | + |
IQL (MARL) | não testado | + | + | + |
DDPG (MARL) | não testado | + | + | - (somente contínuo) |
PPO (Marl) | não testado | + | - | + |
Qmix-vdn (marl) | não testado | + | N / D | + |
SAC (Marl) | não testado | + | - | + |
RlHf | N / D | + | N / D | N / D |
** O número indica aceleração esperada em comparação com o modo ansioso quando executado na CPU. Os números podem variar dependendo da arquitetura e do dispositivo.
E muitos mais por vir!
Exemplos de código exibindo trechos de código de brinquedos e scripts de treinamento também estão disponíveis
Verifique o diretório Exemplos para obter mais detalhes sobre o manuseio das várias definições de configuração.
Também fornecemos tutoriais e demos que dão uma noção do que a biblioteca pode fazer.
Se você estiver usando o Torchrl, consulte esta entrada do Bibtex para citar este trabalho:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Crie um ambiente do CONDA onde os pacotes serão instalados.
conda create --name torch_rl python=3.9
conda activate torch_rl
Pytorch
Dependendo do uso do Functorch que você deseja fazer, você pode instalar a versão mais recente (noturna) do Pytorch ou a mais recente versão estável do Pytorch. Veja aqui uma lista detalhada de comandos, incluindo pip3
ou outras instruções de instalação especiais.
Torchrl
Você pode instalar o último lançamento estável usando
pip3 install torchrl
Isso deve funcionar no Linux, Windows 10 e OSX (chips Intel ou Silicon). Em determinadas máquinas do Windows (Windows 11), deve -se instalar a biblioteca localmente (veja abaixo).
A construção noturna pode ser instalada via
pip3 install torchrl-nightly
que atualmente enviamos apenas para máquinas Linux e OSX (Intel). É importante ressaltar que as construções noturnas também exigem as construções noturnas de Pytorch.
Para instalar dependências extras, ligue
pip3 install " torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing] "
ou um subconjunto destes.
Pode -se também desejar instalar a biblioteca localmente. Três motivos principais podem motivar isso:
Para instalar a biblioteca localmente, comece clonando o repositório:
git clone https://github.com/pytorch/rl
E não se esqueça de conferir a filial ou a etiqueta que você deseja usar para a construção:
git checkout v0.4.0
Vá para o diretório onde clonou o repositório de Torchrl e instale -o (depois de instalar ninja
)
cd /path/to/torchrl/
pip3 install ninja -U
python setup.py develop
Pode-se também construir as rodas para distribuir para colegas de trabalho usando
python setup.py bdist_wheel
Suas rodas serão armazenadas lá ./dist/torchrl<name>.whl
e instalável via
pip install torchrl < name > .whl
AVISO : Infelizmente, pip3 install -e .
atualmente não funciona. Contribuições para ajudar a corrigir isso são bem -vindos!
Nas máquinas M1, isso deve funcionar pronto para uso com a construção noturna de Pytorch. Se a geração desse artefato no MacOS M1 não funcionar corretamente ou na execução da mensagem (mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))
aparecer, então tente
ARCHFLAGS="-arch arm64" python setup.py develop
Para executar uma verificação rápida da sanidade, deixe esse diretório (por exemplo, executando cd ~/
) e tente importar a biblioteca.
python -c "import torchrl"
Isso não deve retornar nenhum aviso ou erro.
Dependências opcionais
As bibliotecas a seguir podem ser instaladas, dependendo do uso que se deseja fazer de Torchrl:
# diverse
pip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher
# rendering
pip3 install moviepy
# deepmind control suite
pip3 install dm_control
# gym, atari games
pip3 install "gym[atari]" "gym[accept-rom-license]" pygame
# tests
pip3 install pytest pyyaml pytest-instafail
# tensorboard
pip3 install tensorboard
# wandb
pip3 install wandb
Solução de problemas
Se um ModuleNotFoundError: No module named 'torchrl._torchrl
ocorrer (ou um aviso indicando que os binários C ++ não puderam ser carregados), significa que as extensões C ++ não foram instaladas ou não encontradas.
develop
: cd ~/path/to/rl/repo
python -c 'from torchrl.envs.libs.gym import GymEnv'
python setup.py develop
. Uma causa comum é uma versão G ++/C ++ discrepância e/ou um problema na biblioteca ninja
. wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
python collect_env.py
OS: macOS *** (arm64)
OS: macOS **** (x86_64)
Os problemas de versão podem causar mensagem de erro do tipo de undefined symbol
e tal. Para isso, consulte o documento de problemas de versão para obter uma explicação completa e soluções alternativas propostas.
Se você encontrar um bug na biblioteca, por favor, levante um problema neste repositório.
Se você tiver uma pergunta mais genérica sobre RL em Pytorch, publique -a no fórum Pytorch.
As colaborações internas do Torchrl são bem -vindas! Sinta -se à vontade para bifurcar, envie questões e PRs. Você pode verificar o guia de contribuição detalhado aqui. Como mencionado acima, uma lista de contribuições abertas pode ser encontrada aqui.
Recomenda-se os colaboradores para instalar ganchos pré-comprometidos (usando pre-commit install
). O pré-compromisso verificará os problemas relacionados ao linha quando o código for comprometido localmente. Você pode desativar a verificação, anexando -n
ao seu comando commit: git commit -m <commit message> -n
Esta biblioteca é lançada como um recurso beta pytorch. As mudanças de quebra de BC provavelmente acontecerão, mas serão introduzidas com uma garantia de depreciação após alguns ciclos de liberação.
A Torchrl é licenciada sob a licença do MIT. Consulte a licença para obter detalhes.