アシュカン・ガンジ1 ·ハン・スー2 ·ティアン・グオ1
1ウースター工科大学2 Nvidia Research
HybridDepth の改良版をリリースしました。新機能と最適化されたパフォーマンスを備えたバージョンが利用可能になりました。
この作品ではHybridDepthが登場します。 HybridDepth は、カメラからキャプチャされた焦点スタック画像に基づく実用的な深度推定ソリューションです。このアプローチは、NYU V2、DDFF12、ARKitScenes などのいくつかのよく知られたデータセットにわたって最先端のモデルよりも優れたパフォーマンスを発揮します。
2024-10-30 : パフォーマンスが向上し、事前にトレーニングされた重みを備えた HybridDepth のバージョン 2がリリースされました。
2024-10-30 : モデルの読み込みと推論を容易にするために、TorchHub のサポートが統合されました。
2024-07-25 : 事前トレーニングされたモデルの初期リリース。
2024-07-23 : GitHub リポジトリと HybridDepth モデルが公開されました。
Colab ノートブックを使用して、HybridDepth をすぐに始めましょう。
TorchHub を使用して、事前トレーニングされたモデルを直接選択できます。
利用可能な事前トレーニング済みモデル:
HybridDepth_NYU5
: 5 焦点スタック入力を使用して NYU Depth V2 データセットで事前トレーニングされ、DFF ブランチとリファインメント レイヤーの両方がトレーニングされます。
HybridDepth_NYU10
: 10 焦点スタック入力を使用して NYU Depth V2 データセットで事前トレーニングされ、DFF ブランチとリファインメント レイヤーの両方がトレーニングされます。
HybridDepth_DDFF5
: 5 焦点スタック入力を使用して DDFF データセットで事前トレーニングされています。
HybridDepth_NYU_PretrainedDFV5
: DFV での事前トレーニングに続き、5 焦点スタックを使用して NYU Depth V2 データセットを使用してリファインメント レイヤー上でのみ事前トレーニングされます。
model_name = 'HybridDepth_NYU_PretrainedDFV5' #change thismodel = torch.hub.load('cake-lab/HybridDepth', model_name , pretrained=True)model.eval()
リポジトリのクローンを作成し、依存関係をインストールします。
git clone https://github.com/cake-lab/HybridDepth.gitcd HybridDepth conda env create -f 環境.yml conda はハイブリッド深度をアクティブにします
事前にトレーニングされた重みをダウンロードします。
以下のリンクからモデルの重みをダウンロードし、 checkpoints
ディレクトリに配置します。
HybridDepth_NYU_FocalStack5
HybridDepth_NYU_FocalStack10
HybridDepth_DDFF_FocalStack5
HybridDepth_NYU_PretrainedDFV_FocalStack5
予測
推論のために、次のコードを実行できます。
# モデルをロードする Checkpointmodel_path = 'checkpoints/NYUBest5.ckpt'model = DepthNetModule.load_from_checkpoint(model_path)model.eval()model = model.to('cuda')
モデルをロードした後、次のコードを使用して入力画像を処理し、深度マップを取得します。
注: 現在、 prepare_input_image
関数は.jpg
画像のみをサポートしています。他の画像形式のサポートが必要な場合は、関数を変更します。
from utils.io import prepare_input_imagedata_dir = 'フォーカル スタック イメージ ディレクトリ' # フォルダー内のフォーカル スタック イメージへのパス# フォーカル スタック イメージをロードしますfocal_stack, rgb_img, focus_dist = prepare_input_image(data_dir)# torch.no_grad() で推論を実行します: out = model (rgb_img、focal_stack、focus_dist)metric_ Depth = out[0].squeeze().cpu().numpy() # メトリックの深さ
まず、以下のリンクからモデルの重みをダウンロードし、 checkpoints
ディレクトリに配置してください。
HybridDepth_NYU_FocalStack5
HybridDepth_NYU_FocalStack10
HybridDepth_DDFF_FocalStack5
HybridDepth_NYU_PretrainedDFV_FocalStack5
NYU Depth V2 : ここに記載されている手順に従ってデータセットをダウンロードします。
DDFF12 : ここに記載されている手順に従ってデータセットをダウンロードします。
ARKitScenes : ここに記載されている手順に従ってデータセットをダウンロードします。
configs
ディレクトリに構成ファイルconfig.yaml
セットアップします。各データセットの事前構成ファイルはconfigs
ディレクトリで利用でき、パス、モデル設定、その他のハイパーパラメーターを指定できます。構成例を次に示します。
data: class_path: dataloader.dataset.NYUDataModule # dataset.py 内のデータローダー モジュールへのパス init_args:nyuv2_data_root: "path/to/NYUv2" # 特定のデータセットへのパスtimg_size: [480, 640] # DataModule の要件に基づいて調整remove_white_border: Truenum_workers: 0 # 合成データを使用する場合は 0 に設定use_labels: Truemodel: invert_ Depth: True # 設定モデルが反転した Depthckpt_path を出力する場合は True に設定します。チェックポイント/checkpoint.ckpt
test.sh
スクリプトで構成ファイルを指定します。
python cli_run.py test --config configs/config_file_name.yaml
次に、次のようにして評価を実行します。
CD スクリプト sh 評価.sh
画像合成に必要な CUDA ベースのパッケージをインストールします。
python utils/synthetic/gauss_psf/setup.py インストール
画像合成に必要なパッケージがインストールされます。
configs
ディレクトリに構成ファイルconfig.yaml
セットアップし、データセットのパス、バッチ サイズ、その他のトレーニング パラメーターを指定します。以下は、NYUv2 データセットを使用したトレーニングのサンプル構成です。
モデル: invert_ Depth: True # 学習率 lr: 3e-4 # 必要に応じて調整します # 体重減少 wd: 0.001 # 必要に応じて調整data: class_path: dataloader.dataset.NYUDataModule # dataset.py 内のデータローダー モジュールへのパス init_args:nyuv2_data_root: "path/to/NYUv2" # データセット pathimg_size: [480, 640] # NYUDataModuleremove_white_border に合わせて調整: Truebatch_size: 24 # 利用可能なメモリに基づいて調整num_workers: 0 # 合成データを使用する場合は 0 に設定use_labels: Trueckpt_path: null
train.sh
スクリプトで構成ファイルを指定します。
python cli_run.py train --config configs/config_file_name.yaml
トレーニング コマンドを実行します。
CD スクリプト sh train.sh
私たちの研究があなたの研究に役立つ場合は、次のように引用してください。
@misc{ganj2024hybrid Depthrobustmetric Depth, title={HybridDepth: フォーカスと単一画像事前分布からの深度を活用することによる堅牢なメトリクス深度融合}、author={Ashkan Ganj、Hang Su、Tian Guo}、year={2024}、eprint={2407.18443} 、archivePrefix = {arXiv}、primaryClass = {cs.CV}、 URL={https://arxiv.org/abs/2407.18443}、 }