Material suplementar pode ser encontrado aqui https://github.com/sacktock/AMBS/blob/main/supplementary-material.pdf. Além disso, fique à vontade para visitar nossos artigos aqui:
AMBS agora é compatível com Safety Gym
Repositório GitHub para "Blindagem baseada em modelo aproximado para aprendizagem de reforço seguro" e "Aproveitando blindagem baseada em modelo aproximado para garantias de segurança probabilísticas em ambientes contínuos".
Ao aproveitar modelos mundiais para proteção antecipada, obtemos um algoritmo de propósito geral para verificar a segurança das políticas de RL aprendidas. Especificamente, usamos o DreamerV3 para simular possíveis futuros (traços) no espaço latente de um modelo dinâmico aprendido, verificamos cada um desses traços e estimamos a probabilidade de cometer uma violação de segurança em um futuro próximo. Se esta probabilidade não for suficientemente baixa, então substituímos a política aprendida por uma política segura treinada para minimizar violações de restrições.
Em nossos experimentos detalhados em "Blindagem baseada em modelo aproximado para aprendizado de reforço seguro", todos os agentes são implementados com JAX, embora a maioria das dependências se sobreponham, elas podem diferir ligeiramente dependendo se você está executando um agente baseado em DreamerV3 ou um agente baseado em dopamina.
Recomendamos a criação de ambientes conda separados para os agentes baseados em DreamerV3 e os agentes baseados em dopamina.
conda create -n jax --clone base
conda create -n jax_dopamine --clone base
Referimo-nos ao repositório DreamerV3 para obter a lista de dependências associadas aos agentes baseados em DreamerV3. E nos referimos a google/dopamine para obter a lista de dependências associadas aos agentes baseados em dopamina.
conda activate jax
pip install //DreamerV3 dependencies
conda activate jax_dopamine
pip install //dopamine dependencies
Como alternativa, use nossos arquivos de requisitos, embora enfatizemos que a instalação específica do JAX necessária depende do hardware.
conda activate jax
pip install -r requirements_jax.txt
conda activate jax_dopamine
pip install -r requirements_jax_dopamine.txt
Para agentes baseados em DreamerV3, navegue até o subdiretório relevante e execute train.py
. O comando a seguir executará o DreamerV3 com AMBS no Seaquest para quadros de 40M. O sinalizador --env.atari.labels
é usado para especificar os rótulos de segurança death
, early-surface
, out-of-oxygen
, e a opção xlarge
determina o tamanho do modelo ( xlarge
é o padrão para jogos Atari).
cd dreamerV3-shield
python train.py --logdir ./logdir/seaquest/shield --configs atari xlarge --task atari_seaquest --env.atari.labels death early-surface out-of-oxygen --run.steps 10000000
A semente aleatória pode ser definida com o sinalizador --seed
(padrão 0).
Para agentes baseados em dopamina, navegue até o subdiretório dopamina e execute o agente desejado.
cd dopamine
python -um dopamine.discrete_domains.train --base_dir ./logdir/seaquest/rainbow --gin_files ./dopamine/jax/agents/full_rainbow/configs/full_rainbow_seaquest.gin
A semente aleatória pode ser definida modificando o arquivo .gin correspondente (por exemplo, JaxFullRainbowAgent.seed=0)
Para traçar execuções, usamos tensorboard. Navegue até o subdiretório relevante e inicie o tensorboard.
cd dreamerV3-shield
tensorboard --logdir ./logdir/seaquest
O Safety Gym é construído no MuJoCo. Para obter detalhes de instalação do MuJoCo, consulte aqui. Os requisitos adicionais para o Safety Gym incluem o seguinte:
gym>=0.15.3
joblib>=0.14.0
mujoco_py>=2.0.2.7
numpy>=1.17.4
xmltodict>=0.12.0
Se você seguiu as instruções anteriores de configuração e instalação, essas dependências já podem estar satisfeitas.
Para fazer com que o AMBS trabalhe no Safety Gym, precisamos penalizar a política de tarefas (não blindadas). Fornecemos as três técnicas de penalidade a seguir que ajudam a AMBS a convergir para uma política que maximiza a recompensa e satisfaz as restrições de segurança dos ambientes do Safety Gym:
python train.py ./logdir/safetygym/PointGoal1/shield_penl --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_plpg --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 0.8 --plpg True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
python train.py --logdir ./logdir/safetygym/PointGoal1/shield_copt --configs safetygym_vision large --task safetygym_Safexp-PointGoal1-v0 --penalty_coeff 1.0 --copt True --normalise_ret False --penl_critic_type vfunction --run.steps 500000
Nosso artigo de pesquisa descrevendo essas técnicas pode ser encontrado aqui: http://arxiv.org/abs/2402.00816
Referimo-nos aos seguintes repositórios a partir dos quais nosso código é desenvolvido: