Ce référentiel contient des définitions de modèle PyTorch, des poids pré-entraînés et un code d'inférence/d'échantillonnage pour notre article explorant la formation faible à forte du transformateur de diffusion pour la génération de texte en image 4K. Vous pouvez trouver plus de visualisations sur notre page de projet.
PixArt-Σ : formation de faible à fort du transformateur de diffusion pour la génération de texte en image 4K
Junsong Chen*, Chongjian Ge*, Enze Xie*†, Yue Wu*, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li
Laboratoire de l'Arche de Noé de Huawei, DLUT, HKU, HKUST
En nous inspirant du précédent projet PixArt-α, nous essaierons de garder ce dépôt aussi simple que possible afin que tous les membres de la communauté PixArt puissent l'utiliser.
? diffusers
utilisant des patchs pour une expérience rapide !-Principal
-Conseils
-Autres
Modèle | Longueur du jeton T5 | VAE | 2K/4K |
---|---|---|---|
PixArt-Σ | 300 | SDXL | ✅ |
PixArt-α | 120 | SD1.5 |
Modèle | Échantillon-1 | Échantillon-2 | Échantillon-3 |
---|---|---|---|
PixArt-Σ | |||
PixArt-α | |||
Rapide | Gros plan, homme barbu aux cheveux gris dans les années 60, observant les passants, en manteau de laine et béret marron , lunettes, cinématique. | Plan du corps, une Française, Photographie, Fond de rues françaises, contre-jour, éclairage de bord, Fujifilm. | Vidéo photoréaliste en gros plan de deux navires pirates s'affrontant alors qu'ils naviguent dans une tasse de café . |
conda create -n pixart python==3.9.0
conda activate pixart
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
git clone https://github.com/PixArt-alpha/PixArt-sigma.git
cd PixArt-sigma
pip install -r requirements.txt
Tout d'abord.
Nous lançons un nouveau dépôt pour créer une base de code plus conviviale et plus compatible. La structure principale du modèle est la même que celle de PixArt-α, vous pouvez toujours développer votre base de fonctions sur le dépôt d'origine. lso, Ce dépôt prendra en charge PixArt-alpha à l'avenir .
Conseil
Vous pouvez désormais entraîner votre modèle sans extraction préalable de fonctionnalités . Nous reformons la structure des données dans la base de code PixArt-α, afin que tout le monde puisse commencer à s'entraîner, à inférer et à visualiser dès le début sans aucune douleur.
Téléchargez d'abord l'ensemble de données sur les jouets. La structure de l'ensemble de données pour la formation est :
cd ./pixart-sigma-toy-dataset
Dataset Structure
├──InternImgs/ (images are saved here)
│ ├──000000000000.png
│ ├──000000000001.png
│ ├──......
├──InternData/
│ ├──data_info.json (meta data)
Optional(?)
│ ├──img_sdxl_vae_features_1024resolution_ms_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│ │ ├──000000000000.npy
│ │ ├──000000000001.npy
│ │ ├──......
│ ├──caption_features_new
│ │ ├──000000000000.npz
│ │ ├──000000000001.npz
│ │ ├──......
│ ├──sharegpt4v_caption_features_new (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│ │ ├──000000000000.npz
│ │ ├──000000000001.npz
│ │ ├──......
# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers
# PixArt-Sigma checkpoints
python tools/download.py # environment eg. HF_ENDPOINT=https://hf-mirror.com can use for HuggingFace mirror
Sélection du fichier de configuration souhaité dans le répertoire des fichiers de configuration.
python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345
train_scripts/train.py
configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py
--load-from output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth
--work-dir output/your_first_pixart-exp
--debug
Pour commencer, installez d’abord les dépendances requises. Assurez-vous d'avoir téléchargé les fichiers de point de contrôle de models (à venir) dans le dossier output/pretrained_models
, puis exécutez-les sur votre ordinateur local :
# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers
# PixArt-Sigma checkpoints
python tools/download.py
# demo launch
python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth --image_size 512 --port 11223
Important
Améliorez vos diffusers
pour rendre le PixArtSigmaPipeline
disponible !
pip install git+https://github.com/huggingface/diffusers
Pour diffusers<0.28.0
, consultez ce script pour obtenir de l'aide.
import torch
from diffusers import Transformer2DModel , PixArtSigmaPipeline
device = torch . device ( "cuda:0" if torch . cuda . is_available () else "cpu" )
weight_dtype = torch . float16
transformer = Transformer2DModel . from_pretrained (
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS" ,
subfolder = 'transformer' ,
torch_dtype = weight_dtype ,
use_safetensors = True ,
)
pipe = PixArtSigmaPipeline . from_pretrained (
"PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" ,
transformer = transformer ,
torch_dtype = weight_dtype ,
use_safetensors = True ,
)
pipe . to ( device )
# Enable memory optimizations.
# pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe ( prompt ). images [ 0 ]
image . save ( "./catcus.png" )
pip install git+https://github.com/huggingface/diffusers
# PixArt-Sigma 1024px
DEMO_PORT=12345 python app/app_pixart_sigma.py
# PixArt-Sigma One step Sampler(DMD)
DEMO_PORT=12345 python app/app_pixart_dmd.py
Jetons un coup d'œil à un exemple simple utilisant http://your-server-ip:12345
.
Téléchargez directement depuis Hugging Face
ou courir avec :
pip install git+https://github.com/huggingface/diffusers
python tools/convert_pixart_to_diffusers.py --orig_ckpt_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS.pth --dump_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS --only_transformer=True --image_size=1024 --version sigma
Tous les modèles seront automatiquement téléchargés ici. Vous pouvez également choisir de télécharger manuellement à partir de cette URL.
Modèle | #Params | Chemin du point de contrôle | Télécharger dans OpenXLab |
---|---|---|---|
T5 & SDXL-VAE | 4,5 milliards | Diffuseurs : pixart_sigma_sdxlvae_T5_diffusers | à venir |
PixArt-Σ-256 | 0,6 B | pth : PixArt-Sigma-XL-2-256x256.pth Diffuseurs : PixArt-Sigma-XL-2-256x256 | à venir |
PixArt-Σ-512 | 0,6 B | pth : PixArt-Sigma-XL-2-512-MS.pth Diffuseurs : PixArt-Sigma-XL-2-512-MS | à venir |
PixArt-α-512-DMD | 0,6 B | Diffuseurs : PixArt-Alpha-DMD-XL-2-512x512 | à venir |
PixArt-Σ-1024 | 0,6 B | pth : PixArt-Sigma-XL-2-1024-MS.pth Diffuseurs : PixArt-Sigma-XL-2-1024-MS | à venir |
PixArt-Σ-2K | 0,6 B | pth : PixArt-Sigma-XL-2-2K-MS.pth Diffuseurs : PixArt-Sigma-XL-2-2K-MS | à venir |
Nous ferons de notre mieux pour libérer
@misc{chen2024pixartsigma,
title={PixArt-Sigma: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation},
author={Junsong Chen and Chongjian Ge and Enze Xie and Yue Wu and Lewei Yao and Xiaozhe Ren and Zhongdao Wang and Ping Luo and Huchuan Lu and Zhenguo Li},
year={2024},
eprint={2403.04692},
archivePrefix={arXiv},
primaryClass={cs.CV}