Implementación de Pytorch para repensar la segmentación de imágenes interactivas en papel con baja latencia, alta calidad y diversas indicaciones, CVPR 2024.
Qin Liu, Jaemin Cho, Mohit Bansal, Marc Niethammer
UNC-Chapel Hill
El código se prueba con python=3.10
, torch=2.2.0
, torchvision=0.17.0
.
git clone https://github.com/uncbiag/SegNext
cd SegNext
Ahora, cree un nuevo entorno conda e instale los paquetes necesarios en consecuencia.
conda create -n segnext python=3.10
conda activate segnext
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirements.txt
Primero, descargue tres pesos modelo: vitb_sax1 (408M), vitb_sax2 (435M) y vitb_sax2_ft (435M). Estos pesos se guardarán automáticamente en la carpeta weights
.
python download.py
Ejecute la GUI interactiva con los pesos descargados. Los assets
contienen imágenes para demostración.
./run_demo.sh
Entrenamos y probamos nuestro método en tres conjuntos de datos: DAVIS, COCO+LVIS y HQSeg-44K.
Conjunto de datos | Descripción | Enlace de descarga |
---|---|---|
DAVIS | 345 imágenes con un objeto cada una (prueba) | DAVIS.zip (43 MB) |
HQSeg-44K | 44320 imágenes (tren); 1537 imágenes (valor) | sitio oficial |
COCO+LVIS* | 99.000 imágenes con 1,5 millones de instancias (tren) | imágenes LVIS originales + anotaciones combinadas |
No olvide cambiar las rutas a los conjuntos de datos en config.yml después de descargarlos y descomprimirlos.
(*) Para preparar COCO+LVIS, necesita descargar LVIS v1.0 original, luego descargar y descomprimir las anotaciones preprocesadas que se obtienen combinando el conjunto de datos COCO y LVIS en la carpeta con LVIS v1.0. (Las anotaciones combinadas son preparadas por RITM).
Proporcionamos un script ( run_eval.sh
) para evaluar nuestros modelos presentados. El siguiente comando ejecuta la evaluación NoC en todos los conjuntos de datos de prueba.
python ./segnext/scripts/evaluate_model.py --gpus=0 --checkpoint=./weights/vitb_sa2_cocolvis_hq44k_epoch_0.pth --datasets=DAVIS,HQSeg44K
Tren Conjunto de datos | Modelo | HQSeg-44K | DAVIS | ||||||
---|---|---|---|---|---|---|---|---|---|
5 millones de UI | NoC90 | NoC95 | NoF95 | 5 millones de UI | NoC90 | NoC95 | NoF95 | ||
C+L | vitb-sax1 (408 MB) | 85,41 | 7.47 | 11.94 | 731 | 90.13 | 5.46 | 13.31 | 177 |
C+L | vitb-sax2 (435 MB) | 85,71 | 7.18 | 11.52 | 700 | 89,85 | 5.34 | 12.80 | 163 |
C+L+HQ | vitb-sax2 (435 MB) | 91,75 | 5.32 | 9.42 | 583 | 91,87 | 4.43 | 10.73 | 123 |
Para la evaluación de la latencia SAT, consulte eval_sat_latency.ipynb.
Proporcionamos un script ( run_train.sh
) para entrenar nuestros modelos en el conjunto de datos HQSeg-44K. Puedes empezar a entrenar con los siguientes comandos. Por defecto utilizamos 4 GPU A6000 para el entrenamiento.
# train vitb-sax1 model on coco+lvis
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax1.py
torchrun --nproc-per-node=4 --master-port 29504 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3
# train vitb-sax2 model on coco+lvis
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax2.py
torchrun --nproc-per-node=4 --master-port 29505 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3
# finetune vitb-sax2 model on hqseg-44k
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_hqseg44k_sax2.py
torchrun --nproc-per-node=4 --master-port 29506 ./segnext/train.py ${MODEL_CONFIG} --batch-size=12 --gpus=0,1,2,3 --weights ./weights/vitb_sa2_cocolvis_epoch_90.pth
@article { liu2024rethinking ,
title = { Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts } ,
author = { Liu, Qin and Cho, Jaemin and Bansal, Mohit and Niethammer, Marc } ,
journal = { arXiv preprint arXiv:2404.00741 } ,
year = { 2024 }
}