英語 | 中国語
論文「時間的ガイダンスを備えた堅牢な高解像度ビデオ マッティング」の公式リポジトリ。 RVM は、堅牢なヒューマン ビデオ マッティング用に特別に設計されています。フレームを独立した画像として処理する既存のニューラル モデルとは異なり、RVM はリカレント ニューラル ネットワークを使用して一時メモリでビデオを処理します。 RVM は、追加の入力を必要とせずに、あらゆるビデオに対してリアルタイムでマット処理を実行できます。 Nvidia GTX 1080 Ti GPU で4K 76FPSおよびHD 104FPSを実現します。このプロジェクトはByteDance Inc.で開発されました。
[2021 11 03] train.pyのバグを修正しました。
[2021 年 9 月 16 日] コードが GPL-3.0 ライセンスに基づいて再リリースされました。
[2021 8 25] ソースコードと事前学習済みモデルを公開しました。
[2021 7 27] WACV 2022 に論文が採択されました。
モデルのパフォーマンスを確認するには、ショーリール ビデオ (YouTube、Bilibili) をご覧ください。
ビデオ内のすべての映像は Google ドライブで利用できます。
ウェブカメラ デモ: ブラウザでモデルをライブで実行します。再発する状態を視覚化します。
Colab デモ: 無料の GPU を使用して、独自のビデオでモデルをテストします。
ほとんどの使用例には MobileNetv3 モデルをお勧めします。 ResNet50 モデルは、パフォーマンスがわずかに向上した、より大きなバリアントです。私たちのモデルはさまざまな推論フレームワークで利用できます。詳細については、推論のドキュメントを参照してください。
フレームワーク | ダウンロード | 注意事項 |
パイトーチ | rvm_mobilenetv3.pth RVM_resnet50.pth | PyTorch の公式の重み。ドクター |
トーチハブ | ダウンロードするものはありません。 | PyTorch プロジェクトでモデルを使用する最も簡単な方法。ドクター |
トーチスクリプト | rvm_mobilenetv3_fp32.torchscript rvm_mobilenetv3_fp16.torchscript rvm_resnet50_fp32.torchscript rvm_resnet50_fp16.torchscript | モバイルで推論する場合は、int8 量子化モデルを自分でエクスポートすることを検討してください。ドクター |
ONNX | rvm_mobilenetv3_fp32.onnx rvm_mobilenetv3_fp16.onnx rvm_resnet50_fp32.onnx rvm_resnet50_fp16.onnx | CPU および CUDA バックエンドを使用した ONNX ランタイムでテスト済み。提供されたモデルは opset 12 を使用します。Doc、Exporter。 |
TensorFlow | RVM_mobilenetv3_tf.zip RVM_resnet50_tf.zip | TensorFlow 2 の保存されたモデル。ドクター |
TensorFlow.js | rvm_mobilenetv3_tfjs_int8.zip | Web 上でモデルを実行します。デモ、スターターコード |
CoreML | rvm_mobilenetv3_1280x720_s0.375_fp16.mlモデル RVM_mobilenetv3_1280x720_s0.375_int8.mlモデル rvm_mobilenetv3_1920x1080_s0.25_fp16.mlモデル RVM_mobilenetv3_1920x1080_s0.25_int8.mlmodel | CoreML は動的解像度をサポートしていません。他の解像度は自分でエクスポートできます。モデルには iOS 13 以降が必要です。 s downsample_ratio 示します。ドクター、エクスポーター |
すべてのモデルは Google ドライブと Baidu Pan (コード: Gym7) で利用できます。
依存関係をインストールします。
pip install -rrequirements_inference.txt
モデルをロードします。
import torchfrom モデル import MattingNetworkmodel = MattingNetwork('mobilenetv3').eval().cuda() # または "resnet50"model.load_state_dict(torch.load('rvm_mobilenetv3.pth'))
ビデオを変換するために、シンプルな変換 API が提供されています。
from inference import Convert_videoconvert_video(model, # モデルは任意のデバイス (cpu または cuda) 上にあります。input_source='input.mp4', # ビデオ ファイルまたはイメージ シーケンス ディレクトリ。output_type='video', # 「video」を選択します" または "png_sequence"output_composition='com.mp4', # ビデオの場合はファイル パス、png sequence.output_alpha="pha.mp4" の場合はディレクトリ パス[オプション] 生のアルファ予測を出力します。output_foreground="fgr.mp4", # [オプション] 生のフォアグラウンド予測を出力します。output_video_mbps=4, # 出力ビデオ mbps は png シーケンスには必要ありません。downsample_ratio=None, # ハイパーパラメータauto.seq_chunk=12 に調整するか、None を使用します。 # 並列処理を向上させるために、一度に n フレームを処理します。)
または、独自の推論コードを作成します。
from torch.utils.data import DataLoaderfrom torchvision.transforms import ToTensorfrom inference_utils import VideoReader, VideoWriterreader = VideoReader('input.mp4', transform=ToTensor())writer = VideoWriter('output.mp4', Frame_rate=30)bgr = torch .tensor([.47, 1, .6]).view(3, 1, 1).cuda() # 緑色の背景.rec = [なし] * 4 # 初期反復状態.downsample_ratio = 0.25 # ビデオに基づいて調整します。with torch.no_grad():for src in DataLoader(reader): # 0 ~ 1 に正規化された RGB テンソル。fgr, pha, *rec = model(src.cuda(), *rec, downsample_ratio) # 繰り返し状態を循環させます。com = fgr * pha + bgr * (1 - pha) # 緑色の背景に合成します。 Writer.write(com) # フレームを書き込みます。
モデルとコンバーター API は、TorchHub からも入手できます。
# model をロードします。model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3") # または "resnet50"# コンバーター API.convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
downsample_ratio
ハイパーパラメータ、その他のコンバータ引数、およびより高度な使用法の詳細については、推論ドキュメントを参照してください。
独自のモデルをトレーニングおよび評価するには、トレーニング ドキュメントを参照してください。
速度はinference_speed_test.py
で計測していますので参考にしてください。
GPU | dタイプ | HD (1920x1080) | 4K (3840x2160) |
---|---|---|---|
RTX3090 | FP16 | 172FPS | 154FPS |
RTX 2060 スーパー | FP16 | 134FPS | 108FPS |
GTX1080Ti | FP32 | 104FPS | 74FPS |
注 1: HD はdownsample_ratio=0.25
を使用し、4K はdownsample_ratio=0.125
を使用します。すべてのテストではバッチ サイズ 1 とフレーム チャンク 1 を使用します。
注 2: Turing アーキテクチャより前の GPU は FP16 推論をサポートしていないため、GTX 1080 Ti は FP32 を使用します。
注 3: テンソル スループットのみを測定します。このリポジトリで提供されているビデオ変換スクリプトは、ハードウェア ビデオ エンコード/デコードを利用しておらず、並列スレッドでテンソル転送が行われていないため、かなり遅くなることが予想されます。 Python でのハードウェア ビデオ エンコード/デコードの実装に興味がある場合は、PyNvCodec を参照してください。
林山川
ヤン・リンジエ
イムラン・サレミ
ソウミャディップ・セングプタ
NCNN C++ Android (@FeiGeChuanShu)
lite.ai.toolkit (@DefTruth)
Gradio Web デモ (@AK391)
NatML を使用した Unity Engine のデモ (@natsuite)
MNN C++ デモ (@DefTruth)
TNN C++ デモ (@DefTruth)