MaxText เป็น LLM โอเพ่นซอร์ส ประสิทธิภาพสูง ปรับขนาด ได้ สูง เขียนด้วย Python/Jax ล้วนๆ และกำหนดเป้าหมาย Google Cloud TPU และ GPU สำหรับ การฝึกอบรม และ การอนุมาน MaxText บรรลุ MFU ในระดับสูงและปรับขนาดได้ตั้งแต่โฮสต์เดี่ยวไปจนถึงคลัสเตอร์ขนาดใหญ่มาก ในขณะที่ยังคงความเรียบง่ายและ "ไม่มีการเพิ่มประสิทธิภาพ" ด้วยพลังของ Jax และคอมไพเลอร์ XLA
MaxText มุ่งหวังที่จะเป็นจุดเริ่มต้นสำหรับโครงการ LLM ที่ทะเยอทะยานทั้งในการวิจัยและการผลิต เราขอแนะนำให้ผู้ใช้เริ่มต้นด้วยการทดลองกับ MaxText ทันที จากนั้นแยกและแก้ไข MaxText เพื่อตอบสนองความต้องการของพวกเขา
เราใช้ MaxText เพื่อสาธิตการฝึกอบรมที่มีประสิทธิภาพสูงและผสมผสานกันอย่างดีใน int8 และขยายขนาดการฝึกอบรมเป็นชิป ~ 51,000
คุณสมบัติหลักที่รองรับ:
สำหรับการใช้งาน MaxText เป็นครั้งแรก เราจะให้คำแนะนำเฉพาะแก่คุณ
MaxText รองรับการฝึกอบรมและการอนุมานของโมเดลเปิดต่างๆ ปฏิบัติตามคู่มือผู้ใช้ในโฟลเดอร์เริ่มต้นใช้งานเพื่อทราบข้อมูลเพิ่มเติม
คำแนะนำที่เป็นประโยชน์เพิ่มเติมบางส่วน:
นอกเหนือจากคู่มือการเริ่มต้นใช้งานแล้ว ยังมีความสามารถอื่นๆ ของ MaxText ที่เพิ่มเข้ามาอยู่เสมอ! ชุดการทดสอบแบบ end-to-end ทั้งหมดอยู่ใน end_to_end เราดำเนินการตามจังหวะทุกคืน พวกเขาสามารถเป็นแหล่งที่ดีสำหรับการทำความเข้าใจ MaxText หรือคุณสามารถดูการทดสอบหน่วยต่อเนื่องซึ่งดำเนินการเกือบอย่างต่อเนื่อง
รายละเอียดเพิ่มเติมเกี่ยวกับการทำซ้ำผลลัพธ์เหล่านี้สามารถพบได้ใน MaxText/configs/README.md
จำนวนพารามิเตอร์ | ประเภทคันเร่ง | TFLOP/ชิป/วินาที | การใช้โมเดลฟล็อปส์ (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71.47% |
64B | v5p-128 | 3.23e+02 | 70.31% |
128B | v5p-256 | 3.15e+02 | 68.68% |
128B | v5p-512 | 3.15e+02 | 68.53% |
256B | v5p-1024 | 3.16e+02 | 68.82% |
512B | v5p-1024 | 2.94e+02 | 63.99% |
1024B | v5p-2048 | 2.49e+02 | 64.05% |
1024B | v5p-4096 | 2.97e+02 | 64.80% |
1160B | v5p-7680 | 2.95e+02 | 64.27% |
1160B | v5p-12288 | 3.04e+02 | 66.23% |
สำหรับรุ่น 16B, 32B, 64B และ 128B ดูการกำหนดค่าการทำงานเต็มรูปแบบใน MaxText/configs/v5e/ เป็น 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
ฮาร์ดแวร์ | 16B TFLOP/วินาที/ชิป | 16B มฟล | 32B TFLOP/วินาที/ชิป | 32B มฟล | 64B TFLOP/วินาที/ชิป | 64B มฟล | 128B TFLOP/วินาที/ชิป | 128B มฟล |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% |
2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% |
4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% |
8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% |
16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |
MaxText ได้รับแรงบันดาลใจอย่างมากจาก MinGPT/NanoGPT การใช้งาน GPT แบบสแตนด์อโลนที่หรูหราที่เขียนด้วย PyTorch และกำหนดเป้าหมายไปที่ Nvidia GPU MaxText มีความซับซ้อนมากขึ้น โดยรองรับโมเดลมาตรฐานอุตสาหกรรมมากขึ้นและปรับขนาดชิปได้หลายหมื่นตัว ท้ายที่สุด MaxText มี MFU มากกว่าสามเท่าของ 17% ที่รายงานล่าสุดด้วยโค้ดเบสนั้น สามารถปรับขนาดได้จำนวนมาก และใช้แคชคีย์-ค่าเพื่อการถอดรหัสแบบถดถอยอัตโนมัติที่มีประสิทธิภาพ
MaxText นั้นคล้ายคลึงกับ Nvidia/Megatron-LM มากกว่า ซึ่งเป็นการใช้งาน LLM ที่ได้รับการปรับแต่งอย่างดีโดยมีเป้าหมายไปที่ Nvidia GPU การใช้งานทั้งสองแบบบรรลุผล MFU ที่เทียบเคียงได้ ความแตกต่างในฐานรหัสเน้นย้ำถึงกลยุทธ์การเขียนโปรแกรมที่แตกต่างกัน MaxText เป็น Python ล้วนๆ ซึ่งอาศัยคอมไพเลอร์ XLA อย่างมากเพื่อให้ได้ประสิทธิภาพสูง ในทางตรงกันข้าม Megatron-LM เป็นการผสมผสานระหว่าง Python และ CUDA โดยอาศัยเคอร์เนล CUDA ที่ได้รับการปรับปรุงมาเป็นอย่างดีเพื่อให้ได้ประสิทธิภาพสูง
MaxText ยังเทียบได้กับ Pax เช่นเดียวกับ Pax MaxText มอบการใช้งาน LLM ใน Jax ที่มีประสิทธิภาพสูงและปรับขนาดได้ Pax มุ่งเน้นไปที่การเปิดใช้งานพารามิเตอร์การกำหนดค่าที่มีประสิทธิภาพ ช่วยให้นักพัฒนาสามารถเปลี่ยนโมเดลโดยการแก้ไขพารามิเตอร์การกำหนดค่า ในทางตรงกันข้าม MaxText เป็นการใช้งาน LLM ต่างๆ ที่เรียบง่ายและเป็นรูปธรรม ซึ่งสนับสนุนให้ผู้ใช้ขยายโดยการฟอร์กและแก้ไขซอร์สโค้ดโดยตรง
เมื่อรันงาน Single Program, Multiple Data (SPMD) บนตัวเร่งความเร็ว กระบวนการโดยรวมอาจหยุดทำงานหากมีข้อผิดพลาดหรือ VM ค้าง/ขัดข้องด้วยเหตุผลบางประการ ในสถานการณ์นี้ การจับร่องรอยสแต็กจะช่วยระบุและแก้ไขปัญหาสำหรับงานที่ทำงานบน TPU VM
การกำหนดค่าต่อไปนี้จะช่วยในการดีบักข้อผิดพลาดหรือเมื่อโปรแกรมค้างหรือหยุดทำงานที่ใดที่หนึ่งโดยการรวบรวมการติดตามสแต็ก เปลี่ยนค่าพารามิเตอร์ตาม MaxText/configs/base.yml
:
collect_stack_trace: True
เพื่อเปิดใช้งานการรวบรวมการติดตามสแต็กจากข้อบกพร่องหรือเมื่อโปรแกรมหยุดทำงาน การตั้งค่านี้จะถ่ายโอนข้อมูลการติดตามของโปรแกรมเป็นระยะๆ เพื่อช่วยในการแก้ไขจุดบกพร่อง หากต้องการปิดใช้งาน ให้ตั้งค่า collect_stack_trace: False
stack_trace_to_cloud: False
เพื่อแสดงการติดตามสแต็กบนคอนโซล stack_trace_to_cloud: True
จะสร้างไฟล์ชั่วคราวใน /tmp/debugging
ใน TPU เพื่อจัดเก็บการติดตามสแต็ก มีตัวแทนที่ทำงานบน TPU VM ซึ่งจะอัปโหลดการติดตามจากไดเรกทอรีชั่วคราวไปยังการบันทึกในระบบคลาวด์เป็นระยะในโปรเจ็กต์ gcp คุณสามารถดูการติดตามใน Logs Explorer บน Cloud Logging ได้โดยใช้คำค้นหาต่อไปนี้ logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
หมายถึงระยะเวลาเป็นวินาทีระหว่างแต่ละเหตุการณ์การรวบรวมการติดตามสแต็ก การตั้งค่า stack_trace_interval_seconds: 600
จะรวบรวมการติดตามสแต็กทุกๆ 600 วินาที (10 นาที)นี่คือแพ็คเกจ PyPI ที่เกี่ยวข้อง: https://pypi.org/project/cloud-tpu-diagnostics
เพื่อรวบรวมการฝึกซ้อมล่วงหน้า เรามีเครื่องมือ train_compile.py
เครื่องมือนี้ช่วยให้คุณสามารถคอมไพล์ train_step
หลักใน train.py
สำหรับฮาร์ดแวร์เป้าหมาย (เช่น อุปกรณ์ v5e จำนวนมาก) โดยไม่ต้องใช้คลัสเตอร์แบบเต็ม
คุณสามารถใช้ CPU หรือ VM เดียวจากตระกูลอื่นเพื่อคอมไพล์ล่วงหน้าสำหรับคลัสเตอร์ TPU การรวบรวมนี้ช่วยได้สองเป้าหมายหลัก:
โดยจะตั้งค่าสถานะข้อมูลหน่วยความจำไม่เพียงพอ (OOM) เช่น เมื่อตั้ง per_device_batch_size
สูงเกินไป โดยมีการติดตามสแต็ก OOM ที่เหมือนกันราวกับว่าถูกคอมไพล์บนฮาร์ดแวร์เป้าหมาย
การคอมไพล์ล่วงหน้าสามารถบันทึกแล้วโหลดเพื่อการเริ่มต้นและรีสตาร์ทอย่างรวดเร็วบนฮาร์ดแวร์เป้าหมาย
เครื่องมือ train_compile.py
เชื่อมโยงอย่างแน่นหนากับ train.py
และใช้ไฟล์การกำหนดค่าเดียวกัน configs/base.yml
แม้ว่าคุณไม่จำเป็นต้องรันบน TPU แต่คุณจำเป็นต้องติดตั้ง jax[tpu]
นอกเหนือจากการขึ้นต่อกันอื่นๆ ดังนั้นเราขอแนะนำให้รัน setup.sh
เพื่อติดตั้งสิ่งเหล่านี้หากคุณยังไม่ได้ดำเนินการ
หลังจากติดตั้งการขึ้นต่อกันตามรายการข้างต้น คุณก็พร้อมที่จะคอมไพล์ล่วงหน้า:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
สิ่งนี้จะรวบรวมโมเดล MaxText พารามิเตอร์ 16B บนพ็อด 2 v5e
นี่คือตัวอย่างที่บันทึกแล้วโหลด train_step
ที่คอมไพล์แล้ว โดยเริ่มจากการบันทึก:
ขั้นตอนที่ 1: เรียกใช้ AOT และบันทึกฟังก์ชันที่คอมไพล์แล้ว
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
ขั้นตอนที่ 2: เรียกใช้ train.py และโหลดฟังก์ชันที่คอมไพล์แล้ว
หากต้องการโหลด train_step ที่คอมไพล์แล้ว คุณเพียงแค่ต้องส่ง compiled_trainstep_file=my_compiled_train.pickle
ไปที่ train.py
:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
ในขั้นตอนการบันทึกของตัวอย่างที่ 2 ด้านบน เราได้รวมการส่งออกแฟล็กคอมไพเลอร์ LIBTPU_INIT_ARGS
และ learning_rate
เนื่องจากสิ่งเหล่านี้ส่งผลต่ออ็อบเจ็กต์ที่คอมไพล์ my_compiled_train.pickle.
ขนาดของโมเดล (เช่น global_parameter_scale
, max_sequence_length
และ per_device_batch
) ได้รับการแก้ไขแล้วเมื่อคุณคอมไพล์ครั้งแรกผ่าน compile_train.py
คุณจะเห็นข้อผิดพลาดเกี่ยวกับขนาดหากคุณพยายามเรียกใช้อ็อบเจ็กต์ที่คอมไพล์ที่บันทึกไว้ด้วยขนาดที่แตกต่างจากที่คุณคอมไพล์ด้วย อย่างไรก็ตาม หมายเหตุเล็กๆ น้อยๆ ก็คือ ตารางอัตราการเรียนรู้ จะได้รับการแก้ไขเช่นกันเมื่อคุณรัน compile_train
ซึ่งถูกกำหนดโดยทั้ง steps
และ learning_rate
พารามิเตอร์ของเครื่องมือเพิ่มประสิทธิภาพ เช่น adam_b1
จะถูกส่งผ่านเป็นวัตถุที่มีรูปทรงไปยังคอมไพเลอร์เท่านั้น ดังนั้นค่าที่แท้จริงจะถูกกำหนดเมื่อคุณรัน train.py
ไม่ใช่ในระหว่างการคอมไพล์ หากคุณส่งผ่านรูปร่างที่แตกต่างกัน (เช่น per_device_batch
) คุณจะได้รับข้อความแสดงข้อผิดพลาดที่ชัดเจนว่าลายเซ็นที่คอมไพล์มีรูปร่างที่คาดหวังแตกต่างจากที่ป้อน หากคุณพยายามทำงานบนฮาร์ดแวร์ที่แตกต่างจากเป้าหมายการคอมไพล์ที่ร้องขอผ่าน compile_topology
คุณจะได้รับข้อผิดพลาดแจ้งว่ามีความล้มเหลวในการแมปอุปกรณ์จากการคอมไพล์กับอุปกรณ์จริงของคุณ การใช้แฟล็ก XLA หรือ LIBTPU ที่แตกต่างจากที่คอมไพล์อาจจะทำงานแบบเงียบๆ กับสภาพแวดล้อมที่คุณคอมไพล์โดยไม่มีข้อผิดพลาด อย่างไรก็ตาม ไม่มีการรับประกันพฤติกรรมในกรณีนี้ คุณควรทำงานในสภาพแวดล้อมเดียวกับที่คุณคอมไพล์มา
การคอมไพล์ล่วงหน้ายังรองรับ GPU ที่มีความแตกต่างบางประการจาก TPU:
GPU ไม่รองรับการคอมไพล์ข้ามฮาร์ดแวร์: โฮสต์ GPU ยังคงจำเป็นต้องใช้การคอมไพล์ AoT แต่โฮสต์ GPU เดียวสามารถคอมไพล์โปรแกรมสำหรับคลัสเตอร์ที่ใหญ่กว่าของฮาร์ดแวร์เดียวกันได้
สำหรับ A3 Cloud GPU ขนาด "slice" สูงสุดคือโฮสต์เดียว และพารามิเตอร์ compile_topology_num_slices
แสดงถึงจำนวนเครื่อง A3 ที่จะคอมไพล์ล่วงหน้า
ตัวอย่างนี้แสดงแฟล็กที่จะใช้สำหรับการคอมไพล์ GPU หลายโฮสต์โดยกำหนดเป้าหมายคลัสเตอร์ที่มีโฮสต์ A3 4 ตัว:
ขั้นตอนที่ 1: เรียกใช้ AOT และบันทึกฟังก์ชันที่คอมไพล์แล้ว
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
ขั้นตอนที่ 2: เรียกใช้ train.py และโหลดฟังก์ชันที่คอมไพล์แล้ว
หากต้องการโหลด train_step ที่คอมไพล์แล้ว คุณเพียงแค่ต้องส่ง compiled_trainstep_file=my_compiled_train.pickle
ไปที่ train.py
:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
เช่นเดียวกับในกรณี TPU โปรดทราบว่าสภาพแวดล้อมการคอมไพล์จะต้องตรงกับสภาพแวดล้อมการดำเนินการ ในกรณีนี้โดยการตั้งค่า XLA_FLAGS
เดียวกัน
MaxText รองรับการอัปโหลดบันทึกที่รวบรวมในไดเรกทอรีไปยังอินสแตนซ์ Tensorboard ใน Vertex AI โดยอัตโนมัติ ปฏิบัติตามคู่มือผู้ใช้เพื่อทราบข้อมูลเพิ่มเติม