Ce référentiel contient la base de code des modèles de cohérence, implémentés à l'aide de PyTorch pour mener des expériences à grande échelle sur ImageNet-64, LSUN Bedroom-256 et LSUN Cat-256. Nous avons basé notre référentiel sur openai/guided-diffusion, qui a été initialement publié sous licence MIT. Nos modifications ont permis la prise en charge de la distillation de cohérence, de la formation à la cohérence, ainsi que de plusieurs algorithmes d'échantillonnage et d'édition abordés dans l'article.
Le référentiel des expériences CIFAR-10 se trouve en JAX et peut être trouvé sur openai/consistency_models_cifar10.
Nous avons publié des points de contrôle pour les principaux modèles dans le document. Avant d'utiliser ces modèles, veuillez consulter la carte de modèle correspondante pour comprendre l'utilisation prévue et les limites de ces modèles.
Voici les liens de téléchargement pour chaque point de contrôle de modèle :
GED sur ImageNet-64 : edm_imagenet64_ema.pt
CD sur ImageNet-64 avec métrique l2 : cd_imagenet64_l2.pt
CD sur ImageNet-64 avec métrique LPIPS : cd_imagenet64_lpips.pt
CT sur ImageNet-64 : ct_imagenet64.pt
EDM sur LSUN Chambre-256 : edm_bed256_ema.pt
CD sur LSUN Bedroom-256 avec métrique l2 : cd_room256_l2.pt
CD sur LSUN Bedroom-256 avec métrique LPIPS : cd_room256_lpips.pt
CT sur LSUN Chambre-256 : ct_chambre256.pt
EDM sur LSUN Cat-256 : edm_cat256_ema.pt
CD sur LSUN Cat-256 avec métrique l2 : cd_cat256_l2.pt
CD sur LSUN Cat-256 avec métrique LPIPS : cd_cat256_lpips.pt
CT sur LSUN Cat-256 : ct_cat256.pt
Pour installer tous les packages de cette base de code ainsi que leurs dépendances, exécutez
pip install -e .
Pour installer avec Docker, exécutez les commandes suivantes :
cd docker && faire construire && faire exécuter
Nous fournissons des exemples de formation EDM, de distillation de cohérence, de formation à la cohérence, de génération en une seule étape et de génération en plusieurs étapes dans scripts/launch.sh.
Pour comparer différents modèles génératifs, nous utilisons FID, Precision, Recall et Inception Score. Ces métriques peuvent toutes être calculées à l'aide de lots d'échantillons stockés dans des fichiers .npz
(numpy). On peut évaluer des échantillons avec cm/evaluations/evaluator.py de la même manière que décrit dans openai/guided-diffusion, avec des lots d'ensembles de données de référence fournis.
Les modèles de cohérence sont pris en charge dans ? diffuseurs via la classe ConsistencyModelPipeline
. Ci-dessous, nous fournissons un exemple :
importer torchfrom diffuseurs import ConsistencyModelPipelinedevice = "cuda"# Charger le cd_imagenet64_l2 checkpoint.model_id_or_path = "openai/diffusers-cd_imagenet64_l2"pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)pipe.to(device)# Onestep Samplingimage = pipe(num_inference_steps=1).images[0]image.save("consistency_model_onestep_sample.png")# Échantillonnage en une étape, génération d'images conditionnelles par classe# L'étiquette de classe ImageNet-64 145 correspond aux manchots royauxclass_id = 145class_id = torch.tensor(class_id, dtype=torch.long)image = pipe(num_inference_steps=1, class_labels=class_id).images[0]image.save("consistency_model_onestep_sample_penguin.png")# Échantillonnage en plusieurs étapes, génération d'images conditionnelles de classe# Les pas de temps peuvent être explicitement spécifiés ; les pas de temps particuliers ci-dessous proviennent du dépôt Github d'origine.# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77image = pipe(timesteps=[22, 0], class_labels=class_id) .images[0]image.save("consistency_model_multistep_sample_penguin.png")
Vous pouvez accélérer davantage le processus d'inférence en utilisant torch.compile()
sur pipe.unet
(uniquement pris en charge à partir de PyTorch 2.0). Pour plus de détails, veuillez consulter la documentation officielle. Ce soutien a été apporté à ? diffuseurs par dg845 et ayushtues.
Si vous trouvez cette méthode et/ou ce code utile, pensez à citer
@article{song2023consistency, title={Modèles de cohérence}, author={Song, Yang et Dhariwal, Prafulla et Chen, Mark et Sutskever, Ilya}, journal={arXiv preprint arXiv:2303.01469}, year={2023}, }