Perpustakaan haiku menggunakan operator xmap
/ pjit
di JAX untuk model paralelisme transformator.
Skema paralelisme mirip dengan Megatron-LM asli, yang efisien pada TPU karena jaringan mesh 2d berkecepatan tinggi. Ada juga versi model eksperimental yang mengimplementasikan sharding gaya ZeRo.
Pustaka ini dirancang untuk skalabilitas hingga sekitar 40 miliar parameter pada TPUv3, yang lebih dari itu harus digunakan strategi paralelisme yang berbeda. Lihat implementasi lain seperti GPT-NeoX atau DeepSpeed untuk mengetahui hal tersebut.
Salah satu arah penelitian di masa depan adalah mengintegrasikan basis kode ini dengan gerombolan-jax, untuk mencapai skalabilitas lebih lanjut dengan paralelisme pipa.
12-07-21 : Menambahkan panduan untuk menyempurnakan
Model pembuatan teks autoregresif dengan 6 miliar parameter dilatih di The Pile.
Unduh bobot ramping (hanya bobot bf16, untuk inferensi, 9GB)
Unduh bobot penuh (termasuk parameter pengoptimal, 61GB)
Pos pemeriksaan yang dilatih sebagian
Demo kolab
Demo web
Postingan blog Aran
Proyek ini tidak akan mungkin terwujud tanpa komputasi yang disediakan secara murah hati oleh TPU Research Cloud dengan bantuan dari EleutherAI.
Terima kasih kepada tim Cloud TPU di Google yang telah memberikan akses awal ke Cloud TPU VM alpha (kini tersedia untuk umum!)
Terima kasih kepada semua orang yang telah membantu dengan satu atau lain cara (diurutkan berdasarkan abjad):
Bobot GPT-J-6B dilisensikan di bawah Lisensi Apache versi 2.0.
Hiperparameter | Nilai |
---|---|
n_parameter | 6.053.381.344 |
n_lapisan | 28* |
d_model | 4.096 |
d_ff | 16.384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2.048 |
n_vocab | 50.257 (tokenizer yang sama dengan GPT-2/3) |
pengkodean posisi | Pengkodean posisi putar (RoPE) |
Dimensi tali | 64 |
*
setiap lapisan terdiri dari satu blok feedforward dan satu blok perhatian mandiri
Model terdiri dari 28 layer dengan dimensi model 4096, dan dimensi feedforward 16384. Dimensi model dibagi menjadi 16 head, masing-masing berdimensi 256. Rotary position coding (RoPE) diterapkan pada 64 dimensi masing-masing head. . Model dilatih dengan kosakata tokenisasi 50257, menggunakan kumpulan BPE yang sama dengan GPT-2/GPT-3.
Model diurutkan secara kasar berdasarkan performa, atau berdasarkan FLOP jika tidak tersedia.
Model | beban | Pelatihan FLOP | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Ukuran Kumpulan Data (GB) |
---|---|---|---|---|---|---|---|---|
Peluang | ✔ | 0 | ~banyak | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51,6% | 52,9% | 43,4% | 70,5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51,21% | 59,4% | 50,9% | 70,8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57,2% | 55,0% | 48,9% | 71,1% | 825 |
Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61,7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62,2% | 56,5% | 55,8% | 73,0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63,6% | 58,7% | 54,7% | 75,1% | ~800 |
GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62,4% | 59,0% | 54,5% | 75,5% | ----- |
Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66,5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67,1% | 62,3% | 62,8% | 75,6% | ~800 |
Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3,99 | 69,7% | 65,3% | 66,1% | 76,5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70,3% | 64,5% | 67,4% | 78,0% | ~800 |
GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69,3% | 65,6% | 68,5% | 77,9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72,5% | 67,9% | 70,9% | 78,5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76,2% | 70,2% | 78,9% | 81,0% | ~800 |
GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
Gopher 230B* | ✘ | 6.31E+23 | ----- | 74,50% | 70,10% | 79,20% | 81,80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76,6% | 73,0% | 80,2% | 82,0% | ----- |
*
mewakili nomor evaluasi yang dilaporkan oleh penulisnya masing-masing, semua nomor lainnya disediakan dengan menjalankan lm-evaluation-harness baik dengan bobot yang dirilis atau dengan akses API. Karena perbedaan implementasi yang tidak kentara serta kerangka tugas zero shot yang berbeda, hal ini mungkin tidak dapat dibandingkan secara langsung. Lihat posting blog ini untuk lebih jelasnya.
†
Model Megatron-11B tidak memberikan metrik yang sebanding, dan beberapa implementasi yang menggunakan bobot yang dirilis tidak mereproduksi kualitas dan evaluasi pembangkitan. (lihat 1 2 3) Oleh karena itu, evaluasi tidak dilakukan.
‡
Model ini telah dilatih dengan data yang berisi kemungkinan kontaminasi set pengujian. Model OpenAI GPT-3 gagal menghapus duplikat data pelatihan untuk set pengujian tertentu, sedangkan model GPT-Neo serta model ini dilatih di The Pile, yang belum dihapus duplikatnya terhadap set pengujian mana pun.
Sebagian besar skrip dalam repositori ini dirancang untuk dijalankan di TPU, yang dalam arsitektur TPU-VM merupakan mesin virtual yang dapat menjalankan kode arbitrer. Sebagian besar skrip dirancang untuk menjalankan TPU, SSH ke dalamnya untuk mengatur dependensi dan menyalin kode dari direktori lokal, dan kemudian memulai pekerja Ray yang dapat menerima panggilan RPC.
TPUVM menangani langkah-langkah dan evaluasi pelatihan model yang berjalan, penyimpanan dan pemuatan pos pemeriksaan, sedangkan program driver python menangani pemuatan data dan orkestrasi umum (seperti kapan harus menyimpan pos pemeriksaan, dll).
Artinya sebagian besar skrip ( train.py
, eval_harness.py
dll) diharapkan dapat berjalan pada mesin virtual GCE di wilayah yang sama dengan TPU, untuk meminimalkan latensi RPC dan biaya transfer data. Skrip lain (biasanya skrip yang tidak menggunakan argumen --tpu
, seperti device_sample.py
, device_serve.py
, atau device_train.py
) diharapkan dijalankan langsung di TPUVM. Skrip device_* hanya berfungsi pada v3-8 dan tidak pada pod yang lebih besar.
Selain itu, terdapat contoh ( resharding_example.py
) tentang cara mengonversi pos pemeriksaan yang disediakan (yang memiliki 8 shard pada kasus GPT-J-6B) ke angka yang lebih kecil, misalnya saat dijalankan pada GPU.
Untuk menyempurnakan model, jalankan device_train.py
pada VM TPU. Dengan menggunakan TPU v3-8, Anda dapat melakukan penyesuaian dengan kecepatan ~5000 token/detik, yang seharusnya cukup untuk kumpulan data berukuran kecil hingga menengah.
Silakan baca panduan langkah demi langkah untuk petunjuk penyempurnaan menyeluruh.
Perhatikan perpustakaan ini memiliki beberapa persyaratan khusus untuk versi JAX. Khususnya, untuk menggunakan model v1 (termasuk GPT-J 6B), diperlukan jax==0.2.12
. Ini pada gilirannya tergantung pada jaxlib==0.1.68
. Jika ini tidak dilakukan, Anda akan mendapatkan kesalahan xmap samar
Namun, untuk menggunakan kode model v2 (tidak ada bobot yang dirilis ke publik), versi JAX terbaru dapat digunakan.
Mengutip repositori ini:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Mengutip bobot GPT-J-6B:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Jika Anda menggunakan repositori ini atau beban terlatih lainnya untuk melakukan sesuatu yang keren, kami akan senang mendengarnya. Jangan ragu untuk membuka terbitan github atau menghubungi melalui email (di profil).