Dieses Repository enthält die Codebasis für Konsistenzmodelle, die mit PyTorch zur Durchführung groß angelegter Experimente auf ImageNet-64, LSUN Bedroom-256 und LSUN Cat-256 implementiert wurden. Wir haben unser Repository auf openai/guided-diffusion basiert, das ursprünglich unter der MIT-Lizenz veröffentlicht wurde. Unsere Modifikationen haben die Unterstützung der Konsistenzdestillation, des Konsistenztrainings sowie mehrerer im Artikel diskutierter Sampling- und Bearbeitungsalgorithmen ermöglicht.
Das Repository für CIFAR-10-Experimente befindet sich in JAX und ist unter openai/consistency_models_cifar10 zu finden.
Wir haben in der Zeitung Prüfpunkte für die Hauptmodelle veröffentlicht. Bevor Sie diese Modelle verwenden, lesen Sie bitte die entsprechende Modellkarte, um den Verwendungszweck und die Einschränkungen dieser Modelle zu verstehen.
Hier sind die Download-Links für jeden Modellkontrollpunkt:
EDM auf ImageNet-64: edm_imagenet64_ema.pt
CD auf ImageNet-64 mit l2-Metrik: cd_imagenet64_l2.pt
CD auf ImageNet-64 mit LPIPS-Metrik: cd_imagenet64_lpips.pt
CT auf ImageNet-64: ct_imagenet64.pt
EDM auf LSUN Schlafzimmer-256: edm_schlafzimmer256_ema.pt
CD auf LSUN Bedroom-256 mit l2-Metrik: cd_schlafzimmer256_l2.pt
CD auf LSUN Bedroom-256 mit LPIPS-Metrik: cd_schlafzimmer256_lpips.pt
CT auf LSUN Schlafzimmer-256: ct_schlafzimmer256.pt
EDM auf LSUN Cat-256: edm_cat256_ema.pt
CD auf LSUN Cat-256 mit l2-Metrik: cd_cat256_l2.pt
CD auf LSUN Cat-256 mit LPIPS-Metrik: cd_cat256_lpips.pt
CT auf LSUN Cat-256: ct_cat256.pt
Um alle Pakete in dieser Codebasis zusammen mit ihren Abhängigkeiten zu installieren, führen Sie Folgendes aus:
pip install -e .
Führen Sie zur Installation mit Docker die folgenden Befehle aus:
cd docker && make build && make run
Wir bieten Beispiele für EDM-Training, Konsistenzdestillation, Konsistenztraining, Einzelschrittgenerierung und Mehrschrittgenerierung in scripts/launch.sh.
Um verschiedene generative Modelle zu vergleichen, verwenden wir FID, Precision, Recall und Inception Score. Diese Metriken können alle anhand von Probenstapeln berechnet werden, die in .npz
Dateien (Numpy) gespeichert sind. Man kann Proben mit cm/evaluations/evaluator.py auf die gleiche Weise auswerten, wie es in openai/guided-diffusion beschrieben ist, mit darin bereitgestellten Referenzdatensatz-Batches.
Konsistenzmodelle werden unterstützt in ? Diffusoren über die ConsistencyModelPipeline
-Klasse. Nachfolgend stellen wir ein Beispiel vor:
import Torchfrom diffusers import ConsistencyModelPipelinedevice = "cuda"# Laden Sie den cd_imagenet64_l2 checkpoint.model_id_or_path = "openai/diffusers-cd_imagenet64_l2"pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, Torch_dtype=torch.float16)pipe.to(device)# Onestep Samplingimage = pipe(num_inference_steps=1).images[0]image.save("consistency_model_onestep_sample.png")# One-Step-Sampling, klassenbedingte Bildgenerierung# ImageNet-64-Klassenbezeichnung 145 entspricht Königspinguinenclass_id = 145class_id = Torch.tensor(class_id, dtype=torch.long)image = pipe(num_inference_steps=1, class_labels=class_id).images[0]image.save("consistency_model_onestep_sample_penguin.png")# Mehrstufige Abtastung, klassenbedingte Bildgenerierung# Zeitschritte können explizit angegeben werden; Die einzelnen Zeitschritte unten stammen aus dem ursprünglichen Github-Repo. # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77image = pipe(timesteps=[22, 0], class_labels=class_id) .images[0]image.save("consistency_model_multistep_sample_penguin.png")
Sie können den Inferenzprozess weiter beschleunigen, indem Sie torch.compile()
auf pipe.unet
verwenden (wird nur ab PyTorch 2.0 unterstützt). Weitere Einzelheiten finden Sie in der offiziellen Dokumentation. Diese Unterstützung wurde zu ? beigetragen? Diffusoren von dg845 und ayushtues.
Wenn Sie diese Methode und/oder diesen Code nützlich finden, ziehen Sie bitte eine Zitierung in Betracht
@article{song2023consistency, title={Konsistenzmodelle}, Autor={Song, Yang und Dhariwal, Prafulla und Chen, Mark und Sutskever, Ilya}, Zeitschrift={arXiv preprint arXiv:2303.01469}, Jahr={2023}, }