Ашкан Гандж 1 · Ханг Су 2 · Тянь Го 1
1 Вустерский политехнический институт 2 Исследования Nvidia
Мы выпустили улучшенную версию HybridDepth, которая теперь доступна с новыми функциями и оптимизированной производительностью!
Эта работа представляет HybridDepth. HybridDepth — это практическое решение для оценки глубины, основанное на изображениях фокусного стека, полученных с камеры. Этот подход превосходит современные модели в нескольких известных наборах данных, включая NYU V2, DDFF12 и ARKitScenes.
30 октября 2024 г .: Выпущена вторая версия HybridDepth с улучшенной производительностью и предварительно обученными весами.
30 октября 2024 г .: Интегрированная поддержка TorchHub для упрощения загрузки моделей и получения выводов.
25 июля 2024 г .: Первый выпуск предварительно обученных моделей.
23 июля 2024 г .: репозиторий GitHub и модель HybridDepth запущены в эксплуатацию.
Быстро начните работу с HybridDepth с помощью блокнота Colab.
Вы можете выбрать предварительно обученную модель непосредственно с помощью TorchHub.
Доступные предварительно обученные модели:
HybridDepth_NYU5
: предварительно обучен на наборе данных NYU Depth V2 с использованием входного стека с 5 фокусами, с обучением как ветви DFF, так и уточняющего слоя.
HybridDepth_NYU10
: предварительно обучен на наборе данных NYU Depth V2 с использованием входного стека из 10 фокусов, с обучением как ветви DFF, так и уточняющего слоя.
HybridDepth_DDFF5
: предварительно обучен на наборе данных DDFF с использованием входного 5-фокального стека.
HybridDepth_NYU_PretrainedDFV5
: предварительное обучение только на уточняющем слое с набором данных NYU Depth V2 с использованием 5-фокального стека после предварительного обучения с помощью DFV.
model_name = 'HybridDepth_NYU_PretrainedDFV5' #change thismodel = torch.hub.load('cake-lab/HybridDepth', model_name, pretrained=True)model.eval()
Клонируйте репозиторий и установите зависимости:
git-клон https://github.com/cake-lab/HybridDepth.gitcd HybridDepth conda env create -f Environment.yml Конда активирует гибридную глубину
Загрузите предварительно тренированные веса:
Загрузите веса для модели по ссылкам ниже и поместите их в каталог checkpoints
:
HybridDepth_NYU_FocalStack5
HybridDepth_NYU_FocalStack10
HybridDepth_DDFF_FocalStack5
HybridDepth_NYU_PretrainedDFV_FocalStack5
Прогноз
Для вывода вы можете запустить следующий код:
# Загрузите модель checkpointmodel_path = 'checkpoints/NYUBest5.ckpt'model = DepthNetModule.load_from_checkpoint(model_path)model.eval()model = model.to('cuda')
После загрузки модели используйте следующий код для обработки входных изображений и получения карты глубины:
Примечание . В настоящее время функция prepare_input_image
поддерживает только изображения .jpg
. Измените функцию, если вам нужна поддержка других форматов изображений.
from utils.io import configure_input_imagedata_dir = 'каталог изображений фокусного стека' # Путь к изображениям фокусного стека в папке# Загрузите изображения фокусного стекаfocal_stack, rgb_img, focus_dist = подготовить_input_image(data_dir)# Запустите вывод с помощью torch.no_grad(): out = model (rgb_img, focus_stack, focus_dist)metric_глубина = out[0].squeeze().cpu().numpy() # Глубина метрики
Пожалуйста, сначала загрузите гири для модели по ссылкам ниже и поместите их в каталог checkpoints
:
HybridDepth_NYU_FocalStack5
HybridDepth_NYU_FocalStack10
HybridDepth_DDFF_FocalStack5
HybridDepth_NYU_PretrainedDFV_FocalStack5
NYU Depth V2 : Загрузите набор данных, следуя инструкциям, представленным здесь.
DDFF12 : Загрузите набор данных, следуя инструкциям, приведенным здесь.
ARKitScenes : загрузите набор данных, следуя инструкциям, представленным здесь.
Настройте файл конфигурации config.yaml
в каталоге configs
. Предварительно настроенные файлы для каждого набора данных доступны в каталоге configs
, где вы можете указать пути, настройки модели и другие гиперпараметры. Вот пример конфигурации:
data: class_path: dataloader.dataset.NYUDataModule # Путь к вашему модулю загрузчика данных в dataset.py init_args:nyuv2_data_root: "path/to/NYUv2" # Путь к конкретному набору данныхimg_size: [480, 640] # Отрегулируйте в соответствии с вашими требованиями DataModuleremove_white_border: Truenum_workers: 0 # Установите значение 0, если используете синтетические datause_labels: Truemodel: invert_length: True # Set значение True, если модель выводит инвертированную глубинуckpt_path: контрольные точки/checkpoint.ckpt
Укажите файл конфигурации в скрипте test.sh
:
python cli_run.py test --config configs/config_file_name.yaml
Затем выполните оценку с помощью:
CD-скрипты ш оценить.ш
Установите необходимый пакет на базе CUDA для синтеза изображений:
установка python utils/synthetic/gauss_psf/setup.py
При этом будет установлен пакет, необходимый для синтеза изображений.
Настройте файл конфигурации config.yaml
в каталоге configs
, указав путь к набору данных, размер пакета и другие параметры обучения. Ниже приведен пример конфигурации для обучения с набором данных NYUv2:
модель: invert_length: True # скорость обучения lr: 3e-4 # Отрегулируйте по мере необходимости # снижение веса wd: 0.001 # Откорректируйте по мере необходимостиdata: class_path: dataloader.dataset.NYUDataModule # Путь к вашему модулю загрузчика данных в dataset.py init_args:nyuv2_data_root: "path/to/NYUv2" # pathimg_size набора данных: [480, 640] # Корректировка для NYUDataModuleremove_white_border: Truebatch_size: 24 # Корректировка на основе доступной памятиnum_workers: 0 # Установите значение 0, если используются синтетические datause_labels: Trueckpt_path: null
Укажите файл конфигурации в скрипте train.sh
:
python cli_run.py train --config configs/config_file_name.yaml
Выполните команду обучения:
CD-скрипты ш поезд.ш
Если наша работа поможет вам в ваших исследованиях, пожалуйста, ссылайтесь на нее следующим образом:
@misc{ganj2024hybrideeprobustmetriclength, title={HybridDepth: надежное объединение метрических глубин путем использования глубины от фокуса и априорных значений одиночного изображения}, автор={Ашкан Гандж и Ханг Су и Тянь Го}, год={2024}, eprint={2407.18443} , archivePrefix={arXiv}, PrimaryClass={cs.CV}, url={https://arxiv.org/abs/2407.18443}, }