q transformer
0.3.0
Q-Transformer 的实现,通过自回归 Q-Functions 进行可扩展离线强化学习,来自 Google Deepmind
我将保留单个动作 Q 学习的逻辑,只是为了与提议的多个动作自回归 Q 学习进行最终比较。也作为对我自己和公众的教育。
Kotb 等人复制了自回归 Q 学习公式。
$ pip install q-transformer
import torch
from q_transformer import (
QRoboticTransformer ,
QLearner ,
Agent ,
ReplayMemoryDataset
)
# the attention model
model = QRoboticTransformer (
vit = dict (
num_classes = 1000 ,
dim_conv_stem = 64 ,
dim = 64 ,
dim_head = 64 ,
depth = ( 2 , 2 , 5 , 2 ),
window_size = 7 ,
mbconv_expansion_rate = 4 ,
mbconv_shrinkage_rate = 0.25 ,
dropout = 0.1
),
num_actions = 8 ,
action_bins = 256 ,
depth = 1 ,
heads = 8 ,
dim_head = 64 ,
cond_drop_prob = 0.2 ,
dueling = True
)
# you need to supply your own environment, by overriding BaseEnvironment
from q_transformer . mocks import MockEnvironment
env = MockEnvironment (
state_shape = ( 3 , 6 , 224 , 224 ),
text_embed_shape = ( 768 ,)
)
# env.init() should return instructions and initial state: Tuple[str, Tensor[*state_shape]]
# env(actions) should return rewards, next state, and done flag: Tuple[Tensor[()], Tensor[*state_shape], Tensor[()]]
# agent is a class that allows the q-model to interact with the environment to generate a replay memory dataset for learning
agent = Agent (
model ,
environment = env ,
num_episodes = 1000 ,
max_num_steps_per_episode = 100 ,
)
agent ()
# Q learning on the replay memory dataset on the model
q_learner = QLearner (
model ,
dataset = ReplayMemoryDataset (),
num_train_steps = 10000 ,
learning_rate = 3e-4 ,
batch_size = 4 ,
grad_accum_every = 16 ,
)
q_learner ()
# after much learning
# your robot should be better at selecting optimal actions
video = torch . randn ( 2 , 3 , 6 , 224 , 224 )
instructions = [
'bring me that apple sitting on the table' ,
'please pass the butter'
]
actions = model . get_optimal_actions ( video , instructions )
实现单一行动支持的第一个工作方式
提供 maxvit 的无批量规范变体,如 SOTA 天气模型 metnet3 中所做的那样
添加可选的深度决斗架构
添加n步Q学习
建立保守正则化
在论文中构建主要提案(自回归离散动作直到最后一个动作,仅在最后给出奖励)
即兴设计解码器头变体,而不是在帧+学习令牌阶段连接先前的动作。换句话说,使用经典的编码器-解码器
重做 maxvit,使用轴向旋转嵌入 + s 形门控,无需关注任何内容。通过此更改启用 maxvit 的 flash 注意
构建一个简单的数据集创建器类,获取环境和模型并返回ReplayDataset
可以接受的文件夹
ReplayDataset
正确处理多条指令
显示一个简单的端到端示例,其风格与所有其他存储库相同
不处理任何指令,利用 CFG 库中的 null 调节器
用于动作解码的缓存 kv
为了进行探索,允许精细随机化操作的子集,而不是同时随机化所有操作
咨询一些强化学习专家,看看在解决妄想偏见方面是否有新的进展
弄清楚是否可以使用随机的动作顺序进行训练 - 顺序可以作为在注意层之前连接或求和的条件发送
简单的波束搜索功能可实现最佳操作
即兴交叉关注过去的动作和时间步状态,transformer-xl 时尚(带结构化记忆丢失)
看看本文的主要思想是否适用于这里的语言模型
@inproceedings { qtransformer ,
title = { Q-Transformer: Scalable Offline Reinforcement Learning via Autoregressive Q-Functions } ,
authors = { Yevgen Chebotar and Quan Vuong and Alex Irpan and Karol Hausman and Fei Xia and Yao Lu and Aviral Kumar and Tianhe Yu and Alexander Herzog and Karl Pertsch and Keerthana Gopalakrishnan and Julian Ibarz and Ofir Nachum and Sumedh Sontakke and Grecia Salazar and Huong T Tran and Jodilyn Peralta and Clayton Tan and Deeksha Manjunath and Jaspiar Singht and Brianna Zitkovich and Tomas Jackson and Kanishka Rao and Chelsea Finn and Sergey Levine } ,
booktitle = { 7th Annual Conference on Robot Learning } ,
year = { 2023 }
}
@inproceedings { dao2022flashattention ,
title = { Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness } ,
author = { Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{'e}, Christopher } ,
booktitle = { Advances in Neural Information Processing Systems } ,
year = { 2022 }
}
@inproceedings { Kumar2023MaintainingPI ,
title = { Maintaining Plasticity in Continual Learning via Regenerative Regularization } ,
author = { Saurabh Kumar and Henrik Marklund and Benjamin Van Roy } ,
year = { 2023 } ,
url = { https://api.semanticscholar.org/CorpusID:261076021 }
}