Materi pelengkap dapat ditemukan di sini https://github.com/sacktock/AMBS/blob/main/supplementary-material.pdf. Juga, jangan ragu untuk mengunjungi makalah kami di sini:
AMBS sekarang kompatibel dengan Safety Gym
Repositori GitHub untuk "Perkiraan Perlindungan Berbasis Model untuk Pembelajaran Penguatan yang Aman" dan "Memanfaatkan Perlindungan Berbasis Model Perkiraan untuk Jaminan Keamanan Probabilistik di Lingkungan Berkelanjutan".
Dengan memanfaatkan model dunia untuk perlindungan masa depan, kami mendapatkan algoritme tujuan umum untuk memverifikasi keamanan kebijakan RL yang dipelajari. Secara khusus, kami menggunakan DreamerV3 untuk mensimulasikan kemungkinan masa depan (jejak) dalam ruang laten model dinamika yang dipelajari, kami memeriksa setiap jejak ini dan memperkirakan kemungkinan melakukan pelanggaran keselamatan dalam waktu dekat. Jika probabilitas ini tidak cukup rendah, maka kami akan mengganti kebijakan yang dipelajari dengan kebijakan aman yang dilatih untuk meminimalkan pelanggaran batasan.
Dalam eksperimen kami yang dirinci dalam "Perkiraan Perlindungan Berbasis Model untuk Pembelajaran Penguatan yang Aman" semua agen diimplementasikan dengan JAX, meskipun sebagian besar dependensinya tumpang tindih, namun mungkin sedikit berbeda tergantung pada apakah Anda menjalankan agen berbasis DreamerV3 atau agen berbasis dopamin.
Kami merekomendasikan untuk membuat lingkungan conda terpisah untuk agen berbasis DreamerV3 dan agen berbasis dopamin.
conda create -n jax --clone base
conda create -n jax_dopamine --clone base
Kami merujuk ke repositori DreamerV3 untuk daftar dependensi yang terkait dengan agen berbasis DreamerV3. Dan kami merujuk ke google/dopamin untuk daftar dependensi yang terkait dengan agen berbasis dopamin.
conda activate jax
pip install //DreamerV3 dependencies
conda activate jax_dopamine
pip install //dopamine dependencies
Sebagai alternatif, gunakan file persyaratan kami, meskipun kami menekankan bahwa instalasi JAX spesifik yang diperlukan bergantung pada perangkat keras.
conda activate jax
pip install -r requirements_jax.txt
conda activate jax_dopamine
pip install -r requirements_jax_dopamine.txt
Untuk agen berbasis DreamerV3, navigasikan ke subdirektori yang relevan dan jalankan train.py
. Perintah berikut akan menjalankan DreamerV3 dengan AMBS di Seaquest untuk 40 juta frame. Bendera --env.atari.labels
digunakan untuk menentukan label keselamatan death
, early-surface
, out-of-oxygen
, dan opsi xlarge
menentukan ukuran model ( xlarge
adalah default untuk game atari).
cd dreamerV3-shield
python train.py --logdir ./logdir/seaquest/shield --configs atari xlarge --task atari_seaquest --env.atari.labels death early-surface out-of-oxygen --run.steps 10000000
Seed acak dapat diatur dengan flag --seed
(default 0).
Untuk agen berbasis dopamin, navigasikan ke subdirektori dopamin dan jalankan agen yang diinginkan.
cd dopamine
python -um dopamine.discrete_domains.train --base_dir ./logdir/seaquest/rainbow --gin_files ./dopamine/jax/agents/full_rainbow/configs/full_rainbow_seaquest.gin
Seed acak dapat diatur dengan memodifikasi file .gin yang sesuai (misalnya JaxFullRainbowAgent.seed=0)
Untuk merencanakan proses, kami menggunakan tensorboard. Navigasikan ke subdirektori yang relevan dan mulai tensorboard.
cd dreamerV3-shield
tensorboard --logdir ./logdir/seaquest
Safety Gym dibangun di MuJoCo. Untuk detail instalasi MuJoCo kami merujuk Anda ke sini. Persyaratan tambahan untuk Safety Gym meliputi yang berikut:
gym>=0.15.3
joblib>=0.14.0
mujoco_py>=2.0.2.7
numpy>=1.17.4
xmltodict>=0.12.0
Jika Anda telah mengikuti petunjuk pengaturan dan instalasi sebelumnya, dependensi ini mungkin sudah terpenuhi.
Agar AMBS dapat bekerja di Safety Gym kita perlu memberikan sanksi terhadap kebijakan tugas (tanpa pelindung). Kami memberikan tiga teknik penalti berikut yang membantu AMBS menyatu dengan kebijakan yang memaksimalkan imbalan dan memenuhi batasan keselamatan di lingkungan Safety Gym:
python train.py ./logdir/safetygym/PointGoal1/shield_penl --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_plpg --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 0.8 --plpg True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_copt --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --copt True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
Makalah penelitian kami yang menguraikan teknik-teknik ini dapat ditemukan di sini: http://arxiv.org/abs/2402.00816
Kami mengacu pada repositori berikut tempat kode kami dikembangkan: