DLT
1.0.0
[ICCV 23] DLT:使用聯合離散連續擴散佈局變壓器產生條件佈局
該存儲庫是 DLT 論文的官方實現。請參閱論文以了解更多詳細信息,並參閱項目頁面以了解總體概述。
無條件 | 類別 | 類別+尺寸 |
---|---|---|
所有相關要求都列在environment.yml 中。我們建議使用 conda 創建適當的環境並安裝依賴項:
conda env create -f environment.yml
conda activate dlt
請在以下網頁下載公用資料集。將其放入您的資料夾中並相應地更新./dlt/configs/remote/dataset_config.yaml
。
您可以使用 configs 資料夾中的任何配置腳本來訓練模型。例如,如果要在 publaynet 資料集上訓練提供的 DLT 模型,指令如下:
cd dlt
python main.py --config configs/remote/dlt_publaynet_config.yaml --workdir < WORKDIR >
請注意,程式碼與加速器無關。如果您不想將結果記錄到 wandb,只需在 args 中設定--workdir test
即可。
若要產生用於在測試集上進行評估的樣本,請執行下列步驟:
# put weights in config.logs folder
DATASET = " publaynet " # or "rico" or "magazine"
python generate_samples.py --config configs/remote/dlt_{ $DATASET }_config.yaml \
--workdir < WORKDIR > --epoch < EPOCH > --cond_type < COND_TYPE > \
--save True
# get all the metrics
# update path to pickle file in dlt/evaluation/metric_comp.py
./download_fid_model.sh
python metric_comp.py
其中<COND_TYPE>
可以是:(all, Whole_box, loc) - (無條件,類別,類別+大小), <EPOCH>
是要評估的模型的紀元號, <WORKDIR>
是路徑保存模型權重的資料夾(例如rico_final)。如果save
True,產生的樣本將保存在logs/<WORKDIR>/samples
資料夾中。
它的輸出是帶有生成樣本的 pickle 檔案。您可以使用它來計算指標。
訓練後包含權重的資料夾具有以下結構:
logs
├── magazine_final
│ ├── checkpoints
│ └── samples
├── publaynet_final
│ ├── checkpoints
│ └── samples
└── rico_final
├── checkpoints
└── samples
如果您發現此程式碼對您的研究有用,請引用我們的論文:
@misc{levi2023dlt,
title={DLT: Conditioned layout generation with Joint Discrete-Continuous Diffusion Layout Transformer},
author={Elad Levi and Eli Brosh and Mykola Mykhailych and Meir Perez},
year={2023},
eprint={2303.03755},
archivePrefix={arXiv},
primaryClass={cs.CV}
}