Dies ist eine Pytorch -Implementierung des quantisierten Vektor -Variations -Autocoders (https://arxiv.org/abs/1711.00937).
Die ursprüngliche Implementierung des Autors finden Sie hier in TensorFlow mit einem Beispiel, das Sie in einem Jupyter -Notizbuch ausführen können.
Um Abhängigkeiten zu installieren, erstellen Sie eine Conda- oder virtuelle Umgebung mit Python 3 und führen Sie dann pip install -r requirements.txt
aus.
Um die VQ-vae zu führen, laufen Sie einfach python3 main.py
Stellen Sie sicher, dass Sie die Flagge -save
einfügen, wenn Sie Ihr Modell speichern möchten. Sie können auch Parameter in der Befehlszeile hinzufügen. Die Standardwerte sind unten angegeben:
parser . add_argument ( "--batch_size" , type = int , default = 32 )
parser . add_argument ( "--n_updates" , type = int , default = 5000 )
parser . add_argument ( "--n_hiddens" , type = int , default = 128 )
parser . add_argument ( "--n_residual_hiddens" , type = int , default = 32 )
parser . add_argument ( "--n_residual_layers" , type = int , default = 2 )
parser . add_argument ( "--embedding_dim" , type = int , default = 64 )
parser . add_argument ( "--n_embeddings" , type = int , default = 512 )
parser . add_argument ( "--beta" , type = float , default = .25 )
parser . add_argument ( "--learning_rate" , type = float , default = 3e-4 )
parser . add_argument ( "--log_interval" , type = int , default = 50 )
Die VQ VAE hat die folgenden grundlegenden Modellkomponenten:
Encoder
-Klasse, die die Karte x -> z_e
definiertVectorQuantizer
-Klasse, z_e -> z_q
den Encoder -Ausgang in einen diskreten One -Hot -Vektor verwandeltDecoder
-Klasse, die die Karte z_q -> x_hat
definiert und das Originalbild rekonstruiert Die Encoder- / Decoder -Klassen sind Faltungs- und inverse Faltungsstapel, die Restblöcke in ihrer Architektur finden. Siehe Resnet -Papier. Die Restmodelle werden von den ResidualLayer
und ResidualStack
-Klassen definiert.
Diese Komponenten sind in der folgenden Ordnerstruktur organisiert:
models/
- decoder.py -> Decoder
- encoder.py -> Encoder
- quantizer.py -> VectorQuantizer
- residual.py -> ResidualLayer, ResidualStack
- vqvae.py -> VQVAE
Um aus dem latenten Raum zu probieren, fügen wir einen Pixelcnn über die latenten Pixelwerte z_ij
. Der Trick hier ist zu erkennen, dass die VQ -VAE ein Bild auf einen latenten Raum bilden, der die gleiche Struktur wie ein 1 -Kanal -Bild hat. Wenn Sie beispielsweise die Standard -VQ -VAE -Parameter ausführen, werden Sie RGB -Kartenbilder von Form (32,32,3)
auf einen latenten Raum mit Form (8,8,1)
ausführen, was einem 8x8 -Graustufenbild entspricht. Daher können Sie einen Pixelcnn verwenden, um eine Verteilung über die "Pixel" -Werte des 8x8 1-Kanal-Latentenraums anzupassen.
Um den Pixelcnn auf latente Darstellungen zu trainieren, müssen Sie zunächst folgende Schritte befolgen:
np.save
-API. In der quantizer.py
ist dies die Variable min_encoding_indices
.utils.load_latent_block
an.Um den Pixelcnn auszuführen, geben Sie einfach ein
python pixelcnn/gated_pixelcnn.py
sowie alle Parameter (siehe ArgParse -Anweisungen). Der Standarddatensatz ist LATENT_BLOCK
, der nur dann funktioniert, wenn Sie Ihre VQ VAE geschult und die latenten Darstellungen gespeichert haben.