Shalev Lifshitz*、Keiran Paster*、Harris Chan†、Jimmy Ba、Sheila McIlraith
项目页面| ArXiv | PDF
构建响应文本指令的人工智能模型具有挑战性,特别是对于顺序决策任务。这项工作引入了一种受 unCLIP 启发的方法,用于在不依赖大量指令标记轨迹数据集的情况下调整行为的生成模型。使用这种方法,我们创建了一个名为 STEVE-1 的指令调整视频预训练 (VPT) 模型,该模型可以遵循 Minecraft™ 中的短视野开放式文本和视觉指令。 STEVE-1 的训练分为两个步骤:调整预训练的 VPT 模型以遵循 MineCLIP 潜在空间中的命令,然后训练先验以从文本中预测潜在代码。这使我们能够通过自我监督的行为克隆和事后重新标记来微调 VPT,从而减少对昂贵的人类文本注释的需求,而所有这些只需 60 美元的计算成本。通过利用 VPT 和 MineCLIP 等预训练模型,并采用文本条件图像生成的最佳实践,STEVE-1 通过低级控制(鼠标和键盘)和原始像素输入,为 Minecraft 中的开放式指令遵循设定了新的标准,远远优于之前的基准,并稳健地完成了我们早期游戏评估套件中 13 项任务中的 12 项。我们提供了实验证据,强调了下游性能的关键因素,包括预训练、无分类器指导和数据扩展。所有资源,包括我们的模型权重、训练脚本和评估工具,都可用于进一步研究。
.
├── README.md
├── steve1
│ ├── All agent, dataset, and training code.
├── run_agent
│ ├── Scripts for running the agent.
├── train
│ ├── Script for training the agent and generating the dataset.
我们建议使用 conda 环境和 python 3.10 在 Linux 上运行。
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
pip install minedojo git+https://github.com/MineDojo/MineCLIP
pip install git+https://github.com/minerllabs/[email protected]
pip install gym==0.19 gym3 attrs opencv-python
pip install gdown tqdm accelerate==0.18.0 wandb
steve1
: pip install -e .
如果您在无头服务器上运行,则需要安装xvfb
并使用xvfb-run
运行每个 python 脚本。例如, xvfb-run python script_name.py
。
另请注意,我们使用 MineRL 环境,而不是 MineDojo 环境。因此,按照“MineDojo 安装”说明中所述设置MINEDOJO_HEADLESS=1
将不会产生任何效果。
运行以下命令下载数据和权重:
. download_weights.sh
要从头开始训练 STEVE-1,请运行以下步骤:
. train/1_generate_dataset.sh
. train/2_create_sampling.sh
: . train/2_create_sampling.sh
. train/3_train.sh
. train/4_train_prior.sh
我们提供了两个脚本,用于使用不同的提示来测试代理。要测试您自己训练过的代理,请修改脚本中的--in_weights
参数。
. run_agent/1_gen_paper_videos.sh
生成论文中使用的视频。. run_agent/2_gen_vid_for_text_prompt.sh
为任意文本提示生成视频。. run_agent/3_run_interactive_session.sh
启动与 STEVE-1 的交互式会话。这在无头模式下不起作用。 如果您发现 STEVE-1 对您的研究有用,请引用我们的论文:
@article{lifshitz2023steve1,
title={STEVE-1: A Generative Model for Text-to-Behavior in Minecraft},
author={Shalev Lifshitz and Keiran Paster and Harris Chan and Jimmy Ba and Sheila McIlraith},
year={2023},
eprint={2306.00937},
archivePrefix={arXiv},
primaryClass={cs.LG}
}