用於訓練任意多模式基礎模型的架構。
可擴展。開源。跨越數十種模式和任務。
洛桑聯邦理工學院 - 蘋果
Website
| BibTeX
| ? Demo
官方實施與預訓練模型:
4M:大規模多模態掩模建模,NeurIPS 2023(聚焦)
David Mizrahi*、Roman Bachmann*、Oğuzhan Fatih Kar、Teresa Yeo、高明飛、Afshin Dehghan、Amir Zamir
4M-21:適用於數十種任務和模式的任意視覺模型,NeurIPS 2024
Roman Bachmann*、Oğuzhan Fatih Kar*、David Mizrahi*、Ali Garjani、高明飛、David Griffiths、胡家明、Afshin Dehghan、Amir Zamir
4M 是一個用於訓練「任意到任意」基礎模型的框架,使用標記化和遮蔽來擴展到多種不同的模式。使用 4M 訓練的模型可以執行廣泛的視覺任務,可以很好地遷移到看不見的任務和模式,並且是靈活且可操縱的多模式生成模型。我們正在發布「4M:大規模多模態掩蔽建模」(此處表示為4M-7)以及「4M-21:適用於數十種任務和模態的任意視覺模型」(此處表示為4M )的程式碼和模型-21)。
git clone https://github.com/apple/ml-4m
cd ml-4m
conda create -n fourm python=3.9 -y
conda activate fourm
pip install --upgrade pip # enable PEP 660 support
pip install -e .
# Run in Python shell
import torch
print(torch.cuda.is_available()) # Should return True
如果 CUDA 不可用,請考慮按照官方安裝說明重新安裝 PyTorch。同樣,如果您想安裝 xFormers(可選,為了更快的分詞器),請按照其 README 操作以確保 CUDA 版本正確。
我們提供了一個示範包裝器,可以快速開始使用 4M 模型進行 RGB-to-all 或 {caption,bounding box}-to-all 生成任務。例如,要從給定的 RGB 輸入產生所有模態,請呼叫:
from fourm . demo_4M_sampler import Demo4MSampler , img_from_url
sampler = Demo4MSampler ( fm = 'EPFL-VILAB/4M-21_XL' ). cuda ()
img = img_from_url ( 'https://storage.googleapis.com/four_m_site/images/demo_rgb.png' ) # 1x3x224x224 ImageNet-standardized PyTorch Tensor
preds = sampler ({ 'rgb@224' : img . cuda ()}, seed = None )
sampler . plot_modalities ( preds , save_path = None )
您應該會看到如下所示的輸出:
要執行「caption-to-all」生成,您可以將採樣器輸入替換為: preds = sampler({'caption': 'A lake house with a boat in front [S_1]'})
。有關可用 4M 模型的列表,請參閱下面的模型動物園,並參閱 README_GENERATION.md 以了解有關生成的更多說明。
有關如何準備對齊的多模式資料集的說明,請參閱 README_DATA.md。
有關如何訓練特定於模態的分詞器的說明,請參閱 README_TOKENIZATION.md。
有關如何訓練 4M 模型的說明,請參閱 README_TRAINING.md。
有關如何使用 4M 模型進行推理/產生的說明,請參閱 README_GENERATION.md。我們還提供了一個生成筆記本,其中包含 4M 推理的範例,專門執行條件影像生成和常見視覺任務(即 RGB-to-All)。
我們提供 4M 和 tokenizer 檢查點作為安全張量,並且還透過 Hugging Face Hub 提供輕鬆加載。
模型 | # 模組。 | 數據集 | # 參數 | 配置 | 重量 |
---|---|---|---|---|---|
4M-B | 7 | CC12M | 198M | 配置 | 檢查點/高頻集線器 |
4M-B | 7 | 科約700M | 198M | 配置 | 檢查點/高頻集線器 |
4M-B | 21 | CC12M+COYO700M+C4 | 198M | 配置 | 檢查點/高頻集線器 |
4M-L | 7 | CC12M | 705M | 配置 | 檢查點/高頻集線器 |
4M-L | 7 | 科約700M | 705M | 配置 | 檢查點/高頻集線器 |
4M-L | 21 | CC12M+COYO700M+C4 | 705M | 配置 | 檢查點/高頻集線器 |
4M-XL | 7 | CC12M | 2.8B | 配置 | 檢查點/高頻集線器 |
4M-XL | 7 | 科約700M | 2.8B | 配置 | 檢查點/高頻集線器 |
4M-XL | 21 | CC12M+COYO700M+C4 | 2.8B | 配置 | 檢查點/高頻集線器 |
要從 Hugging Face Hub 載入模型:
from fourm . models . fm import FM
fm7b_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7_B_CC12M' )
fm7b_coyo = FM . from_pretrained ( 'EPFL-VILAB/4M-7_B_COYO700M' )
fm21b = FM . from_pretrained ( 'EPFL-VILAB/4M-21_B' )
fm7l_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7_L_CC12M' )
fm7l_coyo = FM . from_pretrained ( 'EPFL-VILAB/4M-7_L_COYO700M' )
fm21l = FM . from_pretrained ( 'EPFL-VILAB/4M-21_L' )
fm7xl_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7_XL_CC12M' )
fm7xl_coyo = FM . from_pretrained ( 'EPFL-VILAB/4M-7_XL_COYO700M' )
fm21xl = FM . from_pretrained ( 'EPFL-VILAB/4M-21_XL' )
若要手動載入檢查點,請先從上面的連結下載 safetensors 檔案並呼叫:
from fourm . utils import load_safetensors
from fourm . models . fm import FM
ckpt , config = load_safetensors ( '/path/to/checkpoint.safetensors' )
fm = FM ( config = config )
fm . load_state_dict ( ckpt )
這些模型使用標準 4M-7 CC12M 模型進行初始化,但繼續使用嚴重偏向文字輸入的模態混合進行訓練。它們仍然能夠執行所有其他任務,但與未微調的模型相比,它們在文字到圖像生成方面表現更好。
模型 | # 模組。 | 數據集 | # 參數 | 配置 | 重量 |
---|---|---|---|---|---|
4M-T2I-B | 7 | CC12M | 198M | 配置 | 檢查點/高頻集線器 |
4M-T2I-L | 7 | CC12M | 705M | 配置 | 檢查點/高頻集線器 |
4M-T2I-XL | 7 | CC12M | 2.8B | 配置 | 檢查點/高頻集線器 |
要從 Hugging Face Hub 載入模型:
from fourm . models . fm import FM
fm7b_t2i_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7-T2I_B_CC12M' )
fm7l_t2i_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7-T2I_L_CC12M' )
fm7xl_t2i_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7-T2I_XL_CC12M' )
從檢查點手動載入的執行方式與上述基本 4M 模型相同。
模型 | # 模組。 | 數據集 | # 參數 | 配置 | 重量 |
---|---|---|---|---|---|
4M-SR-L | 7 | CC12M | 198M | 配置 | 檢查點/高頻集線器 |
要從 Hugging Face Hub 載入模型:
from fourm . models . fm import FM
fm7l_sr_cc12m = FM . from_pretrained ( 'EPFL-VILAB/4M-7-SR_L_CC12M' )
從檢查點手動載入的執行方式與上述基本 4M 模型相同。
模態 | 解決 | 代幣數量 | 碼本大小 | 擴散解碼器 | 重量 |
---|---|---|---|---|---|
RGB | 224-448 | 196-784 | 16k | ✓ | 檢查點/高頻集線器 |
深度 | 224-448 | 196-784 | 8k | ✓ | 檢查點/高頻集線器 |
法線 | 224-448 | 196-784 | 8k | ✓ | 檢查點/高頻集線器 |
邊緣(Canny、SAM) | 224-512 | 196-1024 | 8k | ✓ | 檢查點/高頻集線器 |
COCO語意分割 | 224-448 | 196-784 | 4k | ✗ | 檢查點/高頻集線器 |
夾子-B/16 | 224-448 | 196-784 | 8k | ✗ | 檢查點/高頻集線器 |
恐龍v2-B/14 | 224-448 | 256-1024 | 8k | ✗ | 檢查點/高頻集線器 |
DINOv2-B/14(全球) | 224 | 16 | 8k | ✗ | 檢查點/高頻集線器 |
影像綁定-H/14 | 224-448 | 256-1024 | 8k | ✗ | 檢查點/高頻集線器 |
ImageBind-H/14(全球) | 224 | 16 | 8k | ✗ | 檢查點/高頻集線器 |
SAM 實例 | - | 64 | 1k | ✗ | 檢查點/高頻集線器 |
3D人體姿勢 | - | 8 | 1k | ✗ | 檢查點/高頻集線器 |
要從 Hugging Face Hub 載入模型:
from fourm . vq . vqvae import VQVAE , DiVAE
# 4M-7 modalities
tok_rgb = DiVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_rgb_16k_224-448' )
tok_depth = DiVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_depth_8k_224-448' )
tok_normal = DiVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_normal_8k_224-448' )
tok_semseg = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_semseg_4k_224-448' )
tok_clip = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_CLIP-B16_8k_224-448' )
# 4M-21 modalities
tok_edge = DiVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_edge_8k_224-512' )
tok_dinov2 = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_DINOv2-B14_8k_224-448' )
tok_dinov2_global = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224' )
tok_imagebind = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_ImageBind-H14_8k_224-448' )
tok_imagebind_global = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_ImageBind-H14-global_8k_16_224' )
sam_instance = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_sam-instance_1k_64' )
human_poses = VQVAE . from_pretrained ( 'EPFL-VILAB/4M_tokenizers_human-poses_1k_8' )
若要手動載入檢查點,請先從上面的連結下載 safetensors 檔案並呼叫:
from fourm . utils import load_safetensors
from fourm . vq . vqvae import VQVAE , DiVAE
ckpt , config = load_safetensors ( '/path/to/checkpoint.safetensors' )
tok = VQVAE ( config = config ) # Or DiVAE for models with a diffusion decoder
tok . load_state_dict ( ckpt )
此儲存庫中的程式碼是根據 Apache 2.0 授權發布的,如 LICENSE 檔案中所示。
此儲存庫中的模型權重是在 LICENSE_WEIGHTS 檔案中找到的範例程式碼許可證下發布的。
如果您發現此儲存庫有幫助,請考慮引用我們的工作:
@inproceedings{4m,
title={{4M}: Massively Multimodal Masked Modeling},
author={David Mizrahi and Roman Bachmann and O{u{g}}uzhan Fatih Kar and Teresa Yeo and Mingfei Gao and Afshin Dehghan and Amir Zamir},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
}
@article{4m21,
title={{4M-21}: An Any-to-Any Vision Model for Tens of Tasks and Modalities},
author={Roman Bachmann and O{u{g}}uzhan Fatih Kar and David Mizrahi and Ali Garjani and Mingfei Gao and David Griffiths and Jiaming Hu and Afshin Dehghan and Amir Zamir},
journal={arXiv 2024},
year={2024},
}