Официальная реализация VideoMambaPro: шаг вперед для Mamba в понимании видео
мы исследуем сходства и различия самовнимания и Мамбы с точки зрения последней и выявляем ограничения Мамбы в задаче понимания видео. Мы предлагаем VideoMambaPro, который использует VideoMamba в качестве основы, но значительно повышает производительность в задаче понимания видео, сокращая разрыв с преобразователями.
Необходимые пакеты находятся в файле requirements.txt
, и вы можете запустить следующую команду, чтобы установить среду.
conda create -n videomambapro python=3.10
conda activate videomambapro
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
pip install causal_conv1d==1.4.0 (we recommend to install through .whl file)
pip install mamba-ssm
pip install -r requirements.txt
Мы читаем и обрабатываем так же, как VideoMAE, но с другим соглашением о формате файла списка данных.
Мы предварительно обучаем модель на наборе данных ImageNet-1K, где модель загружает файл списка данных в следующем формате:
метка_frame_folder_path total_frames
Существует две реализации нашего набора данных точной настройки VideoClsDataset
и RawFrameClsDataset
, поддерживающие видеоданные и данные rawframes соответственно. Где SSV2 по умолчанию использует RawFrameClsDataset
, а остальные наборы данных используют VideoClsDataset
.
VideoClsDataset
загружает файл списка данных в следующем формате:
метка пути_видео
в то время как RawFrameClsDataset
загружает файл списка данных в следующем формате:
метка_frame_folder_path total_frames
Например, список видеоданных и список данных rawframes показаны ниже:
# The path prefix 'your_path' can be specified by `--data_root ${PATH_PREFIX}` in scripts when training or inferencing.
# k400 video data validation list
your_path/k400/jf7RDuUTrsQ.mp4 325
your_path/k400/JTlatknwOrY.mp4 233
your_path/k400/NUG7kwJ-614.mp4 103
your_path/k400/y9r115bgfNk.mp4 320
your_path/k400/ZnIDviwA8CE.mp4 244
...
# ssv2 rawframes data validation list
your_path/SomethingV2/frames/74225 62 140
your_path/SomethingV2/frames/116154 51 127
your_path/SomethingV2/frames/198186 47 173
your_path/SomethingV2/frames/137878 29 99
your_path/SomethingV2/frames/151151 31 166
...
Наш проект основан на VideoMamba для честного сравнения. Чтобы устранить ограничения 1 и 2 в нашей статье, мы в основном меняем конвейер Mamba, применяя диагональную маску во время обратного SSM и применяя остаточное соединение в двунаправленном SSM. Остаточное соединение Ab реализовано в функции selective_scan_ref в mamba/mamba_ssm/ops/selective_scan_interface.py, а ключевой параметр приведен ниже:
x = u[:, :, 0].unsqueeze(-1).expand(-1, -1, dstate)
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
Назначение маски реализуется путем установки двух выборочных функций, а именно selective_scan_ref и selective_scan_ref_sub в mamba/mamba_ssm/ops/selective_scan_interface.py. При вычислении двунаправленной мамбы, например, в bimamba_inner_ref файла mamba/mamba_ssm/ops/selective_scan_interface.py, код ключа приведен ниже:
y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
y_b = selective_scan_ref_sub(x.flip([-1]), delta.flip([-1]), A_b, B.flip([-1]), C.flip([-1]), D, z.flip([-1]), delta_bias, delta_softplus=True)
y = y + y_b.flip([-1])
链接: https://pan.baidu.com/s/1vJN_XTRct65cDA_0AB259g?pwd=ghqb 提取码: ghqb