我們提出了DiGIT ,一種自回歸生成模型,在源自自監督學習(SSL)模型的抽象潛在空間中執行下一個標記預測。透過對 DINOv2 模型的隱藏狀態採用 K-Means 聚類,我們有效地創建了一種新穎的離散分詞器。此方法顯著提高了 ImageNet 資料集上的影像產生效能,類別無條件任務的 FID 分數為 4.59 ,類別條件任務的 FID 得分為 3.39 。此外,該模型增強了影像理解,實現了80.3 的線性探測精度。
方法 | # 代幣 | 特徵 | # 參數 | 前 1 名 Acc. |
---|---|---|---|---|
iGPT-L | 32 | 1536 | 1362M | 60.3 |
iGPT-XL | 64 | 3072 | 6801M | 68.7 |
VIM+VQGAN | 32 | 1024 | 650M | 61.8 |
VIM+dVAE | 32 | 1024 | 650M | 63.8 |
VIM+ViT-VQGAN | 32 | 1024 | 650M | 65.1 |
VIM+ViT-VQGAN | 32 | 2048 | 1697M | 73.2 |
目的 | 16 | 1536 | 0.6B | 70.5 |
數字(我們的) | 16 | 1024 | 219M | 71.7 |
數字(我們的) | 16 | 1536 | 732M | 80.3 |
類型 | 方法 | # 參數 | # 紀元 | 火焰離子化偵測器 | 是 |
---|---|---|---|---|---|
產生網路 | 大GAN | 70M | - | 38.6 | 24.70 |
差異。 | LDM | 395M | - | 39.1 | 22.83 |
差異。 | ADM | 554M | - | 26.2 | 39.70 |
MIM | 法師 | 200M | 1600 | 11.1 | 81.17 |
MIM | 法師 | 463M | 1600 | 9.10 | 105.1 |
MIM | 掩模GIT | 227M | 300 | 20.7 | 42.08 |
MIM | 數字 (+MaskGIT) | 219M | 200 | 9.04 | 75.04 |
擴增實境 | VQGAN | 214M | 200 | 24.38 | 30.93 |
擴增實境 | 數字 (+VQGAN) | 219M | 400 | 9.13 | 73.85 |
擴增實境 | 數字 (+VQGAN) | 732M | 200 | 4.59 | 141.29 |
類型 | 方法 | # 參數 | # 紀元 | 火焰離子化偵測器 | 是 |
---|---|---|---|---|---|
產生網路 | 大GAN | 160M | - | 6.95 | 198.2 |
差異。 | ADM | 554M | - | 10.94 | 101.0 |
差異。 | LDM-4 | 400M | - | 10.56 | 103.5 |
差異。 | DiT-XL/2 | 675M | - | 9.62 | 121.50 |
差異。 | L-DiT-7B | 7B | - | 6.09 | 153.32 |
MIM | CQR-Trans | 371M | 300 | 5.45 | 172.6 |
MIM+AR | VAR | 310M | 200 | 4.64 | - |
MIM+AR | VAR | 310M | 200 | 3.60* | 257.5* |
MIM+AR | VAR | 600M | 250 | 2.95* | 306.1* |
MIM | MAGVIT-v2 | 307M | 1080 | 3.65 | 200.5 |
擴增實境 | VQVAE-2 | 13.5B | - | 31.11 | 45 |
擴增實境 | RQ-傳輸 | 480M | - | 15.72 | 86.8 |
擴增實境 | RQ-傳輸 | 3.8B | - | 7.55 | 134.0 |
擴增實境 | ViTVQGAN | 650M | 360 | 11.20 | 97.2 |
擴增實境 | ViTVQGAN | 1.7B | 360 | 5.3 | 149.9 |
MIM | 掩模GIT | 227M | 300 | 6.18 | 182.1 |
MIM | 數字 (+MaskGIT) | 219M | 200 | 4.62 | 146.19 |
擴增實境 | VQGAN | 227M | 300 | 18.65 | 80.4 |
擴增實境 | 數字 (+VQGAN) | 219M | 400 | 4.79 | 142.87 |
擴增實境 | 數字 (+VQGAN) | 732M | 200 | 3.39 | 205.96 |
*:VAR 是在無分類器指導下進行訓練的,而所有其他模型則不然。
K-Means npy 檔案和模型檢查點可以從以下位置下載:
模型 | 關聯 |
---|---|
高頻配重? | 抱臉 |
對於基本模型,我們使用 DINOv2-base 和 DINOv2-large 作為大尺寸模型。我們使用的VQGAN與MAGE相同。
DiGIT
└── data/
├── ILSVRC2012
├── dinov2_base_short_224_l3
├── km_8k.npy
├── dinov2_large_short_224_l3
├── km_16k.npy
└── outputs/
├── base_8k_stage1
├── ...
└── models/
├── vqgan_jax_strongaug.ckpt
├── dinov2_vitb14_reg4_pretrain.pth
├── dinov2_vitl14_reg4_pretrain.pth
git clone https://github.com/DAMO-NLP-SG/DiGIT.git
cd DiGIT
pip install fairseq
fairseq
安裝 fairseq 。下載 ImageNet 資料集,並將其放置在資料集目錄$PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012
中。
提取 SSL 功能並將其儲存為 .npy 檔案。使用 K-Means 演算法和 faiss 來計算質心。您也可以利用 Huggingface 上提供的預先訓練的質心。
bash preprocess/run.sh
步驟1
使用判別標記器訓練 GPT 模型。您可以在scripts/train_stage1_ar.sh
中找到訓練腳本,超級參數位於config/stage1/dino_base.yaml
。對於類別條件產生配置,請參閱scripts/train_stage1_classcond.sh
。
步驟2
訓練以判別標記為條件的像素解碼器(AR 模型或 NAR 模型)。您可以在scripts/train_stage2_ar.sh
中找到自回歸訓練腳本,在scripts/train_stage2_nar.sh
中找到NAR訓練腳本。
將建立一個名為outputs/EXP_NAME/checkpoints
的資料夾來儲存檢查點。 TensorBoard 日誌檔案保存在outputs/EXP_NAME/tb
。日誌將記錄在outputs/EXP_NAME/train.log
中。
您可以使用tensorboard --logdir=outputs/EXP_NAME/tb
監控訓練過程。
首先使用scripts/infer_stage1_ar.sh
對判別性標記進行採樣。對於基本模型大小,我們建議設定 topk=200,對於大模型大小,使用 topk=400。
然後執行scripts/infer_stage2_ar.sh
根據先前採樣的判別性標記來採樣VQ標記。
產生的令牌和合成影像將儲存在名為outputs/EXP_NAME/results
目錄中。
準備用於 FID 評估的 ImageNet 驗證集:
python prepare_imgnet_val.py --data_path $PATH_TO_YOUR_WORKSPACE /dataset/ILSVRC2012 --output_dir imagenet-val
透過執行pip install torch-fidelity
安裝評估工具。
執行以下命令來評估 FID:
python fairseq_user/eval_fid.py --results-path $IMG_SAVE_DIR --subset $GEN_SUBSET
bash scripts/train_stage1_linearprobe.sh
該項目根據 MIT 許可證獲得許可 - 有關詳細信息,請參閱許可證文件。
如果您發現我們的專案有用,希望您可以為我們的程式碼庫加註星標並引用我們的工作,如下所示。
@misc { zhu2024stabilize ,
title = { Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective } ,
author = { Yongxin Zhu and Bocheng Li and Hang Zhang and Xin Li and Linli Xu and Lidong Bing } ,
year = { 2024 } ,
eprint = { 2410.12490 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}