Implementation of Q-Transformer, Scalable Offline Reinforcement Learning via Autoregressive Q-Functions, out of Google Deepmind
I will be keeping around the logic for Q-learning on single action just for final comparison with the proposed autoregressive Q-learning on multiple actions. Also to serve as education for myself and the public.
The autoregressive Q-learning formulation has been reproduced by Kotb et al.
$ pip install q-transformer
import torch
from q_transformer import (
QRoboticTransformer,
QLearner,
Agent,
ReplayMemoryDataset
)
# the attention model
model = QRoboticTransformer(
vit = dict(
num_classes = 1000,
dim_conv_stem = 64,
dim = 64,
dim_head = 64,
depth = (2, 2, 5, 2),
window_size = 7,
mbconv_expansion_rate = 4,
mbconv_shrinkage_rate = 0.25,
dropout = 0.1
),
num_actions = 8,
action_bins = 256,
depth = 1,
heads = 8,
dim_head = 64,
cond_drop_prob = 0.2,
dueling = True
)
# you need to supply your own environment, by overriding BaseEnvironment
from q_transformer.mocks import MockEnvironment
env = MockEnvironment(
state_shape = (3, 6, 224, 224),
text_embed_shape = (768,)
)
# env.init() should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions) should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]
# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning
agent = Agent(
model,
environment = env,
num_episodes = 1000,
max_num_steps_per_episode = 100,
)
agent()
# Q learning on the replay memory dataset on the model
q_learner = QLearner(
model,
dataset = ReplayMemoryDataset(),
num_train_steps = 10000,
learning_rate = 3e-4,
batch_size = 4,
grad_accum_every = 16,
)
q_learner()
# after much learning
# your robot should be better at selecting optimal actions
video = torch.randn(2, 3, 6, 224, 224)
instructions = [
'bring me that apple sitting on the table',
'please pass the butter'
]
actions = model.get_optimal_actions(video, instructions)
first work way towards single action support
offer batchnorm-less variant of maxvit, as done in SOTA weather model metnet3
add optional deep dueling architecture
add n-step Q learning
build the conservative regularization
build out main proposal in paper (autoregressive discrete actions until last action, reward given only on last)
improvise decoder head variant, instead of concatenating previous actions at the frames + learned tokens stage. in other words, use classic encoder - decoder
redo maxvit with axial rotary embeddings + sigmoid gating for attending to nothing. enable flash attention for maxvit with this change
build out a simple dataset creator class, taking in the environment and model and returning a folder that can be accepted by a ReplayDataset
ReplayDataset
that takes in folder
handle multiple instructions correctly
show a simple end-to-end example, in the same style as all other repos
handle no instructions, leverage null conditioner in CFG library
cache kv for action decoding
for exploration, allow for finely randomizing a subset of actions, and not all actions at once
consult some RL experts and figure out if there are any new headways into resolving delusional bias
figure out if one can train with randomized orders of actions - order could be sent as a conditioning that is concatted or summed before attention layers
simple beam search function for optimal actions
improvise cross attention to past actions and states of timestep, transformer-xl fashion (w/ structured memory dropout)
see if the main idea in this paper is applicable to language models here
@inproceedings{qtransformer,
title = {Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions},
authors = {Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine},
booktitle = {7th Annual Conference on Robot Learning},
year = {2023}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@inproceedings{Kumar2023MaintainingPI,
title = {Maintaining Plasticity in Continual Learning via Regenerative Regularization},
author = {Saurabh Kumar and Henrik Marklund and Benjamin Van Roy},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:261076021}
}