型號|範例腳本|開始使用|程式碼概述|安裝|貢獻|執照
TorchMultimodal是一個 PyTorch 函式庫,用於大規模訓練最先進的多模式多工模型,包括內容理解和生成模型。 TorchMultimodal 包含:
TorchMultimodal 包含許多模型,包括
除了上述模型之外,我們還提供了在流行的多模式任務上訓練、微調和評估模型的範例腳本。範例可以在範例/下找到,包括
模型 | 支援的任務 |
---|---|
ALBEF | 檢索 視覺問答 |
DDPM | 訓練和推理(筆記本) |
弗拉瓦 | 預訓練 微調 零射擊 |
多維資料傳輸 | 短語接地 視覺問答 |
無限 | 文字轉影片檢索 文字到影片生成 |
雜食動物 | 預訓練 評估 |
下面我們給了一些簡單的範例,說明如何使用 TorchMultimodal 的元件來編寫簡單的訓練或零樣本評估腳本。
import torch
from PIL import Image
from torchmultimodal . models . flava . model import flava_model
from torchmultimodal . transforms . bert_text_transform import BertTextTransform
from torchmultimodal . transforms . flava_transform import FLAVAImageTransform
# Define helper function for zero-shot prediction
def predict ( zero_shot_model , image , labels ):
zero_shot_model . eval ()
with torch . no_grad ():
image = image_transform ( img )[ "image" ]. unsqueeze ( 0 )
texts = text_transform ( labels )
_ , image_features = zero_shot_model . encode_image ( image , projection = True )
_ , text_features = zero_shot_model . encode_text ( texts , projection = True )
scores = image_features @ text_features . t ()
probs = torch . nn . Softmax ( dim = - 1 )( scores )
label = labels [ torch . argmax ( probs )]
print (
"Label probabilities: " ,
{ labels [ i ]: probs [:, i ] for i in range ( len ( labels ))},
)
print ( f"Predicted label: { label } " )
image_transform = FLAVAImageTransform ( is_train = False )
text_transform = BertTextTransform ()
zero_shot_model = flava_model ( pretrained = True )
img = Image . open ( "my_image.jpg" ) # point to your own image
predict ( zero_shot_model , img , [ "dog" , "cat" , "house" ])
# Example output:
# Label probabilities: {'dog': tensor([0.80590]), 'cat': tensor([0.0971]), 'house': tensor([0.0970])}
# Predicted label: dog
import torch
from torch . utils . data import DataLoader
from torchmultimodal . models . masked_auto_encoder . model import vit_l_16_image_mae
from torchmultimodal . models . masked_auto_encoder . utils import (
CosineWithWarmupAndLRScaling ,
)
from torchmultimodal . modules . losses . reconstruction_loss import ReconstructionLoss
from torchmultimodal . transforms . mae_transform import ImagePretrainTransform
mae_transform = ImagePretrainTransform ()
dataset = MyDatasetClass ( transforms = mae_transform ) # you should define this
dataloader = DataLoader ( dataset , batch_size = 8 )
# Instantiate model and loss
mae_model = vit_l_16_image_mae ()
mae_loss = ReconstructionLoss ()
# Define optimizer and lr scheduler
optimizer = torch . optim . AdamW ( mae_model . parameters ())
lr_scheduler = CosineWithWarmupAndLRScaling (
optimizer , max_iters = 1000 , warmup_iters = 100 # you should set these
)
# Train one epoch
for batch in dataloader :
model_out = mae_model ( batch [ "images" ])
loss = mae_loss ( model_out . decoder_pred , model_out . label_patches , model_out . mask )
loss . backward ()
optimizer . step ()
lr_scheduler . step ()
diffusive_labs 包含用於建立擴散模型的元件。有關這些組件的更多詳細信息,請參閱diffusion_labs/README.md。
在此處查找模型類別以及特定於給定架構的任何其他建模程式碼。例如,目錄 torchmultimodal/models/blip2 包含特定於 BLIP-2 的建模元件。
在這裡尋找可以拼接在一起建立新架構的常見通用建構塊。這包括碼本、補丁嵌入或變壓器編碼器/解碼器等層、溫度對比損失或重建損失等損失、ViT 和 BERT 等編碼器以及 Deep Set fusion 等融合模組。
在此處尋找流行模型(例如 CLIP、FLAVA 和 MAE)的常見資料轉換。
TorchMultimodal 需要 Python >= 3.8。該庫可以在有或沒有 CUDA 支援的情況下安裝。以下假設已安裝 conda。
安裝conda環境
conda create -n torch-multimodal python=
conda activate torch-multimodal
安裝 pytorch、torchvision 和 torchaudio。請參閱 PyTorch 文件。
# Use the current CUDA version as seen [here](https://pytorch.org/get-started/locally/)
# Select the nightly Pytorch build, Linux as the OS, and conda. Pick the most recent CUDA version.
conda install pytorch torchvision torchaudio pytorch-cuda= -c pytorch-nightly -c nvidia
# For CPU-only install
conda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly
Linux 上的 Python 3.8 和 3.9 的 Nightly 二進位檔案可以透過 pipwheels 安裝。目前我們僅透過 PyPI 支援 Linux 平台。
python -m pip install torchmultimodal-nightly
或者,您也可以從我們的原始程式碼建置並運行我們的範例:
git clone --recursive https://github.com/facebookresearch/multimodal.git multimodal
cd multimodal
pip install -e .
開發者請按照開發安裝。
我們歡迎來自社區的任何功能請求、錯誤報告或拉取請求。請參閱貢獻文件以了解如何提供協助。
TorchMultimodal 已獲得 BSD 許可,如 LICENSE 文件所示。