このリポジトリは、思考クローニング (人間の思考を模倣して行動しながら考えることを学ぶ) の公式実装を提供します。 Thought Cloning (TC) は、エージェントが人間のように考えるようにトレーニングすることで、エージェントの能力、AI の安全性、解釈可能性を強化する、新しい模倣学習フレームワークです。このリポジトリは、人間の合成思考データセットを使用して、シミュレートされた部分的に観察可能な 2D グリッドワールド ドメイン BabyAI に TC を実装します。紹介ツイートスレッドもチェックしてください。
python-venv または conda で環境を作成します。 python-venv を使用した例を次に示します。
python3 -m venv thoughtcloning
source thoughtcloning/bin/activate
このリポジトリは、Python 3.9.10 および PyTorch 1.7.1+cu110 でテストされています。コードは、gym、numpy、または Gym-minigrid の上位バージョンと互換性がない可能性があります。
git clone https://github.com/ShengranHu/Thought-Cloning.git
cd Thought-Cloning
pip3 install --upgrade pip
pip3 install --editable .
注: pip3 install --editable .
後でプロジェクト ディレクトリが変更された場合は、再度実行します。
/babyai/utils/__init__.py のstorage_dir
の出力ディレクトリを変更します。 path-to-thought-cloning
Thought Cloning プロジェクトのパスに設定します。
BossLevelの合成人間の思考データセット、トレーニングされた TC モデルの重み、および配布外のパフォーマンスをテストするための収集された環境は、Google ドライブで利用できます。
合成思考データセットを再現するには、次のコマンドでscripts/make_agent_demos.py
スクリプトを使用できます。
scripts/make_agent_demos.py --episodes <NUM_OF_EPISODES> --env <ENV_NAME> --noise-rate 0.01
論文の図 3 に示されている主なパフォーマンス結果を再現するには、次のコマンドを使用します。
scripts/train_tc.py --env BabyAI-BossLevel-v0 --demos <DEMO_NAME> --memory-dim=2048 --recurrence=80 --batch-size=180 --instr-arch=attgru --instr-dim=256 --val-interval 2 --log-interval 2 --lr 5e-4 --epochs 160 --epoch-length 51200 --seed 0 --val-seed 2023 --model <NAME_OF_MODEL> --sg-coef 2 --warm-start --stop-tf 10
実験ごとに、 group_name
引数を変更して、ログとモデルの出力を特定のフォルダーにグループ化することができます。
論文の図 4(a) に示されているゼロショット評価結果を再現するには、次のコマンドを使用します。
scripts/evaluate_levels.py --env BabyAI-BossLevel-v0 --model <NAME_OF_MODEL> --testing-levels-path <PATH_TO_TESTING_LEVELS_PICKLE>
私たちの実装は、BabyAI 1.1 (ドメインおよび模倣学習ベースライン)、dan-visdial (上位レベル コンポーネント トランスフォーマー エンコーダー)、および visdial-rl (上位レベル コンポーネント RNN デコーダー) に基づいています。
このプロジェクトが役立つと思われる場合は、以下を引用することを検討してください。
@article{hu2023ThoughtCloning,
title={{Thought Cloning}: Learning to think while acting by imitating human thinking},
author={Hu, Shengran and Clune, Jeff},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2023}
}