Este repositorio contiene el código base para modelos de consistencia, implementado usando PyTorch para realizar experimentos a gran escala en ImageNet-64, LSUN Bedroom-256 y LSUN Cat-256. Hemos basado nuestro repositorio en openai/difusión guiada, que se lanzó inicialmente bajo la licencia del MIT. Nuestras modificaciones han permitido el soporte para la destilación de consistencia, el entrenamiento de consistencia, así como varios algoritmos de muestreo y edición discutidos en el artículo.
El repositorio de experimentos CIFAR-10 está en JAX y se puede encontrar en openai/consistency_models_cifar10.
Hemos publicado puntos de control para los modelos principales en el documento. Antes de utilizar estos modelos, revise la tarjeta del modelo correspondiente para comprender el uso previsto y las limitaciones de estos modelos.
Aquí están los enlaces de descarga para cada modelo de punto de control:
EDM en ImageNet-64: edm_imagenet64_ema.pt
CD en ImageNet-64 con métrica l2: cd_imagenet64_l2.pt
CD en ImageNet-64 con métrica LPIPS: cd_imagenet64_lpips.pt
TC en ImageNet-64: ct_imagenet64.pt
EDM en LSUN Dormitorio-256: edm_bedroom256_ema.pt
CD en LSUN Bedroom-256 con métrica l2: cd_bedroom256_l2.pt
CD en LSUN Bedroom-256 con métrica LPIPS: cd_bedroom256_lpips.pt
CT en LSUN Dormitorio-256: ct_bedroom256.pt
EDM en LSUN Cat-256: edm_cat256_ema.pt
CD en LSUN Cat-256 con métrica l2: cd_cat256_l2.pt
CD en LSUN Cat-256 con métrica LPIPS: cd_cat256_lpips.pt
CT en LSUN Cat-256: ct_cat256.pt
Para instalar todos los paquetes en este código base junto con sus dependencias, ejecute
instalación de pip -e.
Para instalar con Docker, ejecute los siguientes comandos:
cd docker && hacer compilación && hacer ejecutar
Proporcionamos ejemplos de capacitación EDM, destilación de consistencia, capacitación de consistencia, generación de un solo paso y generación de varios pasos en scripts/launch.sh.
Para comparar diferentes modelos generativos, utilizamos FID, Precision, Recall y Inception Score. Todas estas métricas se pueden calcular utilizando lotes de muestras almacenadas en archivos .npz
(numpy). Se pueden evaluar muestras con cm/evaluaciones/evaluator.py de la misma manera que se describe en openai/guided-diffusion, con lotes de conjuntos de datos de referencia proporcionados allí.
Los modelos de consistencia son compatibles con ? difusores a través de la clase ConsistencyModelPipeline
. A continuación proporcionamos un ejemplo:
importar antorcha de difusores importar ConsistencyModelPipelinedevice = "cuda"# Cargar el cd_imagenet64_l2 checkpoint.model_id_or_path = "openai/diffusers-cd_imagenet64_l2"pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)pipe.to(device)# Onestep Samplingimage = pipe(num_inference_steps=1).images[0]image.save("consistency_model_onestep_sample.png")# Muestreo en un solo paso, generación de imágenes condicional de clase# La etiqueta de clase 145 de ImageNet-64 corresponde a los pingüinos reyclass_id = 145class_id = torch.tensor( class_id, dtype=torch.long)imagen = pipe(num_inference_steps=1, class_labels=class_id).images[0]image.save("consistency_model_onestep_sample_penguin.png")# Muestreo de varios pasos, generación de imágenes condicionales de clase# Los pasos de tiempo se pueden especificar explícitamente; los pasos de tiempo particulares a continuación son del repositorio original de Github.# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77image = pipe(timesteps=[22, 0], class_labels=class_id) .images[0]image.save("consistency_model_multistep_sample_penguin.png")
Puede acelerar aún más el proceso de inferencia utilizando torch.compile()
en pipe.unet
(solo compatible con PyTorch 2.0). Para obtener más detalles, consulte la documentación oficial. Este apoyo fue contribuido a ? difusores de dg845 y ayushtues.
Si encuentra útil este método y/o código, considere citar
@artículo{song2023consistency, título={Modelos de consistencia}, autor={Song, Yang y Dhariwal, Prafulla y Chen, Mark y Sutskever, Ilya}, diario={arXiv preprint arXiv:2303.01469}, año={2023}, }