Inglés | 中文
Repositorio oficial del artículo Robust High-Resolution Video Matting with Temporal Guidance. RVM está diseñado específicamente para un mateado de vídeo humano robusto. A diferencia de los modelos neuronales existentes que procesan fotogramas como imágenes independientes, RVM utiliza una red neuronal recurrente para procesar vídeos con memoria temporal. RVM puede realizar mates en tiempo real en cualquier vídeo sin entradas adicionales. Alcanza 4K 76FPS y HD 104FPS en una GPU Nvidia GTX 1080 Ti. El proyecto fue desarrollado en ByteDance Inc.
[3 de noviembre de 2021] Se corrigió un error en train.py.
[16 de septiembre de 2021] El código se vuelve a publicar bajo la licencia GPL-3.0.
[25 de agosto de 2021] Se publican el código fuente y los modelos previamente entrenados.
[27 de julio de 2021] El artículo es aceptado por WACV 2022.
Mire el video showreel (YouTube, Bilibili) para ver el desempeño del modelo.
Todas las imágenes del vídeo están disponibles en Google Drive.
Demostración de cámara web: ejecute el modelo en vivo en su navegador. Visualiza estados recurrentes.
Demostración de Colab: pruebe nuestro modelo en sus propios videos con GPU gratuita.
Recomendamos los modelos MobileNetv3 para la mayoría de los casos de uso. Los modelos ResNet50 son la variante más grande con pequeñas mejoras de rendimiento. Nuestro modelo está disponible en varios marcos de inferencia. Consulte la documentación de inferencia para obtener más instrucciones.
Estructura | Descargar | Notas |
PyTorch | rvm_mobilenetv3.pth rvm_resnet50.pth | Pesos oficiales para PyTorch. Doc |
AntorchaHub | Nada que descargar. | La forma más sencilla de utilizar nuestro modelo en su proyecto PyTorch. Doc |
AntorchaScript | rvm_mobilenetv3_fp32.torchscript rvm_mobilenetv3_fp16.torchscript rvm_resnet50_fp32.torchscript rvm_resnet50_fp16.torchscript | Si realiza inferencias en dispositivos móviles, considere exportar usted mismo los modelos cuantificados int8. Doc |
ONNX | rvm_mobilenetv3_fp32.onnx rvm_mobilenetv3_fp16.onnx rvm_resnet50_fp32.onnx rvm_resnet50_fp16.onnx | Probado en ONNX Runtime con CPU y backends CUDA. Los modelos proporcionados utilizan opset 12. Doc, Exportador. |
TensorFlow | rvm_mobilenetv3_tf.zip rvm_resnet50_tf.zip | Modelo guardado de TensorFlow 2. Doc |
TensorFlow.js | rvm_mobilenetv3_tfjs_int8.zip | Ejecute el modelo en la web. Demostración, código de inicio |
CoreML | rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodelo rvm_mobilenetv3_1280x720_s0.375_int8.mlmodelo rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodelo rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodelo | CoreML no admite resolución dinámica. Otras resoluciones se pueden exportar usted mismo. Los modelos requieren iOS 13+. s denota downsample_ratio . Doc, Exportador |
Todos los modelos están disponibles en Google Drive y Baidu Pan (código: gym7).
Instalar dependencias:
instalación de pip -r requisitos_inferencia.txt
Cargar el modelo:
importar antorcha del modelo importar MattingNetworkmodel = MattingNetwork('mobilenetv3').eval().cuda() # o "resnet50"model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
Para convertir videos, proporcionamos una API de conversión simple:
desde inferencia import convert_videoconvert_video(model, # El modelo, puede estar en cualquier dispositivo (cpu o cuda).input_source='input.mp4', # Un archivo de video o una secuencia de imágenes directorio.output_type='video', # Elige "video " o "png_sequence"output_composition='com.mp4', # Ruta del archivo si es video; ruta del directorio si png secuencia.output_alpha="pha.mp4", # [Opcional] Genera la predicción alfa sin procesar.output_foreground="fgr.mp4", # [Opcional] Genera la predicción de primer plano sin procesar.output_video_mbps=4, # No es necesario para la secuencia png.downsample_ratio=None, # Un hiperparámetro para ajustar o usar. Ninguno para auto.seq_chunk=12, # Procese n cuadros a la vez para un mejor paralelismo).
O escriba su propio código de inferencia:
desde torch.utils.data importar DataLoaderdesde torchvision.transforms importar ToTensordesde inference_utils importar VideoReader, VideoWriterreader = VideoReader('input.mp4', transform=ToTensor())writer = VideoWriter('output.mp4', frame_rate=30)bgr = torch .tensor([.47, 1, .6]).vista(3, 1, 1).cuda() # Fondo verde.rec = [Ninguno] * 4 # Estados recurrentes iniciales.downsample_ratio = 0.25 # Ajusta según tu video.with torch.no_grad():for src in DataLoader(reader): # Tensor RGB normalizado a 0 ~ 1.fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Ciclo de estados recurrentes.com = fgr * pha + bgr * (1 - pha) # Compuesto sobre fondo verde. escritor.write(com) # Escribir marco.
Los modelos y la API del convertidor también están disponibles a través de TorchHub.
# Cargue el model.model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # o "resnet50"# Convertidor API.convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
Consulte la documentación de inferencia para obtener detalles sobre el hiperparámetro downsample_ratio
, más argumentos del convertidor y un uso más avanzado.
Consulte la documentación de capacitación para entrenar y evaluar su propio modelo.
La velocidad se mide con inference_speed_test.py
como referencia.
GPU | dTipo | Alta definición (1920x1080) | 4K (3840x2160) |
---|---|---|---|
RTX 3090 | FP16 | 172 fotogramas por segundo | 154 FPS |
RTX 2060 Súper | FP16 | 134 FPS | 108 FPS |
GTX 1080Ti | FP32 | 104 FPS | 74 FPS |
Nota 1: HD usa downsample_ratio=0.25
, 4K usa downsample_ratio=0.125
. Todas las pruebas utilizan el tamaño de lote 1 y el fragmento de marco 1.
Nota 2: Las GPU anteriores a la arquitectura Turing no admiten la inferencia FP16, por lo que la GTX 1080 Ti usa FP32.
Nota 3: solo medimos el rendimiento del tensor. Se espera que el script de conversión de video proporcionado en este repositorio sea mucho más lento, porque no utiliza codificación/decodificación de video por hardware y no realiza la transferencia tensorial en subprocesos paralelos. Si está interesado en implementar la codificación/decodificación de vídeo por hardware en Python, consulte PyNvCodec.
Shan Chuan Lin
Linjie Yang
Imran Saleemi
Soumyadip Sengupta
NCNN C++ Android (@FeiGeChuanShu)
kit de herramientas lite.ai (@DefTruth)
Demostración web de Gradio (@AK391)
Demostración de Unity Engine con NatML (@natsuite)
Demostración de MNN C++ (@DefTruth)
Demostración de TNN C++ (@DefTruth)