VideoMambaPro 正式上线:Mamba 在视频理解上的飞跃
我们从后者的角度研究了 self-attention 和 Mamba 的异同,揭示了 Mamba 在视频理解任务上的局限性。我们提出VideoMambaPro,它使用VideoMamba作为骨干,但显着增强了视频理解任务的性能,缩小了与Transformers的差距。
所需的包在requirements.txt
文件中,您可以运行以下命令来安装环境
conda create -n videomambapro python=3.10
conda activate videomambapro
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal_conv1d==1.4.0 (we recommend to install through .whl file)
pip install mamba-ssm
pip install -r requirements.txt
我们以与 VideoMAE 相同的方式读取和处理,但数据列表文件的格式采用不同的约定。
我们在 ImageNet-1K 数据集上预训练模型,其中模型加载以下格式的数据列表文件:
Frame_folder_path Total_frames 标签
我们的微调数据集VideoClsDataset
和RawFrameClsDataset
有两种实现,分别支持视频数据和 rawframes 数据。其中 SSV2 默认使用RawFrameClsDataset
,其余数据集使用VideoClsDataset
。
VideoClsDataset
加载以下格式的数据列表文件:
视频路径标签
而RawFrameClsDataset
则加载一个数据列表文件,格式如下:
Frame_folder_path Total_frames 标签
例如视频数据列表和rawframes数据列表如下所示:
# The path prefix 'your_path' can be specified by `--data_root ${PATH_PREFIX}` in scripts when training or inferencing.
# k400 video data validation list
your_path/k400/jf7RDuUTrsQ.mp4 325
your_path/k400/JTlatknwOrY.mp4 233
your_path/k400/NUG7kwJ-614.mp4 103
your_path/k400/y9r115bgfNk.mp4 320
your_path/k400/ZnIDviwA8CE.mp4 244
...
# ssv2 rawframes data validation list
your_path/SomethingV2/frames/74225 62 140
your_path/SomethingV2/frames/116154 51 127
your_path/SomethingV2/frames/198186 47 173
your_path/SomethingV2/frames/137878 29 99
your_path/SomethingV2/frames/151151 31 166
...
我们的项目基于 VideoMamba 进行公平比较。为了解决本文中的限制 1 和 2,我们主要通过在后向 SSM 期间应用对角掩码并在双向 SSM 上应用残差连接来改变 Mamba 的管道。 Ab的残差连接在mamba/mamba_ssm/ops/selective_scan_interface.py中的函数selective_scan_ref中实现,关键选项如下:
x = u[:, :, 0].unsqueeze(-1).expand(-1, -1, dstate)
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
mask分配是通过在mamba/mamba_ssm/ops/selective_scan_interface.py中设置两个选择性函数来实现的,即selective_scan_ref和selective_scan_ref_sub。计算双向mamba时,例如mamba/mamba_ssm/ops/selective_scan_interface.py的bimamba_inner_ref中,关键代码如下:
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
y_b = selective_scan_ref_sub(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
y = y + y_b.flip([-1])
链接: https://pan.baidu.com/s/1vJN_XTRct65cDA_0AB259g?pwd=ghqb 提取码: ghqb