DALL-E Open-AI di Mesh-Tensorflow.
Jika efisiensinya sama dengan GPT-Neo, repo ini seharusnya dapat melatih model hingga, dan lebih besar, dari ukuran DALL-E (12B params) Open-AI.
Tidak ada model terlatih... Belum.
Terima kasih kepada Ben Wang atas implementasi tf vae serta membuat versi mtf berfungsi, dan Aran Komatsuzaki atas bantuannya dalam membangun mtf VAE dan jalur input.
git clone https://github.com/EleutherAI/GPTNeo
cd GPTNeo
pip3 install -r requirements.txt
Berjalan pada TPU, belum diuji pada GPU tetapi seharusnya berfungsi secara teori . Contoh konfigurasi dirancang untuk dijalankan pada pod TPU v3-32.
Untuk menyiapkan TPU, daftar ke Google Cloud Platform, dan buat bucket penyimpanan.
Buat VM Anda melalui google Shell ( https://ssh.cloud.google.com/
) dengan ctpu up --vm-only
sehingga dapat terhubung ke bucket Google dan TPU Anda dan siapkan repo seperti di atas.
DALLE memerlukan VAE terlatih untuk mengompresi gambar menjadi token. Untuk menjalankan prapelatihan VAE, sesuaikan parameter di configs/vae_example.json
ke jalur glob yang menunjuk ke kumpulan data jpg, dan sesuaikan ukuran gambar ke ukuran yang sesuai.
"dataset": {
"train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg",
"eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg",
"image_size": 32
}
Setelah semuanya siap, buat TPU Anda, lalu jalankan:
python train_vae_tf.py --tpu your_tpu_name --model vae_example
Tensor gambar log pelatihan dan nilai kerugian, untuk memeriksa kemajuan, Anda dapat menjalankan:
tensorboard --logdir your_model_dir
Setelah VAE dilatih sebelumnya, Anda dapat melanjutkan ke DALL-E.
Saat ini kami sedang melatih kumpulan data dummy. Kumpulan data publik berskala besar untuk DALL-E sedang dalam pengerjaan. Sementara itu, untuk menghasilkan beberapa data dummy, jalankan:
python src/data/create_tfrecords.py
Ini akan mengunduh CIFAR-10, dan menghasilkan beberapa teks acak untuk dijadikan masukan teks.
Kumpulan data khusus harus diformat dalam sebuah folder, dengan file jsonl di folder akar yang berisi data keterangan dan jalur ke masing-masing gambar, sebagai berikut:
Folder structure:
data_folder
jsonl_file
folder_1
img1
img2
...
folder_2
img1
img2
...
...
jsonl structure:
{"image_path": folder_1/img1, "caption": "some words"}
{"image_path": folder_2/img2, "caption": "more words"}
...
Anda kemudian dapat menggunakan fungsi create_paired_dataset
di src/data/create_tfrecords.py
untuk menyandikan kumpulan data ke dalam tfrecords untuk digunakan dalam pelatihan.
Setelah kumpulan data dibuat, salin ke keranjang Anda dengan gsutil:
gsutil cp -r DALLE-tfrecords gs://neo-datasets/
Dan terakhir, jalankan pelatihan dengan
python train_dalle.py --tpu your_tpu_name --model dalle_example
VAE:
{
"model_type": "vae",
"dataset": {
"train_path": "gs://neo-datasets/CIFAR-10-images/train/**/*.jpg", # glob path to training images
"eval_path": "gs://neo-datasets/CIFAR-10-images/test/**/*.jpg", # glob path to eval images
"image_size": 32 # size of images (all images will be cropped / padded to this size)
},
"train_batch_size": 32,
"eval_batch_size": 32,
"predict_batch_size": 32,
"steps_per_checkpoint": 1000, # how often to save a checkpoint
"iterations": 500, # number of batches to infeed to the tpu at a time. Must be < steps_per_checkpoint
"train_steps": 100000, # total training steps
"eval_steps": 0, # run evaluation for this many steps every steps_per_checkpoint
"model_path": "gs://neo-models/vae_test2/", # directory in which to save the model
"mesh_shape": "data:16,model:2", # mapping of processors to named dimensions - see mesh-tensorflow repo for more info
"layout": "batch_dim:data", # which named dimensions of the model to split across the mesh - see mesh-tensorflow repo for more info
"num_tokens": 512, # vocab size
"dim": 512,
"hidden_dim": 64, # size of hidden dim
"n_channels": 3, # number of input channels
"bf_16": false, # if true, the model is trained with bfloat16 precision
"lr": 0.001, # learning rate [by default learning rate starts at this value, then decays to 10% of this value over the course of the training]
"num_layers": 3, # number of blocks in the encoder / decoder
"train_gumbel_hard": true, # whether to use hard or soft gumbel_softmax
"eval_gumbel_hard": true
}
DALL-E:
{
"model_type": "dalle",
"dataset": {
"train_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords", # glob path to tfrecords data
"eval_path": "gs://neo-datasets/DALLE-tfrecords/*.tfrecords",
"image_size": 32 # size of images (all images will be cropped / padded to this size)
},
"train_batch_size": 32, # see above
"eval_batch_size": 32,
"predict_batch_size": 32,
"steps_per_checkpoint": 1000,
"iterations": 500,
"train_steps": 100000,
"predict_steps": 0,
"eval_steps": 0,
"n_channels": 3,
"bf_16": false,
"lr": 0.001,
"model_path": "gs://neo-models/dalle_test/",
"mesh_shape": "data:16,model:2",
"layout": "batch_dim:data",
"n_embd": 512, # size of embedding dim
"text_vocab_size": 50258, # vocabulary size of the text tokenizer
"image_vocab_size": 512, # vocabulary size of the vae - should equal num_tokens above
"text_seq_len": 256, # length of text inputs (all inputs longer / shorter will be truncated / padded)
"n_layers": 6,
"n_heads": 4, # number of attention heads. For best performance, n_embd / n_heads should equal 128
"vae_model": "vae_example" # path to or name of vae model config
}