DLT
1.0.0
[ICCV 23] DLT:使用联合离散连续扩散布局变压器生成条件布局
该存储库是 DLT 论文的官方实现。请参阅论文以了解更多详细信息,并参阅项目页面以了解总体概述。
无条件 | 类别 | 类别+尺寸 |
---|---|---|
所有相关要求都列在environment.yml 中。我们建议使用 conda 创建适当的环境并安装依赖项:
conda env create -f environment.yml
conda activate dlt
请在以下网页下载公共数据集。将其放入您的文件夹中并相应地更新./dlt/configs/remote/dataset_config.yaml
。
您可以使用 configs 文件夹中的任何配置脚本来训练模型。例如,如果要在 publaynet 数据集上训练提供的 DLT 模型,命令如下:
cd dlt
python main.py --config configs/remote/dlt_publaynet_config.yaml --workdir < WORKDIR >
请注意,代码与加速器无关。如果您不想将结果记录到 wandb,只需在 args 中设置--workdir test
即可。
要生成用于在测试集上进行评估的样本,请执行以下步骤:
# put weights in config.logs folder
DATASET = " publaynet " # or "rico" or "magazine"
python generate_samples.py --config configs/remote/dlt_{ $DATASET }_config.yaml \
--workdir < WORKDIR > --epoch < EPOCH > --cond_type < COND_TYPE > \
--save True
# get all the metrics
# update path to pickle file in dlt/evaluation/metric_comp.py
./download_fid_model.sh
python metric_comp.py
其中<COND_TYPE>
可以是:(all, Whole_box, loc) - (无条件,类别,类别+大小), <EPOCH>
是要评估的模型的纪元号, <WORKDIR>
是路径保存模型权重的文件夹(例如 rico_final)。如果save
True,生成的样本将保存在logs/<WORKDIR>/samples
文件夹中。
它的输出是带有生成样本的 pickle 文件。您可以使用它来计算指标。
训练后包含权重的文件夹具有以下结构:
logs
├── magazine_final
│ ├── checkpoints
│ └── samples
├── publaynet_final
│ ├── checkpoints
│ └── samples
└── rico_final
├── checkpoints
└── samples
如果您发现此代码对您的研究有用,请引用我们的论文:
@misc{levi2023dlt,
title={DLT: Conditioned layout generation with Joint Discrete-Continuous Diffusion Layout Transformer},
author={Elad Levi and Eli Brosh and Mykola Mykhailych and Meir Perez},
year={2023},
eprint={2303.03755},
archivePrefix={arXiv},
primaryClass={cs.CV}
}