이것은 벡터 양자화 된 변형 autoencoder (https://arxiv.org/abs/1711.00937)의 pytorch 구현입니다.
Jupyter 노트북에서 실행할 수있는 예제와 함께 Tensorflow에서 저자의 원래 구현을 찾을 수 있습니다.
종속성을 설치하려면 Python 3을 사용하여 Conda 또는 Virtual Environment를 작성한 다음 pip install -r requirements.txt
실행하십시오.
VQ-VAE를 실행하려면 python3 main.py
실행합니다. 모델을 저장하려면 -save
플래그를 포함하십시오. 명령 줄에 매개 변수를 추가 할 수도 있습니다. 기본값은 아래에 지정되어 있습니다.
parser . add_argument ( "--batch_size" , type = int , default = 32 )
parser . add_argument ( "--n_updates" , type = int , default = 5000 )
parser . add_argument ( "--n_hiddens" , type = int , default = 128 )
parser . add_argument ( "--n_residual_hiddens" , type = int , default = 32 )
parser . add_argument ( "--n_residual_layers" , type = int , default = 2 )
parser . add_argument ( "--embedding_dim" , type = int , default = 64 )
parser . add_argument ( "--n_embeddings" , type = int , default = 512 )
parser . add_argument ( "--beta" , type = float , default = .25 )
parser . add_argument ( "--learning_rate" , type = float , default = 3e-4 )
parser . add_argument ( "--log_interval" , type = int , default = 50 )
VQ VAE에는 다음과 같은 기본 모델 구성 요소가 있습니다.
x -> z_e
정의하는 Encoder
클래스z_e -> z_q
의 인덱스 인 이산 1 홈 벡터로 변환하는 VectorQuantizer
클래스z_q -> x_hat
정의하고 원본 이미지를 재구성하는 Decoder
클래스 인코더 / 디코더 클래스는 컨볼 루션 및 역 컨볼 루션 스택이며, 아키텍처의 잔류 블록을 포함하여 RESNET 용지를 참조하십시오. 잔류 모델은 ResidualLayer
및 ResidualStack
클래스에 의해 정의됩니다.
이러한 구성 요소는 다음 폴더 구조로 구성됩니다.
models/
- decoder.py -> Decoder
- encoder.py -> Encoder
- quantizer.py -> VectorQuantizer
- residual.py -> ResidualLayer, ResidualStack
- vqvae.py -> VQVAE
잠재적 인 공간에서 샘플링하기 위해, 우리는 잠복 픽셀 값 z_ij
에 pixelcnn을 맞추고 있습니다. 여기서 트릭은 VQ VAE가 이미지를 1 채널 이미지와 동일한 구조를 갖는 잠재 공간에 맵핑한다는 것을 인식하고 있습니다. 예를 들어, 기본 VQ VAE 매개 변수를 실행하면 모양 (32,32,3)
의 RGB 맵 이미지 (8,8,1)
가 8x8 그레이 스케일 이미지와 동일합니다. 따라서 Pixelcnn을 사용하여 8x8 1 채널 잠재 공간의 "픽셀"값에 대한 분포에 맞습니다.
잠재적 인 표현으로 Pixelcnn을 훈련 시키려면 먼저 다음 단계를 따라야합니다.
np.save
API로 개별 잠재 공간 표현을 저장하십시오. quantizer.py
에서 이것은 min_encoding_indices
변수입니다.utils.load_latent_block
함수에서 저장된 잠재적 공간 데이터 세트로의 경로를 지정하십시오.pixelcnn을 실행하려면 간단히 입력하십시오
python pixelcnn/gated_pixelcnn.py
모든 매개 변수 (Argparse 문 참조). 기본 데이터 세트는 vq VAE를 훈련시키고 잠재적 인 표현을 저장 한 경우에만 작동하는 LATENT_BLOCK
입니다.