Anglais | Chine
Dépôt officiel du document Robust High-Resolution Video Matting with Temporal Guidance. RVM est spécialement conçu pour les tapis vidéo humains robustes. Contrairement aux modèles neuronaux existants qui traitent les images comme des images indépendantes, RVM utilise un réseau neuronal récurrent pour traiter les vidéos avec mémoire temporelle. RVM peut effectuer un maillage en temps réel sur n'importe quelle vidéo sans entrées supplémentaires. Il atteint 4K 76FPS et HD 104FPS sur un GPU Nvidia GTX 1080 Ti. Le projet a été développé chez ByteDance Inc.
[03 novembre 2021] Correction d'un bug dans train.py.
[16 septembre 2021] Le code est réédité sous licence GPL-3.0.
[25 août 2021] Le code source et les modèles pré-entraînés sont publiés.
[27 juillet 2021] L'article est accepté par WACV 2022.
Regardez la vidéo showreel (YouTube, Bilibili) pour voir les performances du modèle.
Toutes les séquences de la vidéo sont disponibles sur Google Drive.
Démo webcam : exécutez le modèle en direct dans votre navigateur. Visualisez les états récurrents.
Démo Colab : testez notre modèle sur vos propres vidéos avec GPU gratuit.
Nous recommandons les modèles MobileNetv3 pour la plupart des cas d'utilisation. Les modèles ResNet50 sont la variante la plus grande avec de légères améliorations de performances. Notre modèle est disponible sur différents frameworks d'inférence. Consultez la documentation sur l’inférence pour plus d’instructions.
Cadre | Télécharger | Remarques |
PyTorch | rvm_mobilenetv3.pth rvm_resnet50.pth | Poids officiels pour PyTorch. Doc |
TorchHub | Rien à télécharger. | Le moyen le plus simple d'utiliser notre modèle dans votre projet PyTorch. Doc |
TorchScript | rvm_mobilenetv3_fp32.torchscript rvm_mobilenetv3_fp16.torchscript rvm_resnet50_fp32.torchscript rvm_resnet50_fp16.torchscript | En cas d'inférence sur mobile, envisagez d'exporter vous-même des modèles quantifiés int8. Doc |
ONNX | rvm_mobilenetv3_fp32.onnx rvm_mobilenetv3_fp16.onnx rvm_resnet50_fp32.onnx rvm_resnet50_fp16.onnx | Testé sur ONNX Runtime avec les backends CPU et CUDA. Les modèles fournis utilisent l'opset 12. Doc, Exporter. |
TensorFlow | rvm_mobilenetv3_tf.zip rvm_resnet50_tf.zip | Modèle enregistré TensorFlow 2. Doc |
TensorFlow.js | rvm_mobilenetv3_tfjs_int8.zip | Exécutez le modèle sur le Web. Démo, code de démarrage |
CoreML | rvm_mobilenetv3_1280x720_s0.375_fp16.mlmodèle rvm_mobilenetv3_1280x720_s0.375_int8.mlmodèle rvm_mobilenetv3_1920x1080_s0.25_fp16.mlmodèle rvm_mobilenetv3_1920x1080_s0.25_int8.mlmodèle | CoreML ne prend pas en charge la résolution dynamique. D'autres résolutions peuvent être exportées vous-même. Les modèles nécessitent iOS 13+. s désigne downsample_ratio . Doc, exportateur |
Tous les modèles sont disponibles dans Google Drive et Baidu Pan (code : gym7).
Installer les dépendances :
pip install -r exigences_inference.txt
Chargez le modèle :
import torchfrom model import MattingNetworkmodel = MattingNetwork('mobilenetv3').eval().cuda() # ou "resnet50"model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
Pour convertir des vidéos, nous fournissons une API de conversion simple :
from inference import convert_videoconvert_video(model, # Le modèle, peut être sur n'importe quel appareil (cpu ou cuda).input_source='input.mp4', # Un fichier vidéo ou une séquence d'images directory.output_type='video', # Choisissez "video " ou "png_sequence"output_composition='com.mp4', # Chemin du fichier si vidéo ; chemin du répertoire si séquence png.output_alpha="pha.mp4", # [Facultatif] Sortie de la prédiction alpha brute.output_foreground="fgr.mp4", # [Facultatif] Sortie de la prédiction brute du premier plan.output_video_mbps=4, # Sortie vidéo mbps Non nécessaire pour la séquence png.downsample_ratio=Aucun, # Un hyperparamètre à ajuster ou à utiliser. Aucun pour auto.seq_chunk=12, # Traitez n images à la fois pour un meilleur parallélisme.)
Ou écrivez votre propre code d'inférence :
depuis torch.utils.data import DataLoader depuis torchvision.transforms import ToTensor depuis inference_utils import VideoReader, VideoWriterreader = VideoReader('input.mp4', transform=ToTensor())writer = VideoWriter('output.mp4', frame_rate=30)bgr = torch .tensor([.47, 1, .6]).view(3, 1, 1).cuda() # Fond vert.rec = [Aucun] * 4 # Etats récurrents initiaux.downsample_ratio = 0,25 # Ajustez en fonction de votre vidéo.avec torch.no_grad():pour src dans DataLoader(lecteur) : # Tenseur RVB normalisé à 0 ~ 1.fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # Cycle le récurrent states.com = fgr * pha + bgr * (1 - pha) # Composite sur fond vert. writer.write(com) # Cadre d'écriture.
Les modèles et l'API du convertisseur sont également disponibles via TorchHub.
# Chargez le model.model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # ou "resnet50"# Convertisseur API.convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
Veuillez consulter la documentation d'inférence pour plus de détails sur l'hyperparamètre downsample_ratio
, plus d'arguments de convertisseur et une utilisation plus avancée.
Veuillez vous référer à la documentation de formation pour former et évaluer votre propre modèle.
La vitesse est mesurée avec inference_speed_test.py
pour référence.
GPU | dType | HD (1920 x 1080) | 4K (3840 x 2160) |
---|---|---|---|
RTX3090 | PC16 | 172 images par seconde | 154 images par seconde |
RTX 2060 Super | PC16 | 134 images par seconde | 108 images par seconde |
GTX 1080Ti | FP32 | 104 images par seconde | 74 images par seconde |
Remarque 1 : la HD utilise downsample_ratio=0.25
, la 4K utilise downsample_ratio=0.125
. Tous les tests utilisent la taille de lot 1 et le fragment de trame 1.
Remarque 2 : les GPU antérieurs à l'architecture Turing ne prennent pas en charge l'inférence FP16, donc GTX 1080 Ti utilise FP32.
Remarque 3 : Nous mesurons uniquement le débit du tenseur. Le script de conversion vidéo fourni dans ce référentiel devrait être beaucoup plus lent, car il n'utilise pas d'encodage/décodage vidéo matériel et n'effectue pas le transfert de tenseur sur des threads parallèles. Si vous souhaitez implémenter l'encodage/décodage vidéo matériel en Python, veuillez vous référer à PyNvCodec.
Shanchuan Lin
Linjie Yang
Imran Saleemi
Soumyadip Sengupta
NCNN C++ Android (@FeiGeChuanShu)
lite.ai.toolkit (@DefTruth)
Démo Web Gradio (@AK391)
Démo Unity Engine avec NatML (@natsuite)
Démo MNN C++ (@DefTruth)
Démo TNN C++ (@DefTruth)