Implémentation propre de Frechet Inception Distance dans JAX.
pathlib
.1️⃣ FID JAX est un fichier unique, vous pouvez donc simplement le copier dans le répertoire de votre projet. Ou vous pouvez installer le package :
pip install fidjax
2️⃣ Téléchargez les poids de création (crédits à Matthias Wright) :
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle ? dl=1
3️⃣ Téléchargez les statistiques de référence ImageNet de la résolution souhaitée (générez les vôtres pour d'autres ensembles de données) :
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ Calculez les activations, les statistiques et les scores dans JAX :
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax . FID ( weights , reference )
fid_total = 50000
fid_batch = 1000
acts = []
for range ( fid_total // fid_batch ):
samples = ... # (B, H, W, 3) jnp.uint8
acts . append ( fid . compute_acts ( samples ))
stats = fid . compute_stats ( acts )
score = fid . compute_score ( stats )
print ( float ( score )) # FID
Ensemble de données | Modèle | FID-JAX | OpenAI TF |
---|---|---|---|
ImageNet 256 | ADM (guidé, suréchantillonné) | 3.937 | 3.943 |
Pointez vers les fichiers via une implémentation pathlib.Path
qui prennent en charge votre stockage Cloud. Par exemple pour GCS :
import elements # pip install elements
import fidjax
weights = elements . Path ( 'gs://bucket/fid/inception_v3_weights_fid.pickle' )
reference = elements . Path ( 'gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz' )
fid = fidjax . FID ( weights , reference )
Générez des statistiques de référence pour les ensembles de données personnalisés :
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax . FID ( weights )
acts = fid . compute_acts ( images )
mu , sigma = fid . compute_stats ( acts )
np . savez ( 'reference.npz' , { 'mu' : mu , 'sigma' : sigma })
Veuillez signaler un problème sur Github.