状态:稳定发布
在 TensorFlow 2 中实现 DreamerV2 代理。包含所有 55 个游戏的训练曲线。
如果您发现此代码有用,请在您的论文中引用:
@article{hafner2020dreamerv2,
title={Mastering Atari with Discrete World Models},
author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
journal={arXiv preprint arXiv:2010.02193},
year={2020}
}
DreamerV2 是世界上第一个在 Atari 基准测试中达到人类水平性能的模型代理。使用相同的经验和计算量,DreamerV2 的最终性能也优于顶级无模型代理 Rainbow 和 IQN。该存储库中的实现在训练世界模型、训练策略和收集经验之间交替进行,并在单个 GPU 上运行。
DreamerV2 直接从高维输入图像中学习环境模型。为此,它使用紧凑的学习状态进行提前预测。状态由确定性部分和采样的几个分类变量组成。这些分类的先验是通过 KL 损失学习的。世界模型是通过直通梯度进行端到端学习的,这意味着密度的梯度设置为样本的梯度。
DreamerV2 从潜在状态的想象轨迹中学习演员和评论家网络。轨迹从先前遇到的序列的编码状态开始。然后,世界模型使用所选的动作及其先前学习的状态来预测。使用时间差异学习来训练批评者,并训练演员通过强化和直通梯度来最大化价值函数。
欲了解更多信息:
在新环境中运行 DreamerV2 的最简单方法是通过pip3 install dreamerv2
安装软件包。代码自动检测环境是否使用离散操作或连续操作。以下是在 MiniGrid 环境上训练 DreamerV2 的使用示例:
import gym
import gym_minigrid
import dreamerv2 . api as dv2
config = dv2 . defaults . update ({
'logdir' : '~/logdir/minigrid' ,
'log_every' : 1e3 ,
'train_every' : 10 ,
'prefill' : 1e5 ,
'actor_ent' : 3e-3 ,
'loss_scales.kl' : 1.0 ,
'discount' : 0.99 ,
}). parse_flags ()
env = gym . make ( 'MiniGrid-DoorKey-6x6-v0' )
env = gym_minigrid . wrappers . RGBImgPartialObsWrapper ( env )
dv2 . train ( env , config )
要修改 DreamerV2 代理,请克隆存储库并按照以下说明进行操作。如果您不想在系统上安装依赖项,还有一个可用的 Dockerfile。
获取依赖项:
pip3 install tensorflow==2.6.0 tensorflow_probability ruamel.yaml ' gym[atari] ' dm_control
在 Atari 上训练:
python3 dreamerv2/train.py --logdir ~ /logdir/atari_pong/dreamerv2/1
--configs atari --task atari_pong
DM 控制培训:
python3 dreamerv2/train.py --logdir ~ /logdir/dmc_walker_walk/dreamerv2/1
--configs dmc_vision --task dmc_walker_walk
监测结果:
tensorboard --logdir ~ /logdir
生成图:
python3 common/plot.py --indir ~ /logdir --outdir ~ /plots
--xaxis step --yaxis eval_return --bins 1e6
Dockerfile 允许您运行 DreamerV2,而无需在系统中安装其依赖项。这需要您设置具有 GPU 访问权限的 Docker。
检查您的设置:
docker run -it --rm --gpus all tensorflow/tensorflow:2.4.2-gpu nvidia-smi
在 Atari 上训练:
docker build -t dreamerv2 .
docker run -it --rm --gpus all -v ~ /logdir:/logdir dreamerv2
python3 dreamerv2/train.py --logdir /logdir/atari_pong/dreamerv2/1
--configs atari --task atari_pong
DM 控制培训:
docker build -t dreamerv2 . --build-arg MUJOCO_KEY= " $( cat ~ /.mujoco/mjkey.txt ) "
docker run -it --rm --gpus all -v ~ /logdir:/logdir dreamerv2
python3 dreamerv2/train.py --logdir /logdir/dmc_walker_walk/dreamerv2/1
--configs dmc_vision --task dmc_walker_walk
高效调试。您可以使用--configs atari debug
中debug
配置。这可以减少批量大小,增加评估频率,并禁用tf.function
图形编译,以便于逐行调试。
无限梯度范数。这是正常现象,并在混合精度指南中的损失缩放下进行了描述。您可以通过将--precision 32
传递给训练脚本来禁用混合精度。混合精度速度更快,但原则上会导致数值不稳定。
访问记录的指标。指标以 TensorBoard 和 JSON 行格式存储。您可以使用pandas.read_json()
直接加载它们。绘图脚本还将多次运行的分箱和聚合指标存储到单个 JSON 文件中,以便于手动绘图。