Ini adalah implementasi pytorch dari vektor variasi autoencoder kuantisasi (https://arxiv.org/abs/1711.00937).
Anda dapat menemukan implementasi asli penulis di TensorFlow di sini dengan contoh yang dapat Anda jalankan dalam buku catatan Jupyter.
Untuk menginstal dependensi, buat lingkungan conda atau virtual dengan Python 3 dan kemudian jalankan pip install -r requirements.txt
.
Untuk menjalankan VQ-VAE cukup jalankan python3 main.py
Pastikan untuk memasukkan bendera -save
jika Anda ingin menyimpan model Anda. Anda juga dapat menambahkan parameter di baris perintah. Nilai default ditentukan di bawah ini:
parser . add_argument ( "--batch_size" , type = int , default = 32 )
parser . add_argument ( "--n_updates" , type = int , default = 5000 )
parser . add_argument ( "--n_hiddens" , type = int , default = 128 )
parser . add_argument ( "--n_residual_hiddens" , type = int , default = 32 )
parser . add_argument ( "--n_residual_layers" , type = int , default = 2 )
parser . add_argument ( "--embedding_dim" , type = int , default = 64 )
parser . add_argument ( "--n_embeddings" , type = int , default = 512 )
parser . add_argument ( "--beta" , type = float , default = .25 )
parser . add_argument ( "--learning_rate" , type = float , default = 3e-4 )
parser . add_argument ( "--log_interval" , type = int , default = 50 )
VQ VAE memiliki komponen model fundamental berikut:
Encoder
yang mendefinisikan peta x -> z_e
VectorQuantizer
yang mengubah output encoder menjadi vektor satu -panas diskrit yang merupakan indeks vektor embedding terdekat z_e -> z_q
Decoder
yang mendefinisikan peta z_q -> x_hat
dan merekonstruksi gambar asli Kelas encoder / decoder adalah tumpukan konvolusional konvolusional dan terbalik, yang meliputi blok residu dalam arsitekturnya lihat Paper Resnet. Model residu ditentukan oleh kelas ResidualLayer
dan ResidualStack
.
Komponen -komponen ini diatur dalam struktur folder berikut:
models/
- decoder.py -> Decoder
- encoder.py -> Encoder
- quantizer.py -> VectorQuantizer
- residual.py -> ResidualLayer, ResidualStack
- vqvae.py -> VQVAE
Untuk mencicipi dari ruang laten, kami cocok dengan pixelcnn di atas nilai piksel laten z_ij
. Triknya di sini adalah mengakui bahwa VQ VAE memetakan gambar ke ruang laten yang memiliki struktur yang sama dengan gambar 1 saluran. Misalnya, jika Anda menjalankan parameter VQ VAE default, Anda akan RGB memetakan gambar bentuk (32,32,3)
ke ruang laten dengan bentuk (8,8,1)
, yang setara dengan gambar skala abu -abu 8x8. Oleh karena itu, Anda dapat menggunakan pixelcnn agar sesuai dengan distribusi di atas nilai "piksel" dari ruang laten 8x8 1-channel.
Untuk melatih pixelcnn pada representasi laten, pertama -tama Anda harus mengikuti langkah -langkah ini:
np.save
. Dalam quantizer.py
ini adalah variabel min_encoding_indices
.utils.load_latent_block
.Untuk menjalankan pixelcnn, cukup ketik
python pixelcnn/gated_pixelcnn.py
serta parameter apa pun (lihat pernyataan ArgParse). Dataset default adalah LATENT_BLOCK
yang hanya akan berfungsi jika Anda telah melatih VQ VAE Anda dan menyimpan representasi laten.