MaxText adalah LLM sumber terbuka berperforma tinggi , sangat skalabel , dan ditulis dengan Python/Jax murni dan menargetkan Google Cloud TPU dan GPU untuk pelatihan dan inferensi . MaxText mencapai MFU tinggi dan menskalakan dari host tunggal ke cluster yang sangat besar namun tetap sederhana dan "bebas optimasi" berkat kekuatan Jax dan kompiler XLA.
MaxText bertujuan untuk menjadi titik peluncuran proyek LLM yang ambisius baik dalam penelitian dan produksi. Kami mendorong pengguna untuk memulai dengan bereksperimen dengan MaxText secara langsung, lalu melakukan fork dan memodifikasi MaxText untuk memenuhi kebutuhan mereka.
Kami telah menggunakan MaxText untuk mendemonstrasikan pelatihan berperforma tinggi dan terkonvergensi dengan baik di int8 dan menskalakan pelatihan hingga ~51 ribu chip.
Fitur utama yang didukung:
Untuk pertama kalinya Anda menjalankan MaxText, kami memberikan instruksi khusus.
MaxText mendukung pelatihan dan inferensi berbagai model terbuka. Ikuti panduan pengguna di folder memulai untuk mengetahui lebih lanjut.
Beberapa panduan tambahan yang bermanfaat:
Selain panduan memulai, selalu ada kemampuan MaxText lainnya yang terus ditambahkan! Rangkaian lengkap pengujian end-to-end ada di end_to_end. Kami menjalankannya dengan irama malam. Mereka dapat menjadi sumber yang baik untuk memahami MaxText Alternatifnya, Anda dapat melihat pengujian unit berkelanjutan yang dijalankan hampir terus menerus.
Detail lebih lanjut tentang mereproduksi hasil ini dapat ditemukan di MaxText/configs/README.md.
Jumlah params | Tipe Akselerator | TFLOP/chip/detik | Pemanfaatan model jepit (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71,47% |
64B | v5p-128 | 3.23e+02 | 70,31% |
128B | v5p-256 | 3.15e+02 | 68,68% |
128B | v5p-512 | 3.15e+02 | 68,53% |
256B | v5p-1024 | 3.16e+02 | 68,82% |
512B | v5p-1024 | 2.94e+02 | 63,99% |
1024B | v5p-2048 | 2.49e+02 | 64,05% |
1024B | v5p-4096 | 2.97e+02 | 64,80% |
1160B | v5p-7680 | 2.95e+02 | 64,27% |
1160B | v5p-12288 | 3.04e+02 | 66,23% |
Untuk model 16B, 32B, 64B, dan 128B. Lihat konfigurasi proses penuh di MaxText/configs/v5e/ sebagai 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
.
Perangkat keras | 16B TFLOP/dtk/chip | 16B MFU | 32B TFLOP/dtk/chip | 32B MFU | 64B TFLOP/dtk/chip | 64B MFU | 128B TFLOP/dtk/chip | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61,10% | 132 | 66,86% | 118 | 59,90% | 110 | 56,06% |
2x v5e-256 | 117 | 59,37% | 128 | 64,81% | 112 | 56,66% | 110 | 55,82% |
4x v5e-256 | 117 | 59,14% | 126 | 64,10% | 110 | 55,85% | 108 | 54,93% |
8x v5e-256 | 115 | 58,27% | 125 | 63,67% | 108 | 54,96% | 104 | 52,93% |
16x v5e-256 | 111 | 56,56% | 123 | 62,26% | 105 | 53,29% | 100 | 50,86% |
32x v5e-256 | 108 | 54,65% | 119 | 60,40% | 99 | 50,18% | 91 | 46,25% |
MaxText sangat terinspirasi oleh MinGPT/NanoGPT, implementasi GPT mandiri yang elegan yang ditulis dalam PyTorch dan menargetkan GPU Nvidia. MaxText lebih kompleks, mendukung lebih banyak model standar industri dan menskalakan hingga puluhan ribu chip. Pada akhirnya MaxText memiliki MFU lebih dari tiga kali lipat dari 17% yang dilaporkan baru-baru ini dengan basis kode tersebut, dapat diskalakan secara besar-besaran dan mengimplementasikan cache nilai kunci untuk decoding auto-regresif yang efisien.
MaxText lebih mirip dengan Nvidia/Megatron-LM, implementasi LLM yang disetel dengan sangat baik yang menargetkan GPU Nvidia. Kedua implementasi tersebut mencapai MFU yang sebanding. Perbedaan basis kode menyoroti strategi pemrograman yang berbeda. MaxText adalah Python murni, sangat bergantung pada kompiler XLA untuk mencapai kinerja tinggi. Sebaliknya, Megatron-LM adalah campuran Python dan CUDA, mengandalkan kernel CUDA yang dioptimalkan dengan baik untuk mencapai kinerja tinggi.
MaxText juga sebanding dengan Pax. Seperti Pax, MaxText menyediakan implementasi LLM berkinerja tinggi dan terukur di Jax. Pax berfokus pada mengaktifkan parameter konfigurasi yang kuat, memungkinkan pengembang mengubah model dengan mengedit parameter konfigurasi. Sebaliknya, MaxText adalah implementasi sederhana dan konkrit dari berbagai LLM yang mendorong pengguna untuk memperluas dengan melakukan forking dan mengedit langsung kode sumber.
Saat menjalankan pekerjaan Program Tunggal, Banyak Data (SPMD) pada akselerator, keseluruhan proses dapat terhenti jika ada kesalahan atau VM macet/rusak karena alasan tertentu. Dalam skenario ini, menangkap pelacakan tumpukan akan membantu mengidentifikasi dan memecahkan masalah pekerjaan yang berjalan pada VM TPU.
Konfigurasi berikut akan membantu men-debug kesalahan atau ketika suatu program macet atau terhenti di suatu tempat dengan mengumpulkan jejak tumpukan. Ubah nilai parameter sesuai di MaxText/configs/base.yml
:
collect_stack_trace: True
untuk mengaktifkan pengumpulan jejak tumpukan pada kesalahan atau ketika program digantung. Pengaturan ini secara berkala akan membuang jejak program untuk membantu proses debug. Untuk menonaktifkan ini, setel collect_stack_trace: False
.stack_trace_to_cloud: False
untuk menampilkan jejak tumpukan di konsol. stack_trace_to_cloud: True
akan membuat file sementara di /tmp/debugging
di TPU untuk menyimpan jejak tumpukan. Ada agen yang berjalan di VM TPU yang secara berkala akan mengunggah jejak dari direktori sementara ke cloud logging di proyek gcp. Anda dapat melihat jejak di Logs Explorer di Cloud Logging menggunakan kueri berikut: logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
menandakan durasi dalam detik antara setiap peristiwa pengumpulan jejak tumpukan. Menyetel stack_trace_interval_seconds: 600
akan mengumpulkan jejak tumpukan setiap 600 detik (10 menit).Berikut adalah paket PyPI terkait: https://pypi.org/project/cloud-tpu-diagnostics.
Untuk mengompilasi proses pelatihan Anda sebelumnya, kami menyediakan alat train_compile.py
. Alat ini memungkinkan Anda mengkompilasi train_step
utama di train.py
untuk perangkat keras target (misalnya perangkat v5e dalam jumlah besar) tanpa menggunakan cluster penuh.
Anda hanya dapat menggunakan CPU atau satu VM dari keluarga berbeda untuk melakukan pra-kompilasi untuk kluster TPU. Kompilasi ini membantu dua tujuan utama:
Ini akan menandai informasi kehabisan memori (OOM), seperti ketika per_device_batch_size
diatur terlalu tinggi, dengan jejak tumpukan OOM yang identik seolah-olah dikompilasi pada perangkat keras target.
Kompilasi sebelumnya dapat disimpan dan kemudian dimuat untuk waktu startup dan restart yang cepat pada perangkat keras target.
Alat train_compile.py
terkait erat dengan train.py
dan menggunakan file konfigurasi yang sama configs/base.yml
. Meskipun Anda tidak perlu menjalankannya di TPU, Anda perlu menginstal jax[tpu]
selain dependensi lainnya, jadi sebaiknya jalankan setup.sh
untuk menginstalnya jika Anda belum melakukannya.
Setelah menginstal dependensi yang tercantum di atas, Anda siap untuk melakukan kompilasi terlebih dahulu:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
Ini akan mengkompilasi model MaxText parameter 16B pada 2 pod v5e.
Berikut adalah contoh menyimpan lalu memuat train_step
yang telah dikompilasi, dimulai dengan penyimpanan:
Langkah 1: Jalankan AOT dan simpan fungsi yang dikompilasi
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
Langkah 2: Jalankan train.py dan muat fungsi yang dikompilasi
Untuk memuat train_step yang telah dikompilasi, Anda hanya perlu meneruskan compiled_trainstep_file=my_compiled_train.pickle
ke train.py
:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Pada langkah penyimpanan contoh 2 di atas, kami menyertakan ekspor flag compiler LIBTPU_INIT_ARGS
dan learning_rate
karena hal tersebut memengaruhi objek yang dikompilasi my_compiled_train.pickle.
Ukuran model (misalnya global_parameter_scale
, max_sequence_length
dan per_device_batch
) ditetapkan ketika Anda pertama kali mengompilasi melalui compile_train.py
, Anda akan melihat kesalahan ukuran jika Anda mencoba menjalankan objek terkompilasi yang disimpan dengan ukuran berbeda dari yang Anda kompilasi. Namun catatan halusnya adalah bahwa jadwal kecepatan pemelajaran juga ditetapkan saat Anda menjalankan compile_train
- yang ditentukan oleh steps
dan learning_rate
. Parameter pengoptimal seperti adam_b1
diteruskan hanya sebagai objek berbentuk ke kompiler - sehingga nilai sebenarnya ditentukan saat Anda menjalankan train.py
, bukan selama kompilasi. Jika Anda meneruskan dalam bentuk yang berbeda (misalnya per_device_batch
), Anda akan mendapatkan pesan kesalahan yang jelas yang melaporkan bahwa tanda tangan yang dikompilasi memiliki bentuk yang diharapkan berbeda dari yang dimasukkan. Jika Anda mencoba menjalankan perangkat keras yang berbeda dari target kompilasi yang diminta melalui compile_topology
, Anda akan mendapatkan pesan kesalahan yang mengatakan ada kegagalan dalam memetakan perangkat dari yang dikompilasi ke perangkat Anda yang sebenarnya. Menggunakan flag XLA atau LIBTPU yang berbeda dari yang dikompilasi mungkin akan berjalan secara diam-diam dengan lingkungan tempat Anda mengkompilasi tanpa kesalahan. Namun tidak ada jaminan perilaku dalam kasus ini; Anda harus menjalankannya di lingkungan yang sama dengan tempat Anda mengkompilasi.
Kompilasi sebelumnya juga didukung untuk GPU dengan beberapa perbedaan dari TPU:
GPU tidak mendukung kompilasi lintas perangkat keras: Sebuah host GPU masih diperlukan untuk menjalankan kompilasi AoT, namun satu host GPU dapat mengkompilasi program untuk cluster yang lebih besar dari perangkat keras yang sama.
Untuk GPU Cloud A3, ukuran "slice" maksimum adalah satu host, dan parameter compile_topology_num_slices
mewakili jumlah mesin A3 yang akan diprakompilasi.
Contoh ini mengilustrasikan tanda yang digunakan untuk kompilasi GPU multihost yang menargetkan cluster yang terdiri dari 4 host A3:
Langkah 1: Jalankan AOT dan simpan fungsi yang dikompilasi
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
Langkah 2: Jalankan train.py dan muat fungsi yang dikompilasi
Untuk memuat train_step yang telah dikompilasi, Anda hanya perlu meneruskan compiled_trainstep_file=my_compiled_train.pickle
ke train.py
:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Seperti dalam kasus TPU, perhatikan bahwa lingkungan kompilasi harus cocok dengan lingkungan eksekusi, dalam hal ini dengan menyetel XLA_FLAGS
yang sama.
MaxText mendukung pengunggahan otomatis log yang dikumpulkan dalam direktori ke instance Tensorboard di Vertex AI. Ikuti panduan pengguna untuk mengetahui lebih banyak.