สถานะ: เผยแพร่อย่างเสถียร
การใช้งานเอเจนต์ DreamerV2 ใน TensorFlow 2 รวมกราฟการฝึกอบรมสำหรับเกมทั้งหมด 55 เกมแล้ว
หากคุณพบว่าโค้ดนี้มีประโยชน์ โปรดอ้างอิงในรายงานของคุณ:
@article{hafner2020dreamerv2,
title={Mastering Atari with Discrete World Models},
author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
journal={arXiv preprint arXiv:2010.02193},
year={2020}
}
DreamerV2 คือตัวแทนโมเดลระดับโลกรายแรกที่ได้รับประสิทธิภาพระดับมนุษย์บนเกณฑ์มาตรฐาน Atari DreamerV2 ยังมีประสิทธิภาพเหนือกว่าประสิทธิภาพขั้นสุดท้ายของเอเจนต์ไร้โมเดลชั้นนำอย่าง Rainbow และ IQN โดยใช้ประสบการณ์และการคำนวณที่เท่ากัน การใช้งานในพื้นที่เก็บข้อมูลนี้จะสลับระหว่างการฝึกโมเดลโลก การฝึกนโยบาย และการรวบรวมประสบการณ์และทำงานบน GPU ตัวเดียว
DreamerV2 เรียนรู้แบบจำลองสภาพแวดล้อมโดยตรงจากรูปภาพอินพุตที่มีมิติสูง สำหรับสิ่งนี้ คาดการณ์ล่วงหน้าโดยใช้สถานะการเรียนรู้แบบกะทัดรัด สถานะประกอบด้วยส่วนที่กำหนดไว้และตัวแปรหมวดหมู่หลายตัวที่ถูกสุ่มตัวอย่าง ข้อมูลก่อนหน้าสำหรับหมวดหมู่เหล่านี้เรียนรู้ผ่านการสูญเสีย KL แบบจำลองโลกได้รับการเรียนรู้ตั้งแต่ต้นจนจบผ่านการไล่ระดับสีแบบตรง ซึ่งหมายความว่าการไล่ระดับสีของความหนาแน่นจะถูกตั้งค่าเป็นการไล่ระดับสีของตัวอย่าง
DreamerV2 เรียนรู้เครือข่ายนักแสดงและนักวิจารณ์จากเส้นทางจินตนาการของรัฐที่แฝงอยู่ วิถีเริ่มต้นที่สถานะที่เข้ารหัสของลำดับที่พบก่อนหน้านี้ จากนั้นโมเดลโลกจะคาดการณ์ล่วงหน้าโดยใช้การกระทำที่เลือกและสถานะที่เรียนรู้ไว้ก่อนหน้า นักวิจารณ์ได้รับการฝึกฝนโดยใช้การเรียนรู้ความแตกต่างชั่วคราว และนักแสดงได้รับการฝึกฝนเพื่อเพิ่มฟังก์ชันคุณค่าสูงสุดผ่านการเสริมแรงและการไล่ระดับสีแบบตรง
สำหรับข้อมูลเพิ่มเติม:
วิธีที่ง่ายที่สุดในการรัน DreamerV2 ในสภาพแวดล้อมใหม่คือการติดตั้งแพ็คเกจผ่าน pip3 install dreamerv2
รหัสจะตรวจจับโดยอัตโนมัติว่าสภาพแวดล้อมใช้การดำเนินการแบบไม่ต่อเนื่องหรือต่อเนื่อง นี่คือตัวอย่างการใช้งานที่ฝึก DreamerV2 บนสภาพแวดล้อม MiniGrid:
import gym
import gym_minigrid
import dreamerv2 . api as dv2
config = dv2 . defaults . update ({
'logdir' : '~/logdir/minigrid' ,
'log_every' : 1e3 ,
'train_every' : 10 ,
'prefill' : 1e5 ,
'actor_ent' : 3e-3 ,
'loss_scales.kl' : 1.0 ,
'discount' : 0.99 ,
}). parse_flags ()
env = gym . make ( 'MiniGrid-DoorKey-6x6-v0' )
env = gym_minigrid . wrappers . RGBImgPartialObsWrapper ( env )
dv2 . train ( env , config )
หากต้องการแก้ไขเอเจนต์ DreamerV2 ให้โคลนพื้นที่เก็บข้อมูลและปฏิบัติตามคำแนะนำด้านล่าง นอกจากนี้ยังมี Dockerfile ในกรณีที่คุณไม่ต้องการติดตั้งการพึ่งพาในระบบของคุณ
รับการพึ่งพา:
pip3 install tensorflow==2.6.0 tensorflow_probability ruamel.yaml ' gym[atari] ' dm_control
รถไฟบนอาตาริ:
python3 dreamerv2/train.py --logdir ~ /logdir/atari_pong/dreamerv2/1
--configs atari --task atari_pong
ฝึกควบคุม DM:
python3 dreamerv2/train.py --logdir ~ /logdir/dmc_walker_walk/dreamerv2/1
--configs dmc_vision --task dmc_walker_walk
ติดตามผล:
tensorboard --logdir ~ /logdir
สร้างแปลง:
python3 common/plot.py --indir ~ /logdir --outdir ~ /plots
--xaxis step --yaxis eval_return --bins 1e6
Dockerfile ช่วยให้คุณรัน DreamerV2 โดยไม่ต้องติดตั้งการพึ่งพาในระบบของคุณ สิ่งนี้ต้องการให้คุณตั้งค่า Docker พร้อมการเข้าถึง GPU
ตรวจสอบการตั้งค่าของคุณ:
docker run -it --rm --gpus all tensorflow/tensorflow:2.4.2-gpu nvidia-smi
รถไฟบนอาตาริ:
docker build -t dreamerv2 .
docker run -it --rm --gpus all -v ~ /logdir:/logdir dreamerv2
python3 dreamerv2/train.py --logdir /logdir/atari_pong/dreamerv2/1
--configs atari --task atari_pong
ฝึกควบคุม DM:
docker build -t dreamerv2 . --build-arg MUJOCO_KEY= " $( cat ~ /.mujoco/mjkey.txt ) "
docker run -it --rm --gpus all -v ~ /logdir:/logdir dreamerv2
python3 dreamerv2/train.py --logdir /logdir/dmc_walker_walk/dreamerv2/1
--configs dmc_vision --task dmc_walker_walk
การดีบักที่มีประสิทธิภาพ คุณสามารถใช้การกำหนด debug
เช่นเดียวกับใน --configs atari debug
ซึ่งจะช่วยลดขนาดแบตช์ เพิ่มความถี่ในการประเมิน และปิดใช้งานการคอมไพล์กราฟ tf.function
เพื่อการแก้ไขจุดบกพร่องทีละบรรทัดอย่างง่ายดาย
บรรทัดฐานการไล่ระดับสีอนันต์ นี่เป็นเรื่องปกติและอธิบายไว้ภายใต้มาตราส่วนการสูญเสียในคู่มือความแม่นยำแบบผสม คุณสามารถปิดใช้งานความแม่นยำแบบผสมได้โดยส่ง --precision 32
ไปยังสคริปต์การฝึก ความแม่นยำแบบผสมจะเร็วกว่า แต่โดยหลักการแล้วสามารถทำให้เกิดความไม่เสถียรเชิงตัวเลขได้
การเข้าถึงตัวชี้วัดที่บันทึกไว้ เมตริกจะถูกจัดเก็บในรูปแบบบรรทัด TensorBoard และ JSON คุณสามารถโหลดได้โดยตรงโดยใช้ pandas.read_json()
สคริปต์การลงจุดยังจัดเก็บตัววัดแบบรวมและแบบรวมของการเรียกใช้หลายรายการไว้ในไฟล์ JSON ไฟล์เดียวเพื่อให้ง่ายต่อการลงจุดด้วยตนเอง