我们提出了DiGIT ,一种自回归生成模型,在源自自监督学习(SSL)模型的抽象潜在空间中执行下一个标记预测。通过对 DINOv2 模型的隐藏状态采用 K-Means 聚类,我们有效地创建了一种新颖的离散分词器。该方法显着提高了 ImageNet 数据集上的图像生成性能,类无条件任务的 FID 得分为 4.59 ,类条件任务的 FID 得分为 3.39 。此外,该模型增强了图像理解,实现了80.3 的线性探测精度。
方法 | # 代币 | 特征 | # 参数 | 前 1 名 Acc. |
---|---|---|---|---|
iGPT-L | 32 | 1536 | 1362M | 60.3 |
iGPT-XL | 64 | 3072 | 6801M | 68.7 |
VIM+VQGAN | 32 | 1024 | 650M | 61.8 |
VIM+dVAE | 32 | 1024 | 650M | 63.8 |
VIM+ViT-VQGAN | 32 | 1024 | 650M | 65.1 |
VIM+ViT-VQGAN | 32 | 2048 | 1697M | 73.2 |
目的 | 16 | 1536 | 0.6B | 70.5 |
数字(我们的) | 16 | 1024 | 219M | 71.7 |
数字(我们的) | 16 | 1536 | 732M | 80.3 |
类型 | 方法 | # 参数 | # 纪元 | 火焰离子化检测器 | 是 |
---|---|---|---|---|---|
生成网络 | 大GAN | 70M | - | 38.6 | 24.70 |
差异。 | LDM | 395M | - | 39.1 | 22.83 |
差异。 | ADM | 554M | - | 26.2 | 39.70 |
MIM | 法师 | 200M | 1600 | 11.1 | 81.17 |
MIM | 法师 | 463M | 1600 | 9.10 | 105.1 |
MIM | 掩模GIT | 227M | 300 | 20.7 | 42.08 |
MIM | 数字 (+MaskGIT) | 219M | 200 | 9.04 | 75.04 |
增强现实 | VQGAN | 214M | 200 | 24.38 | 30.93 |
增强现实 | 数字 (+VQGAN) | 219M | 400 | 9.13 | 73.85 |
增强现实 | 数字 (+VQGAN) | 732M | 200 | 4.59 | 141.29 |
类型 | 方法 | # 参数 | # 纪元 | 火焰离子化检测器 | 是 |
---|---|---|---|---|---|
生成网络 | 大GAN | 160M | - | 6.95 | 198.2 |
差异。 | ADM | 554M | - | 10.94 | 101.0 |
差异。 | LDM-4 | 400M | - | 10.56 | 103.5 |
差异。 | DiT-XL/2 | 675M | - | 9.62 | 121.50 |
差异。 | L-DiT-7B | 7B | - | 6.09 | 153.32 |
MIM | CQR-Trans | 371M | 300 | 5.45 | 172.6 |
MIM+AR | VAR | 310M | 200 | 4.64 | - |
MIM+AR | VAR | 310M | 200 | 3.60* | 257.5* |
MIM+AR | VAR | 600M | 250 | 2.95* | 306.1* |
MIM | MAGVIT-v2 | 307M | 1080 | 3.65 | 200.5 |
增强现实 | VQVAE-2 | 13.5B | - | 31.11 | 45 |
增强现实 | RQ-传输 | 480M | - | 15.72 | 86.8 |
增强现实 | RQ-传输 | 3.8B | - | 7.55 | 134.0 |
增强现实 | ViTVQGAN | 650M | 360 | 11.20 | 97.2 |
增强现实 | ViTVQGAN | 1.7B | 360 | 5.3 | 149.9 |
MIM | 掩模GIT | 227M | 300 | 6.18 | 182.1 |
MIM | 数字 (+MaskGIT) | 219M | 200 | 4.62 | 146.19 |
增强现实 | VQGAN | 227M | 300 | 18.65 | 80.4 |
增强现实 | 数字 (+VQGAN) | 219M | 400 | 4.79 | 142.87 |
增强现实 | 数字 (+VQGAN) | 732M | 200 | 3.39 | 205.96 |
*:VAR 是在无分类器指导下进行训练的,而所有其他模型则不然。
K-Means npy 文件和模型检查点可以从以下位置下载:
模型 | 关联 |
---|---|
高频配重? | 抱脸 |
对于基本模型,我们使用 DINOv2-base 和 DINOv2-large 作为大尺寸模型。我们使用的VQGAN与MAGE相同。
DiGIT
└── data/
├── ILSVRC2012
├── dinov2_base_short_224_l3
├── km_8k.npy
├── dinov2_large_short_224_l3
├── km_16k.npy
└── outputs/
├── base_8k_stage1
├── ...
└── models/
├── vqgan_jax_strongaug.ckpt
├── dinov2_vitb14_reg4_pretrain.pth
├── dinov2_vitl14_reg4_pretrain.pth
git clone https://github.com/DAMO-NLP-SG/DiGIT.git
cd DiGIT
pip install fairseq
fairseq
安装 fairseq 。下载 ImageNet 数据集,并将其放置在数据集目录$PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012
中。
提取 SSL 功能并将其保存为 .npy 文件。使用 K-Means 算法和 faiss 来计算质心。您还可以利用 Huggingface 上提供的预先训练的质心。
bash preprocess/run.sh
步骤1
使用判别标记器训练 GPT 模型。您可以在scripts/train_stage1_ar.sh
中找到训练脚本,超级参数位于config/stage1/dino_base.yaml
中。对于类条件生成配置,请参阅scripts/train_stage1_classcond.sh
。
步骤2
训练以判别标记为条件的像素解码器(AR 模型或 NAR 模型)。您可以在scripts/train_stage2_ar.sh
中找到自回归训练脚本,在scripts/train_stage2_nar.sh
中找到NAR训练脚本。
将创建一个名为outputs/EXP_NAME/checkpoints
的文件夹来保存检查点。 TensorBoard 日志文件保存在outputs/EXP_NAME/tb
。日志将记录在outputs/EXP_NAME/train.log
中。
您可以使用tensorboard --logdir=outputs/EXP_NAME/tb
监控训练过程。
首先使用scripts/infer_stage1_ar.sh
对判别性标记进行采样。对于基本模型大小,我们建议设置 topk=200,对于大模型大小,使用 topk=400。
然后运行scripts/infer_stage2_ar.sh
根据之前采样的判别性标记来采样VQ标记。
生成的令牌和合成图像将存储在名为outputs/EXP_NAME/results
目录中。
准备用于 FID 评估的 ImageNet 验证集:
python prepare_imgnet_val.py --data_path $PATH_TO_YOUR_WORKSPACE /dataset/ILSVRC2012 --output_dir imagenet-val
通过运行pip install torch-fidelity
安装评估工具。
执行以下命令来评估 FID:
python fairseq_user/eval_fid.py --results-path $IMG_SAVE_DIR --subset $GEN_SUBSET
bash scripts/train_stage1_linearprobe.sh
该项目根据 MIT 许可证获得许可 - 有关详细信息,请参阅许可证文件。
如果您发现我们的项目有用,希望您可以为我们的代码库加注星标并引用我们的工作,如下所示。
@misc { zhu2024stabilize ,
title = { Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective } ,
author = { Yongxin Zhu and Bocheng Li and Hang Zhang and Xin Li and Linli Xu and Lidong Bing } ,
year = { 2024 } ,
eprint = { 2410.12490 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}