Implementación de Image GPT en PyTorch, basada en el documento Generative Pretraining from Pixels (Chen et al.) y el código adjunto.
Completaciones generadas por modelos de medias imágenes del conjunto de prueba. Se ingresa la primera columna; La última columna es la imagen original.
iGPT-S previamente entrenado en CIFAR10. Las terminaciones son bastante deficientes ya que el modelo solo se entrenó en CIFAR10, no en todo ImageNet.
sklearn.cluster.MiniBatchKMeans
). Según su publicación de blog, el modelo más grande, iGPT-L (1,4 M de parámetros), fue entrenado durante 2500 días V100. Al reducir en gran medida la cantidad de atención, la cantidad de capas y el tamaño de entrada (lo que afecta el tamaño del modelo cuadráticamente), podemos entrenar nuestro propio modelo (parámetros de 26 K) en Fashion-MNIST en una sola NVIDIA 2070 en menos de 2 horas.
Algunos modelos previamente entrenados se encuentran en el directorio models
. Ejecute ./download.sh
para descargar el modelo iGPT-S previamente entrenado cifar10
.
Las imágenes se descargan y los centroides se calculan utilizando k -means con grupos num_clusters
. Estos centroides se utilizan para cuantificar las imágenes antes de introducirlas en el modelo.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/_centroids.npy
Nota: Utilice los mismos num_clusters
que num_vocab
en su modelo .
Los modelos se pueden entrenar usando src/run.py
con el subcomando train
.
Los modelos se pueden entrenar previamente especificando un conjunto de datos y una configuración de modelo. configs/s_gen.yml
corresponde a iGPT-S del artículo, configs/xxs_gen.yml
es un modelo extra pequeño para probar conjuntos de datos de juguetes con cómputo limitado.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Los modelos previamente entrenados se pueden ajustar pasando la ruta al punto de control previamente entrenado a --pretrained
, junto con el archivo de configuración y el conjunto de datos.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt `
Se pueden crear figuras como las que se ven arriba usando imágenes aleatorias del conjunto de prueba:
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Gifs como el que se ve en mi tweet se pueden hacer así:
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt