Implementação oficial do VideoMambaPro: um salto em frente para o Mamba na compreensão do vídeo
investigamos semelhanças e diferenças entre autoatenção e Mamba na perspectiva deste último, e revelamos as limitações do Mamba na tarefa de compreensão de vídeo. Propomos o VideoMambaPro que usa o VideoMamba como backbone, mas melhorando significativamente o desempenho na tarefa de compreensão de vídeo, diminuindo a lacuna com os transformadores.
Os pacotes necessários estão no arquivo requirements.txt
e você pode executar o seguinte comando para instalar o ambiente
conda create -n videomambapro python=3.10
conda activate videomambapro
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal_conv1d==1.4.0 (we recommend to install through .whl file)
pip install mamba-ssm
pip install -r requirements.txt
Lemos e processamos da mesma forma que o VideoMAE, mas com uma convenção diferente para o formato do arquivo de lista de dados.
Pré-treinamos o modelo no conjunto de dados ImageNet-1K, onde o modelo carrega um arquivo de lista de dados com o seguinte formato:
frame_folder_path etiqueta total_frames
Existem duas implementações de nosso conjunto de dados de ajuste fino VideoClsDataset
e RawFrameClsDataset
, suportando dados de vídeo e dados de rawframes, respectivamente. Onde SSV2 usa RawFrameClsDataset
por padrão e o restante dos conjuntos de dados usa VideoClsDataset
.
VideoClsDataset
carrega um arquivo de lista de dados com o seguinte formato:
rótulo video_path
enquanto RawFrameClsDataset
carrega um arquivo de lista de dados com o seguinte formato:
frame_folder_path etiqueta total_frames
Por exemplo, a lista de dados de vídeo e a lista de dados de rawframes são mostradas abaixo:
# The path prefix 'your_path' can be specified by `--data_root ${PATH_PREFIX}` in scripts when training or inferencing.
# k400 video data validation list
your_path/k400/jf7RDuUTrsQ.mp4 325
your_path/k400/JTlatknwOrY.mp4 233
your_path/k400/NUG7kwJ-614.mp4 103
your_path/k400/y9r115bgfNk.mp4 320
your_path/k400/ZnIDviwA8CE.mp4 244
...
# ssv2 rawframes data validation list
your_path/SomethingV2/frames/74225 62 140
your_path/SomethingV2/frames/116154 51 127
your_path/SomethingV2/frames/198186 47 173
your_path/SomethingV2/frames/137878 29 99
your_path/SomethingV2/frames/151151 31 166
...
Nosso projeto é baseado no VideoMamba para comparação justa. Para resolver as limitações 1 e 2 em nosso artigo, alteramos principalmente o pipeline do Mamba aplicando a máscara diagonal durante o SSM reverso e aplicando a conexão residual no SSM bidirecional. A conexão residual de Ab é realizada na função selective_scan_ref em mamba/mamba_ssm/ops/selective_scan_interface.py, e a opção chave está abaixo:
x = u[:, :, 0].unsqueeze(-1).expand(-1, -1, dstate)
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
A atribuição da máscara é realizada através da configuração de duas funções seletivas, nomeadamente selective_scan_ref e selective_scan_ref_sub em mamba/mamba_ssm/ops/selective_scan_interface.py. Ao calcular o mamba bidirecional, por exemplo, em bimamba_inner_ref de mamba/mamba_ssm/ops/selective_scan_interface.py, o código-chave está abaixo:
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
y_b = selective_scan_ref_sub(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
y = y + y_b.flip([-1])
链接: https://pan.baidu.com/s/1vJN_XTRct65cDA_0AB259g?pwd=ghqb 提取码: ghqb