fidjax
1.0.0
JAX 中 Frechet 起始距离的干净实现。
pathlib
API 从 GCS 加载权重。1️⃣ FID JAX 是单个文件,因此您只需将其复制到项目目录即可。或者您可以安装该软件包:
pip install fidjax
2️⃣ 下载初始权重(归功于 Matthias Wright):
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle ? dl=1
3️⃣ 下载所需分辨率的 ImageNet 参考统计数据(为其他数据集生成您自己的统计数据):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ 在 JAX 中计算激活、统计数据和分数:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax . FID ( weights , reference )
fid_total = 50000
fid_batch = 1000
acts = []
for range ( fid_total // fid_batch ):
samples = ... # (B, H, W, 3) jnp.uint8
acts . append ( fid . compute_acts ( samples ))
stats = fid . compute_stats ( acts )
score = fid . compute_score ( stats )
print ( float ( score )) # FID
数据集 | 模型 | FID贾克斯 | OpenAI TF |
---|---|---|---|
图像网256 | ADM(引导,上采样) | 3.937 | 3.943 |
通过支持您的云存储的pathlib.Path
实现指向文件。以 GCS 为例:
import elements # pip install elements
import fidjax
weights = elements . Path ( 'gs://bucket/fid/inception_v3_weights_fid.pickle' )
reference = elements . Path ( 'gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz' )
fid = fidjax . FID ( weights , reference )
生成自定义数据集的参考统计数据:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax . FID ( weights )
acts = fid . compute_acts ( images )
mu , sigma = fid . compute_stats ( acts )
np . savez ( 'reference.npz' , { 'mu' : mu , 'sigma' : sigma })
请在 Github 上提交问题。