Inglês | 中文
Repositório oficial do artigo Robust High-Resolution Video Matting with Temporal Guidance. O RVM foi projetado especificamente para um revestimento de vídeo humano robusto. Ao contrário dos modelos neurais existentes que processam quadros como imagens independentes, o RVM utiliza uma rede neural recorrente para processar vídeos com memória temporal. O RVM pode realizar matização em tempo real em qualquer vídeo sem entradas adicionais. Atinge 4K 76FPS e HD 104FPS em uma GPU Nvidia GTX 1080 Ti. O projeto foi desenvolvido na ByteDance Inc.
[03 de novembro de 2021] Corrigido um bug em train.py.
[16 de setembro de 2021] O código é relançado sob a licença GPL-3.0.
[25 de agosto de 2021] O código-fonte e os modelos pré-treinados são publicados.
[27 de julho de 2021] O artigo foi aceito pelo WACV 2022.
Assista ao vídeo do showreel (YouTube, Bilibili) para ver a atuação da modelo.
Todas as filmagens do vídeo estão disponíveis no Google Drive.
Demonstração de webcam: execute o modelo ao vivo em seu navegador. Visualize estados recorrentes.
Demonstração do Colab: teste nosso modelo em seus próprios vídeos com GPU gratuita.
Recomendamos modelos MobileNetv3 para a maioria dos casos de uso. Os modelos ResNet50 são a variante maior com pequenas melhorias de desempenho. Nosso modelo está disponível em várias estruturas de inferência. Consulte a documentação de inferência para obter mais instruções.
Estrutura | Download | Notas |
PyTorch | rvm_mobilenetv3.pth rvm_resnet50.pth | Pesos oficiais para PyTorch. Doutor |
TorchHub | Nada para baixar. | A maneira mais fácil de usar nosso modelo em seu projeto PyTorch. Doutor |
TorchScript | rvm_mobilenetv3_fp32.torchscript rvm_mobilenetv3_fp16.torchscript rvm_resnet50_fp32.torchscript rvm_resnet50_fp16.torchscript | Se fizer inferência em dispositivos móveis, considere exportar você mesmo modelos quantizados int8. Doutor |
ONNX | rvm_mobilenetv3_fp32.onnx rvm_mobilenetv3_fp16.onnx rvm_resnet50_fp32.onnx rvm_resnet50_fp16.onnx | Testado em ONNX Runtime com back-ends de CPU e CUDA. Os modelos fornecidos usam opset 12. Doc, Exportador. |
TensorFlow | rvm_mobilenetv3_tf.zip rvm_resnet50_tf.zip | Modelo salvo do TensorFlow 2. Doutor |
TensorFlow.js | rvm_mobilenetv3_tfjs_int8.zip | Execute o modelo na web. Demonstração, código inicial |
CoreML | rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodel rvm_mobilenetv3_1280x720_s0.375_int8.mlmodel rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodelo rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodel | CoreML não oferece suporte à resolução dinâmica. Outras resoluções podem ser exportadas você mesmo. Os modelos requerem iOS 13+. s denota downsample_ratio . Documento, Exportador |
Todos os modelos estão disponíveis no Google Drive e Baidu Pan (código: gym7).
Instale dependências:
pip instalar -r requisitos_inferência.txt
Carregue o modelo:
importar tocha do modelo importar MattingNetworkmodel = MattingNetwork('mobilenetv3').eval().cuda() # ou "resnet50"model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
Para converter vídeos, fornecemos uma API de conversão simples:
from inference import convert_videoconvert_video(model, # O modelo, pode estar em qualquer dispositivo (cpu ou cuda).input_source='input.mp4', # Um arquivo de vídeo ou uma sequência de imagens directory.output_type='video', # Escolha "video " ou "png_sequence"output_composition='com.mp4', # Caminho do arquivo se for vídeo; caminho do diretório se png sequencia.output_alpha="pha.mp4", # [Opcional] Produza o alfa bruto Prediction.output_foreground="fgr.mp4", # [Opcional] Produza o primeiro plano bruto Prediction.output_video_mbps=4, # Produza mbps de vídeo não necessário para png sequence.downsample_ratio=None, # Um hiperparâmetro para ajustar ou usar. Nenhum para auto.seq_chunk=12, # Processe n quadros de uma vez para melhor paralelismo.)
Ou escreva seu próprio código de inferência:
de torch.utils.data import DataLoaderfrom torchvision.transforms import ToTensorfrom inference_utils import VideoReader, VideoWriterreader = VideoReader('input.mp4', transform=ToTensor())writer = VideoWriter('output.mp4', frame_rate=30)bgr = tocha .tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Fundo verde.rec = [None] * 4 # Estados recorrentes iniciais.downsample_ratio = 0,25 # Ajuste com base no seu vídeo.with torch.no_grad():for src in DataLoader(reader): # Tensor RGB normalizado para 0 ~ 1.fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Ciclo do recorrente states.com = fgr * pha + bgr * (1 - pha) # Composto em fundo verde. escritor.write(com) # Escrever quadro.
Os modelos e a API do conversor também estão disponíveis no TorchHub.
# Carregue o model.model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # ou "resnet50"# Converter API.convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
Consulte a documentação de inferência para obter detalhes sobre o hiperparâmetro downsample_ratio
, mais argumentos do conversor e uso mais avançado.
Consulte a documentação de treinamento para treinar e avaliar seu próprio modelo.
A velocidade é medida com inference_speed_test.py
para referência.
GPU | dTipo | HD (1920x1080) | 4K (3840x2160) |
---|---|---|---|
RTX3090 | FP16 | 172 FPS | 154 FPS |
RTX 2060 Super | FP16 | 134 FPS | 108 FPS |
GTX 1080Ti | FP32 | 104 FPS | 74 FPS |
Nota 1: HD usa downsample_ratio=0.25
, 4K usa downsample_ratio=0.125
. Todos os testes usam tamanho de lote 1 e bloco de quadro 1.
Nota 2: GPUs anteriores à arquitetura Turing não suportam inferência FP16, então GTX 1080 Ti usa FP32.
Nota 3: Medimos apenas a taxa de transferência do tensor. Espera-se que o script de conversão de vídeo fornecido neste repositório seja muito mais lento, porque não utiliza codificação/decodificação de vídeo de hardware e não tem a transferência de tensor feita em threads paralelos. Se você estiver interessado em implementar codificação/decodificação de vídeo de hardware em Python, consulte PyNvCodec.
Shanchuan Lin
Linjie Yang
Imran Saleemi
Soumyadip Sengupta
NCNN C++ Android (@FeiGeChuanShu)
lite.ai.toolkit (@DefTruth)
Demonstração da Gradio Web (@ AK391)
Demonstração do Unity Engine com NatML (@natsuite)
Demonstração MNN C++ (@DefTruth)
Demonstração TNN C++ (@DefTruth)