Imagen: Estas personas no son reales; fueron producidas por nuestro generador que permite controlar diferentes aspectos de la imagen.
Este repositorio contiene la implementación oficial de TensorFlow del siguiente documento:
Una arquitectura generadora basada en estilos para redes generativas adversarias
Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
https://arxiv.org/abs/1812.04948Resumen: Proponemos una arquitectura generadora alternativa para redes generativas adversarias, tomando prestado de la literatura sobre transferencia de estilos. La nueva arquitectura conduce a una separación no supervisada y aprendida automáticamente de atributos de alto nivel (por ejemplo, pose e identidad cuando se entrena en rostros humanos) y una variación estocástica en las imágenes generadas (por ejemplo, pecas, cabello), y permite una escala intuitiva. control específico de la síntesis. El nuevo generador mejora el estado del arte en términos de métricas de calidad de distribución tradicionales, conduce a propiedades de interpolación claramente mejores y también desenreda mejor los factores de variación latentes. Para cuantificar la calidad de la interpolación y el desenredo, proponemos dos métodos nuevos y automatizados que son aplicables a cualquier arquitectura de generador. Finalmente, presentamos un conjunto de datos de rostros humanos nuevo, muy variado y de alta calidad.
Para consultas comerciales, visite nuestro sitio web y envíe el formulario: Licencia de investigación de NVIDIA
★★★ NUEVO: StyleGAN2-ADA-PyTorch ya está disponible; mira la lista completa de versiones aquí ★★★
El material relacionado con nuestro artículo está disponible a través de los siguientes enlaces:
Documento: https://arxiv.org/abs/1812.04948
Vídeo: https://youtu.be/kSLJriaOumA
Código: https://github.com/NVlabs/stylegan
FFHQ: https://github.com/NVlabs/ffhq-dataset
Se puede encontrar material adicional en Google Drive:
Camino | Descripción |
---|---|
EstiloGAN | Carpeta principal. |
├ stylegan-paper.pdf | Versión de alta calidad del PDF en papel. |
├ stylegan-video.mp4 | Versión de alta calidad del vídeo resultante. |
├ imágenes | Imágenes de ejemplo producidas con nuestro generador. |
│ ├ imágenes-representativas | Imágenes de alta calidad para usar en artículos, publicaciones de blogs, etc. |
│ └ 100.000 imágenes generadas | 100.000 imágenes generadas para diferentes cantidades de truncamiento. |
│ ├ ffhq-1024x1024 | Generado utilizando el conjunto de datos Flickr-Faces-HQ a 1024×1024. |
│ ├ dormitorios-256x256 | Generado utilizando el conjunto de datos LSUN Bedroom a 256 × 256. |
│ ├ coches-512x384 | Generado utilizando el conjunto de datos LSUN Car a 512×384. |
│ └ gatos-256x256 | Generado utilizando el conjunto de datos LSUN Cat a 256 × 256. |
├ vídeos | Vídeos de ejemplo producidos con nuestro generador. |
│ └ videoclips-de-alta-calidad | Segmentos individuales del vídeo resultante en formato MP4 de alta calidad. |
├ conjunto de datos ffhq | Datos sin procesar para el conjunto de datos de Flickr-Faces-HQ. |
└ redes | Redes previamente entrenadas como instancias encurtidas de dnnlib.tflib.Network. |
├ estilogan-ffhq-1024x1024.pkl | StyleGAN entrenado con el conjunto de datos Flickr-Faces-HQ a 1024×1024. |
├ stylegan-celebahq-1024x1024.pkl | StyleGAN entrenado con el conjunto de datos CelebA-HQ a 1024×1024. |
├ stylegan-dormitorios-256x256.pkl | StyleGAN entrenado con el conjunto de datos LSUN Bedroom a 256×256. |
├ stylegan-cars-512x384.pkl | StyleGAN entrenado con el conjunto de datos LSUN Car a 512 × 384. |
├ stylegan-gatos-256x256.pkl | StyleGAN entrenado con el conjunto de datos LSUN Cat a 256 × 256. |
└ métricas | Redes auxiliares para las métricas de calidad y desenredado. |
├ inicio_v3_features.pkl | Clasificador estándar Inception-v3 que genera un vector de características sin formato. |
├ vgg16_zhang_perceptual.pkl | Métrica LPIPS estándar para estimar la similitud perceptiva. |
├ celebahq-clasificador-00-male.pkl | Clasificador binario entrenado para detectar un único atributo de CelebA-HQ. |
└ ⋯ | Consulte la lista de archivos para conocer las redes restantes. |
Todo el material, excepto el conjunto de datos Flickr-Faces-HQ, está disponible bajo la licencia Creative Commons BY-NC 4.0 de NVIDIA Corporation. Puede utilizar, redistribuir y adaptar el material para fines no comerciales , siempre que dé el crédito apropiado citando nuestro artículo e indicando cualquier cambio que haya realizado.
Para obtener información sobre la licencia sobre el conjunto de datos FFHQ, consulte el repositorio de Flickr-Faces-HQ.
inception_v3_features.pkl
e inception_v3_softmax.pkl
se derivan de la red Inception-v3 previamente entrenada por Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens y Zbigniew Wojna. La red se compartió originalmente bajo la licencia Apache 2.0 en el repositorio de TensorFlow Models.
vgg16.pkl
y vgg16_zhang_perceptual.pkl
se derivan de la red VGG-16 previamente entrenada por Karen Simonyan y Andrew Zisserman. La red se compartió originalmente bajo la licencia Creative Commons BY 4.0 en la página del proyecto Redes convolucionales muy profundas para el reconocimiento visual a gran escala.
vgg16_zhang_perceptual.pkl
se deriva además de los pesos LPIPS previamente entrenados por Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman y Oliver Wang. Los pesos se compartieron originalmente bajo la licencia "simplificada" de cláusula 2 BSD en el repositorio PerceptualSimilarity.
Tanto Linux como Windows son compatibles, pero recomendamos encarecidamente Linux por motivos de rendimiento y compatibilidad.
Instalación de Python 3.6 de 64 bits. Recomendamos Anaconda3 con numpy 1.14.3 o posterior.
TensorFlow 1.10.0 o posterior con soporte para GPU.
Una o más GPU NVIDIA de gama alta con al menos 11 GB de DRAM. Recomendamos NVIDIA DGX-1 con 8 GPU Tesla V100.
Controlador NVIDIA 391.35 o posterior, kit de herramientas CUDA 9.0 o posterior, cuDNN 7.3.1 o posterior.
En pretrained_example.py se proporciona un ejemplo mínimo del uso de un generador StyleGAN previamente entrenado. Cuando se ejecuta, el script descarga un generador StyleGAN previamente entrenado desde Google Drive y lo utiliza para generar una imagen:
> python pretrained_example.py Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done Gs Params OutputShape WeightShape --- --- --- --- latents_in - (?, 512) - ... images_out - (?, 3, 1024, 1024) - --- --- --- --- Total 26219627 > ls results example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
Se proporciona un ejemplo más avanzado en generate_figures.py. El guión reproduce las figuras de nuestro artículo para ilustrar la mezcla de estilos, las entradas de ruido y el truncamiento:
> python generate_figures.py results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6 results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_ results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
Las redes previamente entrenadas se almacenan como archivos pickle estándar en Google Drive:
# Load pre-trained network. url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: _G, _D, Gs = pickle.load(f) # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
El código anterior descarga el archivo y lo descomprime para generar 3 instancias de dnnlib.tflib.Network. Para generar imágenes, normalmente querrás utilizar Gs
; las otras dos redes se proporcionan para que estén completas. Para que pickle.load()
funcione, necesitará tener el directorio fuente dnnlib
en su PYTHONPATH y un tf.Session
configurado como predeterminado. La sesión se puede inicializar llamando a dnnlib.tflib.init_tf()
.
Hay tres formas de utilizar el generador previamente entrenado:
Utilice Gs.run()
para operación en modo inmediato donde las entradas y salidas son matrices numerosas:
# Pick latent vector. rnd = np.random.RandomState(5) latents = rnd.randn(1, Gs.input_shape[1]) # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
El primer argumento es un lote de vectores latentes de forma [num, 512]
. El segundo argumento está reservado para etiquetas de clase (no utilizado por StyleGAN). Los argumentos de palabras clave restantes son opcionales y se pueden usar para modificar aún más la operación (ver más abajo). La salida es un lote de imágenes, cuyo formato lo dicta el argumento output_transform
.
Utilice Gs.get_output_for()
para incorporar el generador como parte de una expresión TensorFlow más grande:
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) images = tflib.convert_images_to_uint8(images) result_expr.append(inception_clone.get_output_for(images))
El código anterior es de metrics/frechet_inception_distance.py. Genera un lote de imágenes aleatorias y las envía directamente a la red Inception-v3 sin tener que convertir los datos en numerosas matrices intermedias.
Busque Gs.components.mapping
y Gs.components.synthesis
para acceder a subredes individuales del generador. De manera similar a Gs
, las subredes se representan como instancias independientes de dnnlib.tflib.Network:
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
El código anterior es de generate_figures.py. Primero transforma un lote de vectores latentes en el espacio W intermedio usando la red de mapeo y luego convierte estos vectores en un lote de imágenes usando la red de síntesis. La matriz dlatents
almacena una copia separada del mismo vector w para cada capa de la red de síntesis para facilitar la mezcla de estilos.
Los detalles exactos del generador se definen en Training/networks_stylegan.py (consulte G_style
, G_mapping
y G_synthesis
). Se pueden especificar los siguientes argumentos de palabras clave para modificar el comportamiento al llamar run()
y get_output_for()
:
truncation_psi
y truncation_cutoff
controlan el truco de truncamiento que se realiza de forma predeterminada cuando se usa Gs
(ψ=0,7, cutoff=8). Se puede desactivar configurando truncation_psi=1
o is_validation=True
y la calidad de la imagen se puede mejorar aún más a costa de la variación configurando, por ejemplo, truncation_psi=0.5
. Tenga en cuenta que el truncamiento siempre está deshabilitado cuando se utilizan las subredes directamente. El promedio w necesario para realizar manualmente el truco de truncamiento se puede buscar usando Gs.get_var('dlatent_avg')
.
randomize_noise
determina si se deben volver a aleatorizar las entradas de ruido para cada imagen generada ( True
, predeterminado) o si se deben usar valores de ruido específicos para todo el minibatch ( False
). Se puede acceder a los valores específicos a través de las instancias tf.Variable
que se encuentran usando [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
.
Cuando utilice la red de mapeo directamente, puede especificar dlatent_broadcast=None
para deshabilitar la duplicación automática de dlatents
en las capas de la red de síntesis.
El rendimiento en tiempo de ejecución se puede ajustar mediante structure='fixed'
y dtype='float16'
. El primero desactiva el soporte para el crecimiento progresivo, que no es necesario para un generador completamente capacitado, y el segundo realiza todos los cálculos utilizando aritmética de coma flotante de media precisión.
Los scripts de capacitación y evaluación operan en conjuntos de datos almacenados como TFRecords de resolución múltiple. Cada conjunto de datos está representado por un directorio que contiene los mismos datos de imagen en varias resoluciones para permitir una transmisión eficiente. Hay un archivo *.tfrecords separado para cada resolución y, si el conjunto de datos contiene etiquetas, también se almacenan en un archivo separado. De forma predeterminada, los scripts esperan encontrar los conjuntos de datos en datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords
. El directorio se puede cambiar editando config.py:
result_dir = 'results' data_dir = 'datasets' cache_dir = 'cache'
Para obtener el conjunto de datos FFHQ ( datasets/ffhq
), consulte el repositorio de Flickr-Faces-HQ.
Para obtener el conjunto de datos de CelebA-HQ ( datasets/celebahq
), consulte el repositorio de GAN progresiva.
Para obtener otros conjuntos de datos, incluido LSUN, consulte las páginas de sus proyectos correspondientes. Los conjuntos de datos se pueden convertir a TFRecords de resolución múltiple utilizando el dataset_tool.py proporcionado:
> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256 > python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384 > python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256 > python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10 > python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
Una vez configurados los conjuntos de datos, puede entrenar sus propias redes StyleGAN de la siguiente manera:
Edite train.py para especificar el conjunto de datos y la configuración de entrenamiento descomentando o editando líneas específicas.
Ejecute el script de entrenamiento con python train.py
.
Los resultados se escriben en un directorio recién creado results/<ID>-<DESCRIPTION>
.
La capacitación puede tardar varios días (o semanas) en completarse, según la configuración.
De forma predeterminada, train.py
está configurado para entrenar StyleGAN de la más alta calidad (configuración F en la Tabla 1) para el conjunto de datos FFHQ con una resolución de 1024 × 1024 utilizando 8 GPU. Tenga en cuenta que hemos utilizado 8 GPU en todos nuestros experimentos. Es posible que entrenar con menos GPU no produzca resultados idénticos; si desea comparar con nuestra técnica, le recomendamos encarecidamente utilizar la misma cantidad de GPU.
Tiempos de entrenamiento esperados para la configuración predeterminada usando GPU Tesla V100:
GPU | 1024×1024 | 512×512 | 256×256 |
---|---|---|---|
1 | 41 días 4 horas | 24 días 21 horas | 14 días 22 horas |
2 | 21 días 22 horas | 13 días 7 horas | 9 días 5 horas |
4 | 11 días 8 horas | 7 días 0 horas | 4 días 21 horas |
8 | 6 días 14 horas | 4 días 10 horas | 3 días 8 horas |
Las métricas de calidad y desenredo utilizadas en nuestro artículo se pueden evaluar utilizando run_metrics.py. De forma predeterminada, el script evaluará la distancia inicial de Fréchet ( fid50k
) para el generador FFHQ previamente entrenado y escribirá los resultados en un directorio recién creado en results
. El comportamiento exacto se puede cambiar descomentando o editando líneas específicas en run_metrics.py.
Tiempo de evaluación esperado y resultados para el generador FFHQ previamente entrenado usando una GPU Tesla V100:
Métrico | Tiempo | Resultado | Descripción |
---|---|---|---|
fid50k | 16 minutos | 4.4159 | Distancia de inicio de Fréchet utilizando 50.000 imágenes. |
personas_zfull | 55 minutos | 664.8854 | Longitud de ruta perceptual para rutas completas en Z. |
personas_wfull | 55 minutos | 233.3059 | Longitud de ruta perceptual para rutas completas en W. |
personas_zend | 55 minutos | 666.1057 | Longitud de ruta perceptual para puntos finales de ruta en Z. |
personas_wend | 55 minutos | 197.2266 | Longitud de ruta perceptual para puntos finales de ruta en W. |
es | 10 horas | z: 165.0106 ancho: 3,7447 | Separabilidad lineal en Z y W. |
Tenga en cuenta que los resultados exactos pueden variar de una ejecución a otra debido a la naturaleza no determinista de TensorFlow.
Agradecemos a Jaakko Lehtinen, David Luebke y Tuomas Kynkäänniemi por sus discusiones en profundidad y sus útiles comentarios; Janne Hellsten, Tero Kuosmanen y Pekka Jänis por la infraestructura informática y la ayuda con la publicación del código.