fidjax
1.0.0
Implementación limpia de la distancia de inicio de Frechet en JAX.
pathlib
.1️⃣ FID JAX es un archivo único, por lo que puedes copiarlo al directorio de tu proyecto. O puedes instalar el paquete:
pip install fidjax
2️⃣ Descargue los pesos iniciales (créditos a Matthias Wright):
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle ? dl=1
3️⃣ Descargue las estadísticas de referencia de ImageNet de la resolución deseada (genere las suyas propias para otros conjuntos de datos):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ Calcula activaciones, estadísticas y puntuaciones en JAX:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax . FID ( weights , reference )
fid_total = 50000
fid_batch = 1000
acts = []
for range ( fid_total // fid_batch ):
samples = ... # (B, H, W, 3) jnp.uint8
acts . append ( fid . compute_acts ( samples ))
stats = fid . compute_stats ( acts )
score = fid . compute_score ( stats )
print ( float ( score )) # FID
Conjunto de datos | Modelo | FID JAX | OpenAI TF |
---|---|---|---|
ImagenNet 256 | ADM (guiado, con muestreo mejorado) | 3.937 | 3.943 |
Señale los archivos a través de una implementación pathlib.Path
que admita su almacenamiento en la nube. Por ejemplo para GCS:
import elements # pip install elements
import fidjax
weights = elements . Path ( 'gs://bucket/fid/inception_v3_weights_fid.pickle' )
reference = elements . Path ( 'gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz' )
fid = fidjax . FID ( weights , reference )
Genere estadísticas de referencia para conjuntos de datos personalizados:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax . FID ( weights )
acts = fid . compute_acts ( images )
mu , sigma = fid . compute_stats ( acts )
np . savez ( 'reference.npz' , { 'mu' : mu , 'sigma' : sigma })
Presente un problema en Github.