Реализация PyTorch Image GPT, основанная на документе «Генераторное предварительное обучение по пикселям» (Чен и др.) и сопроводительном коде.
Сгенерированные моделью дополнения полуизображений из тестового набора. Первый столбец является входным; последний столбец — исходное изображение
iGPT-S предварительно обучен на CIFAR10. Завершение довольно плохое, поскольку модель была обучена только на CIFAR10, а не на всей ImageNet.
sklearn.cluster.MiniBatchKMeans
). Согласно их сообщению в блоге, самая крупная модель iGPT-L (параметры 1,4 млн) обучалась в течение 2500 V100 дней. Значительно уменьшив количество головок внимания, количество слоев и размер входных данных (что квадратично влияет на размер модели), мы можем обучить нашу собственную модель (26 тыс. параметров) на Fashion-MNIST на одном NVIDIA 2070 менее чем за 2 часа.
Некоторые предварительно обученные модели находятся в каталоге models
. Запустите ./download.sh
, чтобы загрузить предварительно обученную модель iGPT-S cifar10
.
Изображения загружаются, а центроиды вычисляются с использованием k -средних с кластерами num_clusters
. Эти центроиды используются для квантования изображений перед их передачей в модель.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/_centroids.npy
Примечание. Используйте в своей модели те же num_clusters
, что и num_vocab
.
Модели можно обучать с помощью src/run.py
с подкомандой train
.
Модели можно предварительно обучить, указав набор данных и конфигурацию модели. configs/s_gen.yml
соответствует iGPT-S из статьи, configs/xxs_gen.yml
— это очень маленькая модель для примерки игрушечных наборов данных с ограниченными вычислениями.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Предварительно обученные модели можно точно настроить, передав путь к предварительно обученной контрольной точке в --pretrained
вместе с файлом конфигурации и набором данных.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt `
Фигуры, подобные показанным выше, можно создать, используя случайные изображения из тестового набора:
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Гифки, подобные той, что видна в моем твите, можно сделать так:
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt