Supplementary material can be found here https://github.com/sacktock/AMBS/blob/main/supplementary-material.pdf. Also, feel free to visit our papers here:
AMBS is now compatible with Safety Gym
GitHub repository for "Approximate Model-Based Shielding for Safe Reinforcement Learning" and "Leveraging Approximate Model-based Shielding for Probabilistic Safety Guarantees in Continuous Environments".
By leveraging world models for look-ahead shielding we obatin a general purpose algorithm for verifying the safety of learned RL policies. Specifically, we use DreamerV3 to simulate possible futures (traces) in the latent space of a learned dynamics model, we check each of these traces and estimate the probability of comitting a safety violation in the near future. If this probability is not sufficiently low then we override the learned policy with a safe policy trained to minimise constraint violations.
In our experiments detailed in "Approximate Model-Based Shielding for Safe Reinforcement Learning" all agents are implemented with JAX, although the dependencies mostly overlap they may differ slightly depending on whether you are running a DreamerV3 based agent or a dopamine based agent.
We recommend creating separate conda environments for the DreamerV3 based agents and the dopamine based agents.
conda create -n jax --clone base
conda create -n jax_dopamine --clone base
We refer to the DreamerV3 repository for the list of dependencies associated with the DreamerV3 based agents. And we refer to google/dopamine for the list of dependencies associated with the dopamine based agents.
conda activate jax
pip install //DreamerV3 dependencies
conda activate jax_dopamine
pip install //dopamine dependencies
Alternatively use our requirements files, although we stress that the specific JAX installation required is hardware depedent.
conda activate jax
pip install -r requirements_jax.txt
conda activate jax_dopamine
pip install -r requirements_jax_dopamine.txt
For DreamerV3 based agents, navigate to the relevant subdirectory and run train.py
. The following command will run DreamerV3 with AMBS on Seaquest for 40M frames. The --env.atari.labels
flag is used to specify the safety labels death
, early-surface
, out-of-oxygen
, and the xlarge
option determines the model size (xlarge
is the default for atari games).
cd dreamerV3-shield
python train.py --logdir ./logdir/seaquest/shield --configs atari xlarge --task atari_seaquest --env.atari.labels death early-surface out-of-oxygen --run.steps 10000000
Random seed can be set with the --seed
flag (default 0).
For dopamine based agents, navigate to the dopamine subdirectory an run the desired agent.
cd dopamine
python -um dopamine.discrete_domains.train --base_dir ./logdir/seaquest/rainbow --gin_files ./dopamine/jax/agents/full_rainbow/configs/full_rainbow_seaquest.gin
Random seed can be set by modifying the corresponding .gin file (e.g. JaxFullRainbowAgent.seed=0)
For plotting runs we use tensorboard. Navigate to the relevant subdirectory and start tensorboard.
cd dreamerV3-shield
tensorboard --logdir ./logdir/seaquest
Safety Gym is built on MuJoCo. For installation details of MuJoCo we refer you to here. Additional requirements for Safety Gym include the following:
gym>=0.15.3
joblib>=0.14.0
mujoco_py>=2.0.2.7
numpy>=1.17.4
xmltodict>=0.12.0
If you have followed the earlier setup and installation instructions, these dependencies may already be satisfied.
To get AMBS to work on Safety Gym we need to penalise the (unshielded) task policy. We provide the following three penalty techniques that help AMBS converge to a policy that maximises reward and satisfies the safety constraints of the Safety Gym environments:
python train.py ./logdir/safetygym/PointGoal1/shield_penl --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_plpg --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 0.8 --plpg True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_copt --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --copt True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
Our research paper outlining these techniques can be found here: http://arxiv.org/abs/2402.00816
We refer to the following repositories from which our code is developed: