fidjax
1.0.0
Clean implementation of the Frechet Inception Distance in JAX.
pathlib
API.1️⃣ FID JAX is a single file, so you can just copy it to your project directory. Or you can install the package:
pip install fidjax
2️⃣ Download the Inception weights (credits to Matthias Wright):
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle?dl=1
3️⃣ Download the ImageNet reference stats of the desired resolution (generate your own for other datasets):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ Compute activations, statistics, and scores in JAX:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax.FID(weights, reference)
fid_total = 50000
fid_batch = 1000
acts = []
for range(fid_total // fid_batch):
samples = ... # (B, H, W, 3) jnp.uint8
acts.append(fid.compute_acts(samples))
stats = fid.compute_stats(acts)
score = fid.compute_score(stats)
print(float(score)) # FID
Dataset | Model | FID JAX | OpenAI TF |
---|---|---|---|
ImageNet 256 | ADM (guided, upsampled) | 3.937 | 3.943 |
Point to the files via a pathlib.Path
implementation that support your
Cloud storage. For example for GCS:
import elements # pip install elements
import fidjax
weights = elements.Path('gs://bucket/fid/inception_v3_weights_fid.pickle')
reference = elements.Path('gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz')
fid = fidjax.FID(weights, reference)
Generate reference statistics for custom datasets:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax.FID(weights)
acts = fid.compute_acts(images)
mu, sigma = fid.compute_stats(acts)
np.savez('reference.npz', {'mu': mu, 'sigma': sigma})
Please file an issue on Github.