Implementação PyTorch de Image GPT, baseada no papel Generative Pretraining from Pixels (Chen et al.) e código que o acompanha.
Conclusões geradas pelo modelo de meias imagens do conjunto de teste. A primeira coluna é a entrada; a última coluna é a imagem original
iGPT-S pré-treinado em CIFAR10. As conclusões são bastante ruins, pois o modelo foi treinado apenas no CIFAR10, e não em todo o ImageNet.
sklearn.cluster.MiniBatchKMeans
.) De acordo com a postagem do blog, o maior modelo, iGPT-L (parâmetros de 1,4 M), foi treinado por 2.500 dias V100. Ao reduzir bastante o número de cabeça de atenção, o número de camadas e o tamanho da entrada (que afeta o tamanho do modelo quadraticamente), podemos treinar nosso próprio modelo (parâmetros de 26 K) no Fashion-MNIST em um único NVIDIA 2070 em menos de 2 horas.
Alguns modelos pré-treinados estão localizados no diretório models
. Execute ./download.sh
para baixar o modelo iGPT-S pré-treinado cifar10
.
As imagens são baixadas e os centróides são calculados usando k -means com clusters num_clusters
. Esses centróides são usados para quantizar as imagens antes de serem inseridas no modelo.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/_centroids.npy
Nota: Use os mesmos num_clusters
que num_vocab
em seu modelo .
Os modelos podem ser treinados usando src/run.py
com o subcomando train
.
Os modelos podem ser pré-treinados especificando um conjunto de dados e uma configuração de modelo. configs/s_gen.yml
corresponde ao iGPT-S do papel, configs/xxs_gen.yml
é um modelo extra pequeno para testar conjuntos de dados de brinquedos com computação limitada.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Os modelos pré-treinados podem ser ajustados passando o caminho para o ponto de verificação pré-treinado para --pretrained
, junto com o arquivo de configuração e o conjunto de dados.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt `
Figuras como as vistas acima podem ser criadas usando imagens aleatórias do conjunto de testes:
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Gifs como o visto no meu tweet podem ser feitos assim:
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt