Implementasi PyTorch dari Image GPT, berdasarkan makalah Generative Pretraining from Pixels (Chen et al.) dan kode yang menyertainya.
Penyelesaian setengah gambar yang dihasilkan model dari set pengujian. Kolom pertama adalah masukan; kolom terakhir adalah gambar asli
iGPT-S dilatih sebelumnya di CIFAR10. Penyelesaiannya cukup buruk karena model hanya dilatih di CIFAR10, tidak semua di ImageNet.
sklearn.cluster.MiniBatchKMeans
.) Menurut postingan blog mereka, model terbesar, iGPT-L (parameter 1,4 M), dilatih selama 2500 V100 hari. Dengan sangat mengurangi jumlah head perhatian, jumlah lapisan, dan ukuran input (yang memengaruhi ukuran model secara kuadrat), kita dapat melatih model kita sendiri (parameter 26 K) di Fashion-MNIST pada satu NVIDIA 2070 dalam waktu kurang dari 2 jam.
Beberapa model terlatih terletak di direktori models
. Jalankan ./download.sh
untuk mengunduh model iGPT-S cifar10
yang telah dilatih sebelumnya.
Gambar diunduh, dan centroid dihitung menggunakan k -means dengan cluster num_clusters
. Centroid ini digunakan untuk mengkuantisasi gambar sebelum dimasukkan ke dalam model.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/_centroids.npy
Catatan: Gunakan num_clusters
yang sama dengan num_vocab
di model Anda .
Model dapat dilatih menggunakan src/run.py
dengan subperintah train
.
Model dapat dilatih sebelumnya dengan menentukan kumpulan data dan konfigurasi model. configs/s_gen.yml
sesuai dengan iGPT-S dari makalah, configs/xxs_gen.yml
adalah model ekstra kecil untuk mencoba kumpulan data mainan dengan komputasi terbatas.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Model yang telah dilatih sebelumnya dapat disempurnakan dengan meneruskan jalur ke pos pemeriksaan yang telah dilatih sebelumnya ke --pretrained
, bersama dengan file konfigurasi dan kumpulan data.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt `
Gambar seperti yang terlihat di atas dapat dibuat menggunakan gambar acak dari set pengujian:
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Gif seperti yang terlihat di tweet saya bisa dibuat seperti ini:
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt