Это реализация Pytorch квантового вариационного автоэкодора (https://arxiv.org/abs/1711.00937).
Вы можете найти оригинальную реализацию автора в Tensorflow здесь с примером, который вы можете запустить в ноутбуке Jupyter.
Чтобы установить зависимости, создайте Conda или виртуальную среду с Python 3, а затем запустите pip install -r requirements.txt
.
Чтобы запустить vq-vae, просто запустите python3 main.py
Обязательно включите флаг -save
, если вы хотите сохранить свою модель. Вы также можете добавить параметры в командной строке. Значения по умолчанию указаны ниже:
parser . add_argument ( "--batch_size" , type = int , default = 32 )
parser . add_argument ( "--n_updates" , type = int , default = 5000 )
parser . add_argument ( "--n_hiddens" , type = int , default = 128 )
parser . add_argument ( "--n_residual_hiddens" , type = int , default = 32 )
parser . add_argument ( "--n_residual_layers" , type = int , default = 2 )
parser . add_argument ( "--embedding_dim" , type = int , default = 64 )
parser . add_argument ( "--n_embeddings" , type = int , default = 512 )
parser . add_argument ( "--beta" , type = float , default = .25 )
parser . add_argument ( "--learning_rate" , type = float , default = 3e-4 )
parser . add_argument ( "--log_interval" , type = int , default = 50 )
VQ VAE имеет следующие фундаментальные компоненты модели:
Encoder
, который определяет карту x -> z_e
VectorQuantizer
, который преобразует вывод энкодера в дискретный вектор с одним горячим, который является индексом ближайшего вектора встраивания z_e -> z_q
Decoder
, который определяет карту z_q -> x_hat
и реконструирует исходное изображение Классы энкодера / декодера представляют собой сверточные и обратные сверточные стеки, которые включают остаточные блоки в их архитектуре, см. Resnet Paper. Остатовые модели определяются классами ResidualLayer
и ResidualStack
.
Эти компоненты организованы в следующей структуре папок:
models/
- decoder.py -> Decoder
- encoder.py -> Encoder
- quantizer.py -> VectorQuantizer
- residual.py -> ResidualLayer, ResidualStack
- vqvae.py -> VQVAE
Чтобы попробовать из скрытого пространства, мы устанавливаем Pixelcnn по скрытым значениям пикселя z_ij
. Хитрость здесь заключается в том, что VQ VQ VAE отображает изображение в скрытое пространство, которое имеет ту же структуру, что и изображение 1 канала. Например, если вы запустите параметры VQ VQ VQ VQ VQ VAE, вы будете карты RGB изображения формы (32,32,3)
в скрытое пространство с формой (8,8,1)
, что эквивалентно изображению серого масштаба 8x8. Следовательно, вы можете использовать Pixelcnn, чтобы соответствовать распределению по значениям «пикселя» 1-канального скрытого пространства 8x8.
Чтобы тренировать Pixelcnn на скрытые представления, вам сначала нужно выполнить эти шаги:
np.save
API. В quantizer.py
это переменная min_encoding_indices
.utils.load_latent_block
FUNCTION.Чтобы запустить Pixelcnn, просто введите
python pixelcnn/gated_pixelcnn.py
а также любые параметры (см. Заявления Argparse). Набор данных по умолчанию - LATENT_BLOCK
, который будет работать только в том случае, если вы обучили свой vq vae и сохранили скрытые представления.