補足資料は、https://github.com/sacktock/AMBS/blob/main/supplementary-material.pdf でご覧いただけます。また、こちらの論文もお気軽にご覧ください。
AMBS が Safety Gym に対応しました
「安全な強化学習のための近似モデルベースのシールド」および「連続環境における確率論的安全保証のための近似モデルベースのシールドの活用」の GitHub リポジトリ。
先読みシールドのワールド モデルを活用することで、学習された RL ポリシーの安全性を検証するための汎用アルゴリズムを作成します。具体的には、DreamerV3 を使用して、学習されたダイナミクス モデルの潜在空間で起こり得る将来 (トレース) をシミュレートし、これらの各トレースをチェックして、近い将来に安全違反が発生する確率を推定します。この確率が十分に低くない場合は、学習されたポリシーを、制約違反を最小限に抑えるように訓練された安全なポリシーでオーバーライドします。
「安全な強化学習のためのモデルベースのシールドの近似」で詳しく説明されている実験では、すべてのエージェントが JAX を使用して実装されていますが、依存関係はほとんど重複していますが、DreamerV3 ベースのエージェントを実行しているか、ドーパミン ベースのエージェントを実行しているかによって若干異なる場合があります。
DreamerV3 ベースのエージェントとドーパミン ベースのエージェントに対して個別の conda 環境を作成することをお勧めします。
conda create -n jax --clone base
conda create -n jax_dopamine --clone base
DreamerV3 ベースのエージェントに関連付けられた依存関係のリストについては、DreamerV3 リポジトリを参照します。また、ドーパミンベースのエージェントに関連する依存関係のリストについては、google/dopmine を参照します。
conda activate jax
pip install //DreamerV3 dependencies
conda activate jax_dopamine
pip install //dopamine dependencies
あるいは、要件ファイルを使用してください。ただし、必要な特定の JAX インストールはハードウェアに依存することを強調します。
conda activate jax
pip install -r requirements_jax.txt
conda activate jax_dopamine
pip install -r requirements_jax_dopamine.txt
DreamerV3 ベースのエージェントの場合は、関連するサブディレクトリに移動し、 train.py
実行します。次のコマンドは、Seaquest で AMBS を使用して DreamerV3 を 40M フレーム実行します。 --env.atari.labels
フラグは、安全ラベルdeath
、 early-surface
、 out-of-oxygen
を指定するために使用され、 xlarge
オプションはモデル サイズを決定します ( xlarge
は Atari ゲームのデフォルトです)。
cd dreamerV3-shield
python train.py --logdir ./logdir/seaquest/shield --configs atari xlarge --task atari_seaquest --env.atari.labels death early-surface out-of-oxygen --run.steps 10000000
ランダム シードは--seed
フラグ (デフォルトは 0) で設定できます。
ドーパミン ベースのエージェントの場合は、ドーパミン サブディレクトリに移動し、目的のエージェントを実行します。
cd dopamine
python -um dopamine.discrete_domains.train --base_dir ./logdir/seaquest/rainbow --gin_files ./dopamine/jax/agents/full_rainbow/configs/full_rainbow_seaquest.gin
ランダム シードは、対応する .gin ファイルを変更することで設定できます (例: JaxFullRainbowAgent.seed=0)。
実行のプロットにはテンソルボードを使用します。関連するサブディレクトリに移動し、tensorboard を開始します。
cd dreamerV3-shield
tensorboard --logdir ./logdir/seaquest
セーフティジムはMuJoCo上に構築されています。 MuJoCoのインストールの詳細については、こちらを参照してください。セーフティ ジムの追加要件には次のものが含まれます。
gym>=0.15.3
joblib>=0.14.0
mujoco_py>=2.0.2.7
numpy>=1.17.4
xmltodict>=0.12.0
以前のセットアップとインストール手順に従っている場合は、これらの依存関係はすでに満たされている可能性があります。
AMBS をセーフティ ジムで動作させるには、(シールドされていない) タスク ポリシーにペナルティを与える必要があります。 AMBS が報酬を最大化し、セーフティ ジム環境の安全制約を満たすポリシーに収束するのに役立つ、次の 3 つのペナルティ テクニックを提供します。
python train.py ./logdir/safetygym/PointGoal1/shield_penl --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_plpg --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 0.8 --plpg True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_copt --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --copt True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
これらの技術を概説した私たちの研究論文は、http://arxiv.org/abs/2402.00816 でご覧いただけます。
コードの開発元となる次のリポジトリを参照します。