VideoMambaPro 正式上線:Mamba 在視訊理解上的飛躍
我們從後者的角度研究了 self-attention 和 Mamba 的異同,揭示了 Mamba 在視訊理解任務上的限制。我們提出VideoMambaPro,它使用VideoMamba作為骨幹,但顯著增強了視訊理解任務的效能,縮小了與Transformers的差距。
所需的套件在requirements.txt
檔案中,您可以執行以下命令來安裝環境
conda create -n videomambapro python=3.10
conda activate videomambapro
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal_conv1d==1.4.0 (we recommend to install through .whl file)
pip install mamba-ssm
pip install -r requirements.txt
我們以與 VideoMAE 相同的方式讀取和處理,但資料列表檔案的格式採用不同的約定。
我們在 ImageNet-1K 資料集上預先訓練模型,其中模型載入以下格式的資料列表檔案:
Frame_folder_path Total_frames 標籤
我們的微調資料集VideoClsDataset
和RawFrameClsDataset
有兩種實現,分別支援視訊資料和 rawframes 資料。其中 SSV2 預設使用RawFrameClsDataset
,其餘資料集使用VideoClsDataset
。
VideoClsDataset
載入以下格式的資料列表檔案:
視訊路徑標籤
而RawFrameClsDataset
則載入一個資料列表文件,格式如下:
Frame_folder_path Total_frames 標籤
例如視訊資料列表和rawframes資料列表如下所示:
# The path prefix 'your_path' can be specified by `--data_root ${PATH_PREFIX}` in scripts when training or inferencing.
# k400 video data validation list
your_path/k400/jf7RDuUTrsQ.mp4 325
your_path/k400/JTlatknwOrY.mp4 233
your_path/k400/NUG7kwJ-614.mp4 103
your_path/k400/y9r115bgfNk.mp4 320
your_path/k400/ZnIDviwA8CE.mp4 244
...
# ssv2 rawframes data validation list
your_path/SomethingV2/frames/74225 62 140
your_path/SomethingV2/frames/116154 51 127
your_path/SomethingV2/frames/198186 47 173
your_path/SomethingV2/frames/137878 29 99
your_path/SomethingV2/frames/151151 31 166
...
我們的專案基於 VideoMamba 進行公平比較。為了解決本文中的限制 1 和 2,我們主要透過在後向 SSM 期間應用對角遮罩並在雙向 SSM 上應用殘差連接來改變 Mamba 的管道。 Ab的殘差連接在mamba/mamba_ssm/ops/selective_scan_interface.py中的函數selective_scan_ref中實現,關鍵選項如下:
x = u[:, :, 0].unsqueeze(-1).expand(-1, -1, dstate)
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
mask分配是透過在mamba/mamba_ssm/ops/selective_scan_interface.py中設定兩個選擇性函數來實現的,即selective_scan_ref和selective_scan_ref_sub。計算雙向mamba時,例如mamba/mamba_ssm/ops/selective_scan_interface.py的bimamba_inner_ref中,關鍵程式碼如下:
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
y_b = selective_scan_ref_sub(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
y = y + y_b.flip([-1])
連結: https://pan.baidu.com/s/1vJN_XTRct65cDA_0AB259g?pwd=ghqb 提取碼: ghqb