图片:这些人不是真实的——他们是由我们的生成器生成的,可以控制图像的不同方面。
该存储库包含以下论文的官方 TensorFlow 实现:
用于生成对抗网络的基于样式的生成器架构
Tero Karras (NVIDIA)、Samuli Laine (NVIDIA)、Timo Aila (NVIDIA)
https://arxiv.org/abs/1812.04948摘要:我们借鉴风格迁移文献,提出了一种用于生成对抗网络的替代生成器架构。新的架构导致了高级属性(例如,在人脸上训练时的姿势和身份)和生成图像(例如,雀斑、头发)中的随机变化的自动学习、无监督分离,并且它实现了直观的、规模化的。合成的具体控制。新的生成器改进了传统分布质量指标的最先进水平,带来了明显更好的插值属性,并且还更好地消除了潜在的变化因素。为了量化插值质量和解缠结,我们提出了两种适用于任何生成器架构的新的自动化方法。最后,我们引入了一个新的、高度多样化且高质量的人脸数据集。
如需业务咨询,请访问我们的网站并提交表格:NVIDIA 研究许可
★★★ 新功能:StyleGAN2-ADA-PyTorch 现已推出;在这里查看完整的版本列表 ★★★
与我们论文相关的材料可通过以下链接获取:
论文:https://arxiv.org/abs/1812.04948
视频:https://youtu.be/kSLJriaOumA
代码:https://github.com/NVlabs/stylegan
FFHQ:https://github.com/NVlabs/ffhq-dataset
其他材料可以在 Google 云端硬盘上找到:
小路 | 描述 |
---|---|
风格GAN | 主文件夹。 |
├ stylegan-paper.pdf | 论文 PDF 的高质量版本。 |
├ stylegan-video.mp4 | 结果视频的高质量版本。 |
├ 图片 | 使用我们的生成器生成的示例图像。 |
│ ├ 代表性图片 | 用于文章、博客文章等的高质量图像。 |
│ └ 100k 生成图像 | 100,000 张针对不同截断量生成的图像。 |
│ ├ ffhq-1024x1024 | 使用 Flickr-Faces-HQ 数据集以 1024×1024 生成。 |
│ ├ 卧室-256x256 | 使用 256×256 的 LSUN Bedroom 数据集生成。 |
│ ├ 汽车-512x384 | 使用 512×384 的 LSUN Car 数据集生成。 |
│ └ 猫-256x256 | 使用 256×256 的 LSUN Cat 数据集生成。 |
├ 视频 | 使用我们的生成器生成的示例视频。 |
│ └ 高品质视频剪辑 | 结果视频的各个片段为高质量 MP4。 |
├ ffhq-数据集 | Flickr-Faces-HQ 数据集的原始数据。 |
└ 网络 | 预先训练的网络作为 dnnlib.tflib.Network 的 pickle 实例。 |
├ stylegan-ffhq-1024x1024.pkl | StyleGAN 使用 1024×1024 的 Flickr-Faces-HQ 数据集进行训练。 |
├ stylegan-celebahq-1024x1024.pkl | StyleGAN 使用 CelebA-HQ 数据集在 1024×1024 下进行训练。 |
├ stylegan-bedrooms-256x256.pkl | StyleGAN 使用 256×256 的 LSUN Bedroom 数据集进行训练。 |
├ stylegan-cars-512x384.pkl | StyleGAN 使用 512×384 的 LSUN 汽车数据集进行训练。 |
├ stylegan-cats-256x256.pkl | StyleGAN 使用 256×256 的 LSUN Cat 数据集进行训练。 |
└ 指标 | 用于质量和解缠结指标的辅助网络。 |
├ inception_v3_features.pkl | 输出原始特征向量的标准 Inception-v3 分类器。 |
├ vgg16_zhang_perceptual.pkl | 用于估计感知相似性的标准 LPIPS 指标。 |
├ celebahq-classifier-00-male.pkl | 训练二元分类器来检测 CelebA-HQ 的单个属性。 |
└⋯ | 请参阅剩余网络的文件列表。 |
所有材料(不包括 Flickr-Faces-HQ 数据集)均由 NVIDIA Corporation 根据 Creative Commons BY-NC 4.0 许可提供。您可以出于非商业目的使用、重新分发和改编这些材料,只要您通过引用我们的论文并注明您所做的任何更改来给予适当的认可。
有关 FFHQ 数据集的许可信息,请参阅 Flickr-Faces-HQ 存储库。
inception_v3_features.pkl
和inception_v3_softmax.pkl
源自 Christian Szegedy、Vincent Vanhoucke、Sergey Ioffe、Jonathon Shlens 和 Zbigniew Wojna 预训练的 Inception-v3 网络。该网络最初是在 TensorFlow 模型存储库上根据 Apache 2.0 许可证共享的。
vgg16.pkl
和vgg16_zhang_perceptual.pkl
源自 Karen Simonyan 和 Andrew Zisserman 预训练的 VGG-16 网络。该网络最初是在用于大规模视觉识别的非常深卷积网络项目页面上根据 Creative Commons BY 4.0 许可证共享的。
vgg16_zhang_perceptual.pkl
进一步源自 Richard Zhu、Phillip Isola、Alexei A. Efros、Eli Shechtman 和 Oliver Wang 预训练的 LPIPS 权重。这些权重最初是根据 PerceptualSimilarity 存储库上的 BSD 2-Clause“简化”许可证共享的。
Linux 和 Windows 均受支持,但出于性能和兼容性原因,我们强烈推荐使用 Linux。
64 位 Python 3.6 安装。我们推荐使用 Anaconda3 和 numpy 1.14.3 或更高版本。
TensorFlow 1.10.0 或更高版本,支持 GPU。
一个或多个具有至少 11GB DRAM 的高端 NVIDIA GPU。我们推荐配备 8 个 Tesla V100 GPU 的 NVIDIA DGX-1。
NVIDIA 驱动程序 391.35 或更高版本、CUDA 工具包 9.0 或更高版本、cuDNN 7.3.1 或更高版本。
pretrained_example.py 中给出了使用预训练 StyleGAN 生成器的最小示例。执行时,该脚本会从 Google Drive 下载预先训练的 StyleGAN 生成器并使用它生成图像:
> python pretrained_example.py Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done Gs Params OutputShape WeightShape --- --- --- --- latents_in - (?, 512) - ... images_out - (?, 3, 1024, 1024) - --- --- --- --- Total 26219627 > ls results example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
generate_figures.py 中给出了更高级的示例。该脚本重现了我们论文中的数据,以说明风格混合、噪声输入和截断:
> python generate_figures.py results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6 results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_ results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
预先训练的网络作为标准 pickle 文件存储在 Google Drive 上:
# Load pre-trained network. url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: _G, _D, Gs = pickle.load(f) # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
上面的代码下载该文件并将其解封以生成 dnnlib.tflib.Network 的 3 个实例。要生成图像,您通常需要使用Gs
为了完整性而提供了其他两个网络。为了让pickle.load()
工作,您需要在 PYTHONPATH 中设置dnnlib
源目录,并将tf.Session
设置为默认目录。可以通过调用dnnlib.tflib.init_tf()
来初始化会话。
使用预训练生成器的方式有以下三种:
使用Gs.run()
进行立即模式操作,其中输入和输出是 numpy 数组:
# Pick latent vector. rnd = np.random.RandomState(5) latents = rnd.randn(1, Gs.input_shape[1]) # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
第一个参数是一批形状为[num, 512]
的潜在向量。第二个参数保留用于类标签(StyleGAN 不使用)。其余的关键字参数是可选的,可用于进一步修改操作(见下文)。输出是一批图像,其格式由output_transform
参数决定。
使用Gs.get_output_for()
将生成器合并为更大的 TensorFlow 表达式的一部分:
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:]) images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True) images = tflib.convert_images_to_uint8(images) result_expr.append(inception_clone.get_output_for(images))
上面的代码来自metrics/frechet_inception_distance.py。它生成一批随机图像并将它们直接输入到 Inception-v3 网络,而无需在中间将数据转换为 numpy 数组。
查找Gs.components.mapping
和Gs.components.synthesis
以访问生成器的各个子网络。与Gs
类似,子网络表示为 dnnlib.tflib.Network 的独立实例:
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
上面的代码来自generate_figures.py。它首先使用映射网络将一批潜在向量转换为中间W空间,然后使用合成网络将这些向量转换为一批图像。 dlatents
数组为合成网络的每一层存储相同w向量的单独副本,以方便风格混合。
生成器的具体细节在training/networks_stylegan.py中定义(参见G_style
、 G_mapping
和G_synthesis
)。可以指定以下关键字参数来修改调用run()
和get_output_for()
时的行为:
truncation_psi
和truncation_cutoff
控制使用Gs
时默认执行的截断技巧(ψ=0.7,cutoff=8)。可以通过设置truncation_psi=1
或is_validation=True
来禁用它,并且可以通过设置例如truncation_psi=0.5
以变化为代价进一步提高图像质量。请注意,直接使用子网时始终禁用截断。可以使用Gs.get_var('dlatent_avg')
查找手动执行截断技巧所需的平均w 。
randomize_noise
确定是否对每个生成的图像使用重新随机化噪声输入( True
,默认值)或是否对整个小批量使用特定噪声值( False
)。可以通过使用[var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
找到的tf.Variable
实例访问特定值。
直接使用映射网络时,可以指定dlatent_broadcast=None
以禁用在合成网络各层上自动复制dlatents
。
可以通过structure='fixed'
和dtype='float16'
微调运行时性能。前者禁用了对渐进增长的支持,这对于完全训练的生成器来说是不需要的,后者使用半精度浮点算术执行所有计算。
训练和评估脚本对存储为多分辨率 TFRecord 的数据集进行操作。每个数据集都由一个目录表示,该目录包含多种分辨率的相同图像数据,以实现高效的流传输。每个分辨率都有一个单独的 *.tfrecords 文件,如果数据集包含标签,它们也会存储在单独的文件中。默认情况下,脚本期望在datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords
处找到数据集。可以通过编辑 config.py 来更改目录:
result_dir = 'results' data_dir = 'datasets' cache_dir = 'cache'
要获取 FFHQ 数据集 ( datasets/ffhq
),请参阅 Flickr-Faces-HQ 存储库。
要获取 CelebA-HQ 数据集 ( datasets/celebahq
),请参阅 Progressive GAN 存储库。
要获取其他数据集(包括 LSUN),请查阅相应的项目页面。可以使用提供的 dataset_tool.py 将数据集转换为多分辨率 TFRecords:
> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256 > python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384 > python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256 > python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10 > python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
设置数据集后,您可以训练自己的 StyleGAN 网络,如下所示:
编辑 train.py 以通过取消注释或编辑特定行来指定数据集和训练配置。
使用python train.py
运行训练脚本。
结果将写入新创建的目录results/<ID>-<DESCRIPTION>
。
培训可能需要几天(或几周)才能完成,具体取决于配置。
默认情况下, train.py
配置为使用 8 个 GPU 以 1024×1024 分辨率为 FFHQ 数据集训练最高质量的 StyleGAN(表 1 中的配置 F)。请注意,我们在所有实验中都使用了 8 个 GPU。使用较少 GPU 进行训练可能不会产生相同的结果 - 如果您想与我们的技术进行比较,我们强烈建议使用相同数量的 GPU。
使用 Tesla V100 GPU 的默认配置的预期训练时间:
GPU | 1024×1024 | 512×512 | 256×256 |
---|---|---|---|
1 | 41天4小时 | 24天21小时 | 14天22小时 |
2 | 21天22小时 | 13天7小时 | 9天5小时 |
4 | 11天8小时 | 7天0小时 | 4天21小时 |
8 | 6天14小时 | 4天10小时 | 3天8小时 |
我们论文中使用的质量和解缠度指标可以使用 run_metrics.py 进行评估。默认情况下,该脚本将评估预训练 FFHQ 生成器的 Fréchet 起始距离 ( fid50k
),并将结果写入results
下新创建的目录。可以通过取消注释或编辑 run_metrics.py 中的特定行来更改确切的行为。
使用一个 Tesla V100 GPU 的预训练 FFHQ 生成器的预期评估时间和结果:
公制 | 时间 | 结果 | 描述 |
---|---|---|---|
50k | 16分钟 | 4.4159 | 使用 50,000 张图像的 Fréchet 起始距离。 |
ppl_zfull | 55分钟 | 664.8854 | Z中完整路径的感知路径长度。 |
ppl_wfull | 55分钟 | 233.3059 | W中完整路径的感知路径长度。 |
ppl_zend | 55分钟 | 666.1057 | Z中路径端点的感知路径长度。 |
ppl_wend | 55分钟 | 197.2266 | W中路径端点的感知路径长度。 |
LS | 10小时 | 号:165.0106 宽:3.7447 | Z和W的线性可分离性。 |
请注意,由于 TensorFlow 的不确定性,每次运行的确切结果可能会有所不同。
我们感谢 Jaakko Lehtinen、David Luebke 和 Tuomas Kynkäänniemi 的深入讨论和有益的评论; Janne Hellsten、Tero Kuosmanen 和 Pekka Jänis 负责计算基础设施并帮助发布代码。