型号|示例脚本|开始使用|代码概述|安装|贡献|执照
TorchMultimodal是一个 PyTorch 库,用于大规模训练最先进的多模式多任务模型,包括内容理解和生成模型。 TorchMultimodal 包含:
TorchMultimodal 包含许多模型,包括
除了上述模型之外,我们还提供了用于在流行的多模式任务上训练、微调和评估模型的示例脚本。示例可以在示例/下找到,包括
模型 | 支持的任务 |
---|---|
ALBEF | 检索 视觉问答 |
DDPM | 训练和推理(笔记本) |
弗拉瓦 | 预训练 微调 零射击 |
多维数据传输 | 短语接地 视觉问答 |
无限 | 文本到视频检索 文本到视频生成 |
杂食动物 | 预训练 评估 |
下面我们给出了一些简单的示例,说明如何使用 TorchMultimodal 的组件编写简单的训练或零样本评估脚本。
import torch
from PIL import Image
from torchmultimodal . models . flava . model import flava_model
from torchmultimodal . transforms . bert_text_transform import BertTextTransform
from torchmultimodal . transforms . flava_transform import FLAVAImageTransform
# Define helper function for zero-shot prediction
def predict ( zero_shot_model , image , labels ):
zero_shot_model . eval ()
with torch . no_grad ():
image = image_transform ( img )[ "image" ]. unsqueeze ( 0 )
texts = text_transform ( labels )
_ , image_features = zero_shot_model . encode_image ( image , projection = True )
_ , text_features = zero_shot_model . encode_text ( texts , projection = True )
scores = image_features @ text_features . t ()
probs = torch . nn . Softmax ( dim = - 1 )( scores )
label = labels [ torch . argmax ( probs )]
print (
"Label probabilities: " ,
{ labels [ i ]: probs [:, i ] for i in range ( len ( labels ))},
)
print ( f"Predicted label: { label } " )
image_transform = FLAVAImageTransform ( is_train = False )
text_transform = BertTextTransform ()
zero_shot_model = flava_model ( pretrained = True )
img = Image . open ( "my_image.jpg" ) # point to your own image
predict ( zero_shot_model , img , [ "dog" , "cat" , "house" ])
# Example output:
# Label probabilities: {'dog': tensor([0.80590]), 'cat': tensor([0.0971]), 'house': tensor([0.0970])}
# Predicted label: dog
import torch
from torch . utils . data import DataLoader
from torchmultimodal . models . masked_auto_encoder . model import vit_l_16_image_mae
from torchmultimodal . models . masked_auto_encoder . utils import (
CosineWithWarmupAndLRScaling ,
)
from torchmultimodal . modules . losses . reconstruction_loss import ReconstructionLoss
from torchmultimodal . transforms . mae_transform import ImagePretrainTransform
mae_transform = ImagePretrainTransform ()
dataset = MyDatasetClass ( transforms = mae_transform ) # you should define this
dataloader = DataLoader ( dataset , batch_size = 8 )
# Instantiate model and loss
mae_model = vit_l_16_image_mae ()
mae_loss = ReconstructionLoss ()
# Define optimizer and lr scheduler
optimizer = torch . optim . AdamW ( mae_model . parameters ())
lr_scheduler = CosineWithWarmupAndLRScaling (
optimizer , max_iters = 1000 , warmup_iters = 100 # you should set these
)
# Train one epoch
for batch in dataloader :
model_out = mae_model ( batch [ "images" ])
loss = mae_loss ( model_out . decoder_pred , model_out . label_patches , model_out . mask )
loss . backward ()
optimizer . step ()
lr_scheduler . step ()
diffusive_labs 包含用于构建扩散模型的组件。有关这些组件的更多详细信息,请参阅diffusion_labs/README.md。
在此处查找模型类以及特定于给定架构的任何其他建模代码。例如,目录 torchmultimodal/models/blip2 包含特定于 BLIP-2 的建模组件。
在这里查找可以拼接在一起构建新架构的常见通用构建块。这包括码本、补丁嵌入或变压器编码器/解码器等层、温度对比损失或重建损失等损失、ViT 和 BERT 等编码器以及 Deep Set fusion 等融合模块。
在此处查找流行模型(例如 CLIP、FLAVA 和 MAE)的常见数据转换。
TorchMultimodal 需要 Python >= 3.8。该库可以在有或没有 CUDA 支持的情况下安装。以下假设已安装 conda。
安装conda环境
conda create -n torch-multimodal python=
conda activate torch-multimodal
安装 pytorch、torchvision 和 torchaudio。请参阅 PyTorch 文档。
# Use the current CUDA version as seen [here](https://pytorch.org/get-started/locally/)
# Select the nightly Pytorch build, Linux as the OS, and conda. Pick the most recent CUDA version.
conda install pytorch torchvision torchaudio pytorch-cuda= -c pytorch-nightly -c nvidia
# For CPU-only install
conda install pytorch torchvision torchaudio cpuonly -c pytorch-nightly
Linux 上的 Python 3.8 和 3.9 的 Nightly 二进制文件可以通过 pipwheels 安装。目前我们仅通过 PyPI 支持 Linux 平台。
python -m pip install torchmultimodal-nightly
或者,您也可以从我们的源代码构建并运行我们的示例:
git clone --recursive https://github.com/facebookresearch/multimodal.git multimodal
cd multimodal
pip install -e .
开发者请按照开发安装。
我们欢迎来自社区的任何功能请求、错误报告或拉取请求。请参阅贡献文件以了解如何提供帮助。
TorchMultimodal 已获得 BSD 许可,如 LICENSE 文件中所示。