Status: Rilis stabil
Implementasi agen DreamerV2 di TensorFlow 2. Kurva pelatihan untuk 55 game disertakan.
Jika Anda merasa kode ini berguna, silakan rujuk di makalah Anda:
@article{hafner2020dreamerv2,
title={Mastering Atari with Discrete World Models},
author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
journal={arXiv preprint arXiv:2010.02193},
year={2020}
}
DreamerV2 adalah agen model dunia pertama yang mencapai kinerja tingkat manusia pada benchmark Atari. DreamerV2 juga mengungguli kinerja akhir agen bebas model teratas Rainbow dan IQN dengan menggunakan jumlah pengalaman dan komputasi yang sama. Implementasi dalam repositori ini bergantian antara melatih model dunia, melatih kebijakan, dan mengumpulkan pengalaman dan dijalankan pada satu GPU.
DreamerV2 mempelajari model lingkungan langsung dari gambar masukan berdimensi tinggi. Untuk melakukan hal ini, ia memprediksi masa depan dengan menggunakan keadaan terpelajar yang ringkas. Negara bagian terdiri dari bagian deterministik dan beberapa variabel kategori yang dijadikan sampel. Prioritas kategorikal ini dipelajari melalui kerugian KL. Model dunia dipelajari secara end-to-end melalui gradien lurus, artinya gradien kepadatan diatur ke gradien sampel.
DreamerV2 mempelajari jaringan aktor dan kritikus dari lintasan keadaan laten yang dibayangkan. Lintasan dimulai pada keadaan yang dikodekan dari urutan yang ditemui sebelumnya. Model dunia kemudian memprediksi masa depan menggunakan tindakan yang dipilih dan keadaan yang dipelajari sebelumnya. Kritikus dilatih menggunakan pembelajaran perbedaan temporal dan aktor dilatih untuk memaksimalkan fungsi nilai melalui penguatan dan gradien langsung.
Untuk informasi lebih lanjut:
Cara termudah untuk menjalankan DreamerV2 di lingkungan baru adalah dengan menginstal paket melalui pip3 install dreamerv2
. Kode secara otomatis mendeteksi apakah lingkungan menggunakan tindakan terpisah atau berkelanjutan. Berikut adalah contoh penggunaan yang melatih DreamerV2 di lingkungan MiniGrid:
import gym
import gym_minigrid
import dreamerv2 . api as dv2
config = dv2 . defaults . update ({
'logdir' : '~/logdir/minigrid' ,
'log_every' : 1e3 ,
'train_every' : 10 ,
'prefill' : 1e5 ,
'actor_ent' : 3e-3 ,
'loss_scales.kl' : 1.0 ,
'discount' : 0.99 ,
}). parse_flags ()
env = gym . make ( 'MiniGrid-DoorKey-6x6-v0' )
env = gym_minigrid . wrappers . RGBImgPartialObsWrapper ( env )
dv2 . train ( env , config )
Untuk memodifikasi agen DreamerV2, kloning repositori dan ikuti petunjuk di bawah. Ada juga Dockerfile yang tersedia, jika Anda tidak ingin menginstal dependensi pada sistem Anda.
Dapatkan dependensi:
pip3 install tensorflow==2.6.0 tensorflow_probability ruamel.yaml ' gym[atari] ' dm_control
Berlatih di Atari:
python3 dreamerv2/train.py --logdir ~ /logdir/atari_pong/dreamerv2/1
--configs atari --task atari_pong
Latih Kontrol DM:
python3 dreamerv2/train.py --logdir ~ /logdir/dmc_walker_walk/dreamerv2/1
--configs dmc_vision --task dmc_walker_walk
Pantau hasil:
tensorboard --logdir ~ /logdir
Hasilkan plot:
python3 common/plot.py --indir ~ /logdir --outdir ~ /plots
--xaxis step --yaxis eval_return --bins 1e6
Dockerfile memungkinkan Anda menjalankan DreamerV2 tanpa menginstal dependensinya di sistem Anda. Ini mengharuskan Anda menyiapkan Docker dengan akses GPU.
Periksa pengaturan Anda:
docker run -it --rm --gpus all tensorflow/tensorflow:2.4.2-gpu nvidia-smi
Berlatih di Atari:
docker build -t dreamerv2 .
docker run -it --rm --gpus all -v ~ /logdir:/logdir dreamerv2
python3 dreamerv2/train.py --logdir /logdir/atari_pong/dreamerv2/1
--configs atari --task atari_pong
Latih Kontrol DM:
docker build -t dreamerv2 . --build-arg MUJOCO_KEY= " $( cat ~ /.mujoco/mjkey.txt ) "
docker run -it --rm --gpus all -v ~ /logdir:/logdir dreamerv2
python3 dreamerv2/train.py --logdir /logdir/dmc_walker_walk/dreamerv2/1
--configs dmc_vision --task dmc_walker_walk
Proses debug yang efisien. Anda dapat menggunakan konfigurasi debug
seperti pada --configs atari debug
. Hal ini mengurangi ukuran batch, meningkatkan frekuensi evaluasi, dan menonaktifkan kompilasi grafik tf.function
untuk memudahkan proses debug baris demi baris.
Norma gradien tak terbatas. Hal ini normal dan dijelaskan pada penskalaan kerugian dalam panduan presisi campuran. Anda dapat menonaktifkan presisi campuran dengan meneruskan --precision 32
ke skrip pelatihan. Presisi campuran lebih cepat tetapi pada prinsipnya dapat menyebabkan ketidakstabilan numerik.
Mengakses metrik yang dicatat. Metrik disimpan dalam format garis TensorBoard dan JSON. Anda dapat langsung memuatnya menggunakan pandas.read_json()
. Skrip pembuatan plot juga menyimpan metrik yang digabungkan dan digabungkan dari beberapa proses ke dalam satu file JSON untuk memudahkan pembuatan plot manual.