fidjax
1.0.0
การใช้งาน Frechet Inception Distance อย่างสะอาดใน JAX
pathlib
API1️⃣ FID JAX เป็นไฟล์เดียว ดังนั้นคุณจึงสามารถคัดลอกไปยังไดเร็กทอรีโปรเจ็กต์ของคุณได้ หรือคุณสามารถติดตั้งแพ็คเกจ:
pip install fidjax
2️⃣ ดาวน์โหลดตุ้มน้ำหนัก Inception (เครดิตโดย Matthias Wright):
wget https://www.dropbox.com/s/xt6zvlvt22dcwck/inception_v3_weights_fid.pickle ? dl=1
3️⃣ ดาวน์โหลดสถิติอ้างอิง ImageNet ของความละเอียดที่ต้องการ (สร้างของคุณเองสำหรับชุดข้อมูลอื่นๆ):
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz
4️⃣ คำนวณการเปิดใช้งาน สถิติ และคะแนนใน JAX:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
reference = './VIRTUAL_imagenet128_labeled.npz'
fid = fidjax . FID ( weights , reference )
fid_total = 50000
fid_batch = 1000
acts = []
for range ( fid_total // fid_batch ):
samples = ... # (B, H, W, 3) jnp.uint8
acts . append ( fid . compute_acts ( samples ))
stats = fid . compute_stats ( acts )
score = fid . compute_score ( stats )
print ( float ( score )) # FID
ชุดข้อมูล | แบบอย่าง | ฟิด แจ๊กซ์ | OpenAI TF |
---|---|---|---|
อิมเมจเน็ต 256 | ADM (นำทาง อัปแซมเพิล) | 3.937 | 3.943 |
ชี้ไปที่ไฟล์ผ่านการใช้งาน pathlib.Path
ที่รองรับที่เก็บข้อมูลบนคลาวด์ของคุณ ตัวอย่างเช่นสำหรับ GCS:
import elements # pip install elements
import fidjax
weights = elements . Path ( 'gs://bucket/fid/inception_v3_weights_fid.pickle' )
reference = elements . Path ( 'gs://bucket/fid/VIRTUAL_imagenet128_labeled.npz' )
fid = fidjax . FID ( weights , reference )
สร้างสถิติอ้างอิงสำหรับชุดข้อมูลที่กำหนดเอง:
import fidjax
import numpy as np
weights = './inception_v3_weights_fid.pickle?dl=1'
fid = fidjax . FID ( weights )
acts = fid . compute_acts ( images )
mu , sigma = fid . compute_stats ( acts )
np . savez ( 'reference.npz' , { 'mu' : mu , 'sigma' : sigma })
กรุณายื่นปัญหาบน Github