Implementação Pytorch para papel Repensando a segmentação de imagens interativas com baixa latência, alta qualidade e prompts diversos, CVPR 2024.
Qin Liu, Jaemin Cho, Mohit Bansal, Marc Niethammer
UNC-Chapel Hill
O código é testado com python=3.10
, torch=2.2.0
, torchvision=0.17.0
.
git clone https://github.com/uncbiag/SegNext
cd SegNext
Agora, crie um novo ambiente conda e instale os pacotes necessários de acordo.
conda create -n segnext python=3.10
conda activate segnext
conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install -r requirements.txt
Primeiro, baixe três pesos de modelo: vitb_sax1 (408M), vitb_sax2 (435M) e vitb_sax2_ft (435M). Esses pesos serão salvos automaticamente na pasta weights
.
python download.py
Execute a GUI interativa com os pesos baixados. Os assets
contêm imagens para demonstração.
./run_demo.sh
Treinamos e testamos nosso método em três conjuntos de dados: DAVIS, COCO+LVIS e HQSeg-44K.
Conjunto de dados | Descrição | Link para baixar |
---|---|---|
DAVIS | 345 imagens com um objeto cada (teste) | DAVIS.zip (43 MB) |
HQSeg-44K | 44320 imagens (trem); 1537 imagens (valor) | site oficial |
COCO+LVIS* | 99 mil imagens com 1,5 milhão de instâncias (trem) | imagens LVIS originais + anotações combinadas |
Não se esqueça de alterar os caminhos para os conjuntos de dados em config.yml após baixar e descompactar.
(*) Para preparar COCO+LVIS, você precisa baixar o LVIS v1.0 original e, em seguida, baixar e descompactar as anotações pré-processadas que são obtidas combinando o conjunto de dados COCO e LVIS na pasta com LVIS v1.0. (As anotações combinadas são preparadas pelo RITM.)
Fornecemos um script ( run_eval.sh
) para avaliar nossos modelos apresentados. O comando a seguir executa a avaliação NoC em todos os conjuntos de dados de teste.
python ./segnext/scripts/evaluate_model.py --gpus=0 --checkpoint=./weights/vitb_sa2_cocolvis_hq44k_epoch_0.pth --datasets=DAVIS,HQSeg44K
Trem Conjunto de dados | Modelo | HQSeg-44K | DAVIS | ||||||
---|---|---|---|---|---|---|---|---|---|
5-mIoU | NãoC90 | NãoC95 | NºF95 | 5-mIoU | NãoC90 | NãoC95 | NºF95 | ||
C+L | vitb-sax1 (408 MB) | 85,41 | 7,47 | 11.94 | 731 | 90,13 | 5,46 | 13h31 | 177 |
C+L | vitb-sax2 (435 MB) | 85,71 | 7.18 | 11.52 | 700 | 89,85 | 5,34 | 12h80 | 163 |
C+L+HQ | vitb-sax2 (435 MB) | 91,75 | 5.32 | 9.42 | 583 | 91,87 | 4,43 | 10,73 | 123 |
Para avaliação de latência SAT, consulte eval_sat_latency.ipynb.
Fornecemos um script ( run_train.sh
) para treinar nossos modelos no conjunto de dados HQSeg-44K. Você pode começar a treinar com os seguintes comandos. Por padrão, usamos 4 GPUs A6000 para treinamento.
# train vitb-sax1 model on coco+lvis
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax1.py
torchrun --nproc-per-node=4 --master-port 29504 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3
# train vitb-sax2 model on coco+lvis
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_cocolvis_sax2.py
torchrun --nproc-per-node=4 --master-port 29505 ./segnext/train.py ${MODEL_CONFIG} --batch-size=16 --gpus=0,1,2,3
# finetune vitb-sax2 model on hqseg-44k
MODEL_CONFIG=./segnext/models/default/plainvit_base1024_hqseg44k_sax2.py
torchrun --nproc-per-node=4 --master-port 29506 ./segnext/train.py ${MODEL_CONFIG} --batch-size=12 --gpus=0,1,2,3 --weights ./weights/vitb_sa2_cocolvis_epoch_90.pth
@article { liu2024rethinking ,
title = { Rethinking Interactive Image Segmentation with Low Latency, High Quality, and Diverse Prompts } ,
author = { Liu, Qin and Cho, Jaemin and Bansal, Mohit and Niethammer, Marc } ,
journal = { arXiv preprint arXiv:2404.00741 } ,
year = { 2024 }
}