Implémentation PyTorch d'Image GPT, basée sur l'article Generative Pretraining from Pixels (Chen et al.) et le code qui l'accompagne.
Complétions générées par le modèle de demi-images à partir de l'ensemble de test. La première colonne est saisie ; la dernière colonne est l'image originale
iGPT-S pré-entraîné sur CIFAR10. Les résultats sont assez médiocres car le modèle n'a été formé que sur CIFAR10, pas sur l'intégralité d'ImageNet.
sklearn.cluster.MiniBatchKMeans
.) Selon leur article de blog, le plus grand modèle, iGPT-L (1,4 M de paramètres), a été formé pendant 2 500 jours V100. En réduisant considérablement le nombre de têtes d'attention, le nombre de couches et la taille d'entrée (ce qui affecte la taille du modèle de manière quadratique), nous pouvons entraîner notre propre modèle (26 000 paramètres) sur Fashion-MNIST sur un seul NVIDIA 2070 en moins de 2 heures.
Certains modèles pré-entraînés se trouvent dans le répertoire models
. Exécutez ./download.sh
pour télécharger le modèle iGPT-S pré-entraîné cifar10
.
Les images sont téléchargées et les centroïdes sont calculés à l'aide de k -means avec des clusters num_clusters
. Ces centroïdes sont utilisés pour quantifier les images avant qu'elles ne soient introduites dans le modèle.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/_centroids.npy
Remarque : utilisez les mêmes num_clusters
que num_vocab
dans votre modèle .
Les modèles peuvent être entraînés à l'aide de src/run.py
avec la sous-commande train
.
Les modèles peuvent être pré-entraînés en spécifiant un ensemble de données et une configuration de modèle. configs/s_gen.yml
correspond à iGPT-S du document, configs/xxs_gen.yml
est un très petit modèle pour essayer des ensembles de données de jouets avec un calcul limité.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Les modèles pré-entraînés peuvent être affinés en transmettant le chemin d'accès au point de contrôle pré-entraîné à --pretrained
, ainsi que le fichier de configuration et l'ensemble de données.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt `
Des figures comme celles vues ci-dessus peuvent être créées à l’aide d’images aléatoires provenant de l’ensemble de test :
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Les gifs comme celui vu dans mon tweet peuvent être créés comme ceci :
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt