fidjax
1.0.0
Saubere Implementierung der Frechet Inception Distance in JAX.
pathlib
API laden.1️⃣ FID JAX ist eine einzelne Datei, Sie können sie also einfach in Ihr Projektverzeichnis kopieren. Oder Sie können das Paket installieren:
pip install fidjax
2️⃣ Laden Sie die Inception-Gewichte herunter (Credits an Matthias Wright):
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle ? dl=1
3️⃣ Laden Sie die ImageNet-Referenzstatistiken der gewünschten Auflösung herunter (erstellen Sie Ihre eigenen für andere Datensätze):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ Berechnen Sie Aktivierungen, Statistiken und Ergebnisse in JAX:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax . FID ( weights , reference )
fid_total = 50000
fid_batch = 1000
acts = []
for range ( fid_total // fid_batch ):
samples = ... # (B, H, W, 3) jnp.uint8
acts . append ( fid . compute_acts ( samples ))
stats = fid . compute_stats ( acts )
score = fid . compute_score ( stats )
print ( float ( score )) # FID
Datensatz | Modell | FID JAX | OpenAI TF |
---|---|---|---|
ImageNet 256 | ADM (geführt, hochgesampelt) | 3.937 | 3.943 |
Verweisen Sie über eine pathlib.Path
-Implementierung auf die Dateien, die Ihren Cloud-Speicher unterstützen. Zum Beispiel für GCS:
import elements # pip install elements
import fidjax
weights = elements . Path ( 'gs://bucket/fid/inception_v3_weights_fid.pickle' )
reference = elements . Path ( 'gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz' )
fid = fidjax . FID ( weights , reference )
Referenzstatistiken für benutzerdefinierte Datensätze generieren:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax . FID ( weights )
acts = fid . compute_acts ( images )
mu , sigma = fid . compute_stats ( acts )
np . savez ( 'reference.npz' , { 'mu' : mu , 'sigma' : sigma })
Bitte reichen Sie ein Problem auf Github ein.