Image : Ces personnes ne sont pas réelles – elles ont été produites par notre générateur qui permet de contrôler différents aspects de l'image.
Ce référentiel contient l'implémentation officielle de TensorFlow du document suivant :
Une architecture génératrice basée sur le style pour les réseaux adverses génératifs
Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
https://arxiv.org/abs/1812.04948Résumé : Nous proposons une architecture génératrice alternative pour les réseaux adverses génératifs, empruntant à la littérature sur le transfert de style. La nouvelle architecture conduit à une séparation automatiquement apprise et non supervisée des attributs de haut niveau (par exemple, pose et identité lors d'un entraînement sur des visages humains) et à une variation stochastique des images générées (par exemple, taches de rousseur, cheveux), et elle permet une visualisation intuitive et à grande échelle. contrôle spécifique de la synthèse. Le nouveau générateur améliore l'état de l'art en termes de mesures traditionnelles de qualité de distribution, conduit à des propriétés d'interpolation manifestement meilleures et démêle également mieux les facteurs de variation latents. Pour quantifier la qualité de l'interpolation et le démêlage, nous proposons deux nouvelles méthodes automatisées applicables à toute architecture de générateur. Enfin, nous introduisons un nouvel ensemble de données de visages humains très varié et de haute qualité.
Pour toute demande commerciale, veuillez visiter notre site Web et soumettre le formulaire : NVIDIA Research Licensing
★★★ NOUVEAU : StyleGAN2-ADA-PyTorch est maintenant disponible ; voir la liste complète des versions ici ★★★
Le matériel lié à notre article est disponible via les liens suivants :
Article : https://arxiv.org/abs/1812.04948
Vidéo : https://youtu.be/kSLJriaOumA
Code : https://github.com/NVlabs/stylegan
FFHQ : https://github.com/NVlabs/ffhq-dataset
Du matériel supplémentaire peut être trouvé sur Google Drive :
Chemin | Description |
---|---|
StyleGAN | Dossier principal. |
├ stylegan-paper.pdf | Version haute qualité du PDF papier. |
├ stylegan-vidéo.mp4 | Version haute qualité de la vidéo du résultat. |
├images | Exemples d'images produites à l'aide de notre générateur. |
│ ├ images représentatives | Images de haute qualité à utiliser dans des articles, des articles de blog, etc. |
│ └ 100 000 images générées | 100 000 images générées pour différents niveaux de troncature. |
│ ├ffhq-1024x1024 | Généré à l'aide de l'ensemble de données Flickr-Faces-HQ à 1024×1024. |
│ ├ chambres-256x256 | Généré à l'aide de l'ensemble de données LSUN Bedroom à 256 × 256. |
│ ├ voitures-512x384 | Généré à l'aide de l'ensemble de données LSUN Car à 512 × 384. |
│ └ chats-256x256 | Généré à l'aide de l'ensemble de données LSUN Cat à 256 × 256. |
├ vidéos | Exemples de vidéos réalisées à l'aide de notre générateur. |
│ └ clips vidéo de haute qualité | Segments individuels de la vidéo résultante au format MP4 de haute qualité. |
├ ensemble de données ffhq | Données brutes pour l'ensemble de données Flickr-Faces-HQ. |
└ réseaux | Réseaux pré-entraînés en tant qu'instances marinées de dnnlib.tflib.Network. |
├ stylegan-ffhq-1024x1024.pkl | StyleGAN formé avec l'ensemble de données Flickr-Faces-HQ à 1024 × 1024. |
├ stylegan-celebahq-1024x1024.pkl | StyleGAN formé avec l'ensemble de données CelebA-HQ à 1024 × 1024. |
├ stylegan-chambres-256x256.pkl | StyleGAN formé avec l'ensemble de données LSUN Bedroom à 256 × 256. |
├ stylegan-cars-512x384.pkl | StyleGAN formé avec l'ensemble de données LSUN Car à 512 × 384. |
├ stylegan-cats-256x256.pkl | StyleGAN formé avec l'ensemble de données LSUN Cat à 256 × 256. |
└ métriques | Réseaux auxiliaires pour les métriques de qualité et de démêlage. |
├ inception_v3_features.pkl | Classificateur Inception-v3 standard qui génère un vecteur de caractéristiques brut. |
├ vgg16_zhang_perceptual.pkl | Métrique LPIPS standard pour estimer la similarité perceptuelle. |
├ celebahq-classifier-00-male.pkl | Classificateur binaire formé pour détecter un seul attribut de CelebA-HQ. |
└ ⋯ | Veuillez consulter la liste des fichiers pour les réseaux restants. |
Tout le matériel, à l'exclusion de l'ensemble de données Flickr-Faces-HQ, est mis à disposition sous licence Creative Commons BY-NC 4.0 par NVIDIA Corporation. Vous pouvez utiliser, redistribuer et adapter le matériel à des fins non commerciales , à condition de donner le crédit approprié en citant notre article et en indiquant toutes les modifications que vous avez apportées.
Pour obtenir des informations sur la licence concernant l'ensemble de données FFHQ, veuillez vous référer au référentiel Flickr-Faces-HQ.
inception_v3_features.pkl
et inception_v3_softmax.pkl
sont dérivés du réseau Inception-v3 pré-entraîné par Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens et Zbigniew Wojna. Le réseau était initialement partagé sous licence Apache 2.0 sur le référentiel TensorFlow Models.
vgg16.pkl
et vgg16_zhang_perceptual.pkl
sont dérivés du réseau VGG-16 pré-entraîné par Karen Simonyan et Andrew Zisserman. Le réseau a été initialement partagé sous licence Creative Commons BY 4.0 sur la page du projet Very Deep Convolutional Networks for Large-Scale Visual Recognition.
vgg16_zhang_perceptual.pkl
est en outre dérivé des poids LPIPS pré-entraînés par Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman et Oliver Wang. Les poids ont été initialement partagés sous la licence « simplifiée » BSD 2-Clause sur le référentiel PerceptualSimilarity.
Linux et Windows sont pris en charge, mais nous recommandons fortement Linux pour des raisons de performances et de compatibilité.
Installation de Python 3.6 64 bits. Nous recommandons Anaconda3 avec numpy 1.14.3 ou plus récent.
TensorFlow 1.10.0 ou version ultérieure avec prise en charge GPU.
Un ou plusieurs GPU NVIDIA haut de gamme avec au moins 11 Go de DRAM. Nous recommandons NVIDIA DGX-1 avec 8 GPU Tesla V100.
Pilote NVIDIA 391.35 ou version ultérieure, boîte à outils CUDA 9.0 ou version ultérieure, cuDNN 7.3.1 ou version ultérieure.
Un exemple minimal d'utilisation d'un générateur StyleGAN pré-entraîné est donné dans pretrained_example.py. Une fois exécuté, le script télécharge un générateur StyleGAN pré-entraîné depuis Google Drive et l'utilise pour générer une image :
> python pretrained_example.py Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done Gs Params OutputShape WeightShape --- --- --- --- latents_in - (?, 512) - ... images_out - (?, 3, 1024, 1024) - --- --- --- --- Total 26219627 > ls results example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
Un exemple plus avancé est donné dans generate_figures.py. Le script reproduit les figures de notre article afin d'illustrer le mélange de styles, les entrées de bruit et la troncature :
> python generate_figures.py results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6 results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_ results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
Les réseaux pré-entraînés sont stockés sous forme de fichiers pickle standard sur Google Drive :
# Load pre-trained network. url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: _G, _D, Gs = pickle.load(f) # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
Le code ci-dessus télécharge le fichier et le décompresse pour produire 3 instances de dnnlib.tflib.Network. Pour générer des images, vous souhaiterez généralement utiliser Gs
– les deux autres réseaux sont fournis par souci d’exhaustivité. Pour que pickle.load()
fonctionne, vous devrez avoir le répertoire source dnnlib
dans votre PYTHONPATH et un tf.Session
défini par défaut. La session peut être initialisée en appelant dnnlib.tflib.init_tf()
.
Il existe trois façons d'utiliser le générateur pré-entraîné :
Utilisez Gs.run()
pour un fonctionnement en mode immédiat où les entrées et sorties sont des tableaux numpy :
# Pick latent vector. rnd = np.random.RandomState(5) latents = rnd.randn(1, Gs.input_shape[1]) # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
Le premier argument est un lot de vecteurs de forme latents [num, 512]
. Le deuxième argument est réservé aux étiquettes de classe (non utilisées par StyleGAN). Les arguments de mot-clé restants sont facultatifs et peuvent être utilisés pour modifier davantage l'opération (voir ci-dessous). La sortie est un lot d’images dont le format est dicté par l’argument output_transform
.
Utilisez Gs.get_output_for()
pour incorporer le générateur dans le cadre d'une expression TensorFlow plus grande :
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) images = tflib.convert_images_to_uint8(images) result_expr.append(inception_clone.get_output_for(images))
Le code ci-dessus provient de metrics/frechet_inception_distance.py. Il génère un lot d'images aléatoires et les transmet directement au réseau Inception-v3 sans avoir à convertir les données en tableaux numpy entre les deux.
Recherchez Gs.components.mapping
et Gs.components.synthesis
pour accéder aux sous-réseaux individuels du générateur. Semblable à Gs
, les sous-réseaux sont représentés comme des instances indépendantes de dnnlib.tflib.Network :
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
Le code ci-dessus provient de generate_figures.py. Il transforme d'abord un lot de vecteurs latents dans l'espace W intermédiaire à l'aide du réseau de cartographie, puis transforme ces vecteurs en un lot d'images à l'aide du réseau de synthèse. Le tableau dlatents
stocke une copie distincte du même vecteur w pour chaque couche du réseau de synthèse afin de faciliter le mélange de styles.
Les détails exacts du générateur sont définis dans training/networks_stylegan.py (voir G_style
, G_mapping
et G_synthesis
). Les arguments de mot-clé suivants peuvent être spécifiés pour modifier le comportement lors de l'appel run()
et get_output_for()
:
truncation_psi
et truncation_cutoff
contrôlent l'astuce de troncature effectuée par défaut lors de l'utilisation Gs
(ψ=0,7, cutoff=8). Il peut être désactivé en définissant truncation_psi=1
ou is_validation=True
, et la qualité de l'image peut être encore améliorée au prix de la variation en définissant par exemple truncation_psi=0.5
. Notez que la troncature est toujours désactivée lors de l'utilisation directe des sous-réseaux. La moyenne w nécessaire pour effectuer manuellement l'astuce de troncature peut être recherchée à l'aide de Gs.get_var('dlatent_avg')
.
randomize_noise
détermine s'il faut re-randomiser les entrées de bruit pour chaque image générée ( True
, par défaut) ou s'il faut utiliser des valeurs de bruit spécifiques pour l'ensemble du mini-lot ( False
). Les valeurs spécifiques sont accessibles via les instances tf.Variable
trouvées en utilisant [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
.
Lorsque vous utilisez directement le réseau cartographique, vous pouvez spécifier dlatent_broadcast=None
pour désactiver la duplication automatique des dlatents
sur les couches du réseau de synthèse.
Les performances d'exécution peuvent être affinées via structure='fixed'
et dtype='float16'
. Le premier désactive la prise en charge de la croissance progressive, qui n'est pas nécessaire pour un générateur entièrement formé, et le second effectue tous les calculs en utilisant l'arithmétique à virgule flottante demi-précision.
Les scripts de formation et d'évaluation fonctionnent sur des ensembles de données stockés sous forme de TFRecords multi-résolution. Chaque ensemble de données est représenté par un répertoire contenant les mêmes données d'image dans plusieurs résolutions pour permettre un streaming efficace. Il existe un fichier *.tfrecords distinct pour chaque résolution, et si l'ensemble de données contient des étiquettes, elles sont également stockées dans un fichier distinct. Par défaut, les scripts s'attendent à trouver les ensembles de données dans datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords
. Le répertoire peut être modifié en éditant config.py :
result_dir = 'results' data_dir = 'datasets' cache_dir = 'cache'
Pour obtenir l'ensemble de données FFHQ ( datasets/ffhq
), veuillez vous référer au référentiel Flickr-Faces-HQ.
Pour obtenir l'ensemble de données CelebA-HQ ( datasets/celebahq
), veuillez vous référer au référentiel Progressive GAN.
Pour obtenir d’autres ensembles de données, dont LSUN, veuillez consulter leurs pages de projets correspondantes. Les ensembles de données peuvent être convertis en TFRecords multi-résolution à l'aide du dataset_tool.py fourni :
> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256 > python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384 > python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256 > python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10 > python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
Une fois les ensembles de données configurés, vous pouvez entraîner vos propres réseaux StyleGAN comme suit :
Modifiez train.py pour spécifier l'ensemble de données et la configuration de formation en supprimant les commentaires ou en modifiant des lignes spécifiques.
Exécutez le script de formation avec python train.py
.
Les résultats sont écrits dans un répertoire nouvellement créé results/<ID>-<DESCRIPTION>
.
La formation peut prendre plusieurs jours (ou semaines), selon la configuration.
Par défaut, train.py
est configuré pour entraîner le StyleGAN de la plus haute qualité (configuration F dans le tableau 1) pour l'ensemble de données FFHQ à une résolution de 1 024 × 1 024 à l'aide de 8 GPU. Veuillez noter que nous avons utilisé 8 GPU dans toutes nos expériences. L'entraînement avec moins de GPU peut ne pas produire des résultats identiques. Si vous souhaitez comparer avec notre technique, nous vous recommandons fortement d'utiliser le même nombre de GPU.
Temps de formation prévus pour la configuration par défaut utilisant les GPU Tesla V100 :
GPU | 1024×1024 | 512×512 | 256×256 |
---|---|---|---|
1 | 41 jours 4 heures | 24 jours 21 heures | 14 jours 22 heures |
2 | 21 jours 22 heures | 13 jours 7 heures | 9 jours 5 heures |
4 | 11 jours 8 heures | 7 jours 0 heures | 4 jours 21 heures |
8 | 6 jours 14 heures | 4 jours 10 heures | 3 jours 8 heures |
Les métriques de qualité et de démêlage utilisées dans notre article peuvent être évaluées à l'aide de run_metrics.py. Par défaut, le script évaluera la distance de démarrage de Fréchet ( fid50k
) pour le générateur FFHQ pré-entraîné et écrira les résultats dans un répertoire nouvellement créé sous results
. Le comportement exact peut être modifié en supprimant les commentaires ou en modifiant des lignes spécifiques dans run_metrics.py.
Temps d'évaluation et résultats attendus pour le générateur FFHQ pré-entraîné utilisant un GPU Tesla V100 :
Métrique | Temps | Résultat | Description |
---|---|---|---|
fid50k | 16 minutes | 4.4159 | Fréchet Inception Distance à partir de 50 000 images. |
ppl_zfull | 55 minutes | 664.8854 | Longueur du chemin perceptuel pour les chemins complets en Z . |
ppl_wfull | 55 minutes | 233.3059 | Longueur du chemin perceptuel pour les chemins complets en W . |
ppl_zend | 55 minutes | 666.1057 | Longueur du chemin perceptuel pour les extrémités du chemin dans Z . |
ppl_wend | 55 minutes | 197.2266 | Longueur du chemin perceptuel pour les extrémités du chemin dans W . |
ls | 10 heures | z: 165.0106 w: 3.7447 | Séparabilité linéaire en Z et W . |
Veuillez noter que les résultats exacts peuvent varier d'une exécution à l'autre en raison de la nature non déterministe de TensorFlow.
Nous remercions Jaakko Lehtinen, David Luebke et Tuomas Kynkäänniemi pour leurs discussions approfondies et leurs commentaires utiles ; Janne Hellsten, Tero Kuosmanen et Pekka Jänis pour l'infrastructure de calcul et l'aide à la publication du code.