This repo contains the Pytorch implementation of LDMSeg: a simple latent diffusion approach for panoptic segmentation and Mask inpainting. The provided code inlcudes both the training and evaluation.
A Simple Latent Diffusion Approach for Panoptic Segmentation and Mask Inpainting
Wouter Van Gansbeke and Bert De Brabandere
This paper presents a conditional latent diffusion approach to tackle the task of panoptic segmentation. The aim is to omit the need for specialized architectures (e.g., region-proposal-networks or object queries), complex loss functions (e.g., Hungarian matching or based on bounding boxes), and additional post-processing methods (e.g., clustering, NMS, or object pasting). As a result, we rely on Stable Diffusion, which is a task-agnostic framework. The presented approach consists of two-steps: (1) project the panoptic segmentation masks to a latent space with a shallow auto-encoder; (2) train a diffusion model in latent space, conditioned on RGB images.
Key Contributions: Our contributions are threefold:
The code runs with recent Pytorch versions, e.g. 2.0. Further, you can create a python environment with Anaconda:
conda create -n LDMSeg python=3.11
conda activate LDMSeg
We recommend to follow the automatic installatation (see tools/scripts/install_env.sh
). Run the following commands to install the project in editable mode. Note that all dependencies will be installed automatically.
As this might not always work (e.g., due to CUDA or gcc issues), please have a look at the manual installation steps.
python -m pip install -e .
pip install git+https://github.com/facebookresearch/detectron2.git
pip install git+https://github.com/cocodataset/panopticapi.git
The most important packages can be quickly installed with pip as:
pip install torch torchvision einops # Main framework
pip install diffusers transformers xformers accelerate timm # For using pretrained models
pip install scipy opencv-python # For augmentations or loss
pip install pyyaml easydict hydra-core # For using config files
pip install termcolor wandb # For printing and logging
See data/environment.yml
for a copy of my environment. We also rely on some dependencies from detectron2 and panopticapi. Please follow their docs.
We currently support the COCO dataset. Please follow the docs for installing the images and their corresponding panoptic segmentation masks. Also, take a look at the ldmseg/data/
directory for a few examples on the COCO dataset. As a sidenote, the adopted structure should be fairly standard:
.
└── coco
├── annotations
├── panoptic_semseg_train2017
├── panoptic_semseg_val2017
├── panoptic_train2017 -> annotations/panoptic_train2017
├── panoptic_val2017 -> annotations/panoptic_val2017
├── test2017
├── train2017
└── val2017
Last but not least, change the paths in configs/env/root_paths.yml
to your dataset root and your desired output directory respectively.
The presented approach is two-pronged: First, we train an auto-encoder to represent segmentation maps in a lower dimensional space (e.g., 64x64). Next, we start from pretrained Latent Diffusion Models (LDM), particularly Stable Diffusion, to train a model which can generate panoptic masks from RGB images.
The models can be trained by running the the following commands. By default we will train on the COCO dataset with the base config file defined in tools/configs/base/base.yaml
. Note that this file will be automatically loaded as we rely on the hydra
package.
python -W ignore tools/main_ae.py
datasets=coco
base.train_kwargs.fp16=True
base.optimizer_name=adamw
base.optimizer_kwargs.lr=1e-4
base.optimizer_kwargs.weight_decay=0.05
More details on passing arguments can be found in tools/scripts/train_ae.sh
. For example, I run this model for 50k iterations on a single GPU of 23 GB with a total batch size of 16.
python -W ignore tools/main_ldm.py
datasets=coco
base.train_kwargs.gradient_checkpointing=True
base.train_kwargs.fp16=True
base.train_kwargs.weight_dtype=float16
base.optimizer_zero_redundancy=True
base.optimizer_name=adamw
base.optimizer_kwargs.lr=1e-4
base.optimizer_kwargs.weight_decay=0.05
base.scheduler_kwargs.weight='max_clamp_snr'
base.vae_model_kwargs.pretrained_path='$AE_MODEL'
$AE_MODEL
denotes the path to the model obtained from the previous step.
More details on passing arguments can be found in tools/scripts/train_diffusion.sh
. For example, I ran this model for 200k iterations on 8 GPUs of 16 GB with a total batch size of 256.
We're planning to release several trained models. The (class-agnostic) PQ metric is provided on the COCO validation set.
Model | #Params | Dataset | Iters | PQ | SQ | RQ | Download link |
---|---|---|---|---|---|---|---|
AE | ~2M | COCO | 66k | - | - | - | Download (23 MB) |
LDM | ~800M | COCO | 200k | 51.7 | 82.0 | 63.0 | Download (3.3 GB) |
Note: A less powerful AE (i.e., less downsampling or upsampling layers) can often benefit inpainting, as we don't perform additional finetuning.
The evaluation should look like:
python -W ignore tools/main_ldm.py
datasets=coco
base.sampling_kwargs.num_inference_steps=50
base.eval_only=True
base.load_path=$PRETRAINED_MODEL_PATH
You can add parameters if necessary. Higher thresholds such as --base.eval_kwargs.count_th 700
or --base.eval_kwargs.mask_th 0.9
can further boost the numbers.
However, we use standard values by thresholding at 0.5 and removing segments with an area smaller than 512 for the evaluation.
To evaluate a pretrained model from above, run tools/scripts/eval.sh
.
Here, we visualize the results:
If you find this repository useful for your research, please consider citing the following paper:
@article{vangansbeke2024ldmseg,
title={a simple latent diffusion approach for panoptic segmentation and mask inpainting},
author={Van Gansbeke, Wouter and De Brabandere, Bert},
journal={arxiv preprint arxiv:2401.10227},
year={2024}
}
For any enquiries, please contact the main author.
This software is released under a creative commons license which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary here.
I'm thankful for all the public repositories (see also references in the code), and in particular for the detectron2 and diffusers libaries.