PyTorch-Implementierung von Image GPT, basierend auf Papier Generative Pretraining von Pixels (Chen et al.) und begleitendem Code.
Modellgenerierte Vervollständigungen von Halbbildern aus dem Testsatz. Erste Spalte ist Eingabe; Die letzte Spalte ist das Originalbild
iGPT-S vorab auf CIFAR10 trainiert. Die Abschlüsse sind ziemlich dürftig, da das Modell nur auf CIFAR10 trainiert wurde, nicht auf dem gesamten ImageNet.
sklearn.cluster.MiniBatchKMeans
.) Laut ihrem Blogbeitrag wurde das größte Modell, iGPT-L (1,4 Mio. Parameter), 2500 V100-Tage lang trainiert. Indem wir die Anzahl der Aufmerksamkeitsköpfe, die Anzahl der Schichten und die Eingabegröße (die sich quadratisch auf die Modellgröße auswirkt) stark reduzieren, können wir unser eigenes Modell (26 K Parameter) auf Fashion-MNIST auf einer einzelnen NVIDIA 2070 in weniger als 2 Stunden trainieren.
Einige vorab trainierte Modelle befinden sich im models
. Führen Sie ./download.sh
aus, um das vorab trainierte iGPT-S-Modell cifar10
herunterzuladen.
Bilder werden heruntergeladen und Schwerpunkte werden mit k -means mit num_clusters
-Clustern berechnet. Diese Schwerpunkte werden zur Quantisierung der Bilder verwendet, bevor sie in das Modell eingespeist werden.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/_centroids.npy
Hinweis: Verwenden Sie in Ihrem Modell dieselben num_clusters
wie num_vocab
.
Modelle können mit src/run.py
und dem Unterbefehl train
trainiert werden.
Modelle können vorab trainiert werden, indem ein Datensatz und eine Modellkonfiguration angegeben werden. configs/s_gen.yml
entspricht iGPT-S aus dem Artikel, configs/xxs_gen.yml
ist ein besonders kleines Modell zum Ausprobieren von Spielzeugdatensätzen mit begrenzter Rechenleistung.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Vorab trainierte Modelle können optimiert werden, indem der Pfad zum vorab trainierten Prüfpunkt zusammen mit der Konfigurationsdatei und dem Datensatz an --pretrained
übergeben wird.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt `
Abbildungen wie die oben gezeigten können mithilfe zufälliger Bilder aus dem Testsatz erstellt werden:
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Gifs wie das in meinem Tweet können wie folgt erstellt werden:
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt