MaxText 是一个高性能、高度可扩展的开源法学硕士,用纯 Python/Jax 编写,针对 Google Cloud TPU 和 GPU 进行训练和推理。由于 Jax 和 XLA 编译器的强大功能,MaxText 实现了高 MFU 并可从单主机扩展到非常大的集群,同时保持简单和“免优化”。
MaxText 旨在成为研究和生产领域雄心勃勃的法学硕士项目的起点。我们鼓励用户首先尝试开箱即用的 MaxText,然后分叉和修改 MaxText 以满足他们的需求。
我们使用 MaxText 演示了 int8 的高性能、良好收敛的训练以及约 51K 芯片的规模训练。
主要支持的功能:
对于您第一次运行 MaxText,我们提供了具体说明。
MaxText支持各种开放模型的训练和推理。请按照入门文件夹中的用户指南了解更多信息。
一些额外有用的指南:
除了入门指南之外,还有其他 MaxText 功能正在不断添加!全套端到端测试位于 end_to_end 中。我们以每晚的节奏运行它们。它们可以成为理解 MaxText 的良好来源。另外,您还可以看到几乎连续运行的连续单元测试。
有关重现这些结果的更多详细信息,请参阅 MaxText/configs/README.md。
参数数量 | 加速器类型 | TFLOP/芯片/秒 | 模型触发器利用率 (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71.47% |
64B | v5p-128 | 3.23e+02 | 70.31% |
128B | v5p-256 | 3.15e+02 | 68.68% |
128B | v5p-512 | 3.15e+02 | 68.53% |
256B | v5p-1024 | 3.16e+02 | 68.82% |
512B | v5p-1024 | 2.94e+02 | 63.99% |
1024B | v5p-2048 | 2.49e+02 | 64.05% |
1024B | v5p-4096 | 2.97e+02 | 64.80% |
1160B | v5p-7680 | 2.95e+02 | 64.27% |
1160B | v5p-12288 | 3.04e+02 | 66.23% |
适用于 16B、32B、64B 和 128B 型号。请参阅 MaxText/configs/v5e/ 中的完整运行配置16b.sh
、 32b.sh
、 64b.sh
、 128b.sh
。
硬件 | 16B TFLOP/秒/芯片 | 16B MFU | 32B TFLOP/秒/芯片 | 32B MFU | 64B TFLOP/秒/芯片 | 64B MFU | 128B TFLOP/秒/芯片 | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% |
2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% |
4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% |
8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% |
16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |
MaxText 深受 MinGPT/NanoGPT 的启发,这是用 PyTorch 编写并针对 Nvidia GPU 的优雅的独立 GPT 实现。 MaxText 更复杂,支持更多行业标准模型并扩展到数万个芯片。最终,MaxText 的 MFU 是最近报告的该代码库 17% 的三倍多,具有大规模可扩展性,并实现了键值缓存以实现高效的自动回归解码。
MaxText 与 Nvidia/Megatron-LM 更相似,这是一个针对 Nvidia GPU 的经过精心调整的 LLM 实现。这两种实现实现了可比的 MFU。代码库的差异凸显了不同的编程策略。 MaxText是纯Python,严重依赖XLA编译器来实现高性能。相比之下,Megatron-LM 是 Python 和 CUDA 的混合体,依靠优化良好的 CUDA 内核来实现高性能。
MaxText 也与 Pax 相当。与 Pax 一样,MaxText 在 Jax 中提供了 LLM 的高性能且可扩展的实现。 Pax专注于启用强大的配置参数,使开发人员能够通过编辑配置参数来更改模型。相比之下,MaxText 是各种 LLM 的简单、具体的实现,鼓励用户通过分叉和直接编辑源代码来扩展。
在加速器上运行单程序、多数据 (SPMD) 作业时,如果出现任何错误或任何虚拟机因某种原因挂起/崩溃,整个进程可能会挂起。在这种情况下,捕获堆栈跟踪将有助于识别和解决 TPU VM 上运行的作业的问题。
以下配置将有助于通过收集堆栈跟踪来调试故障或程序卡住或挂起的情况。在MaxText/configs/base.yml
中相应地更改参数值:
collect_stack_trace: True
以在出现故障或程序挂起时启用堆栈跟踪收集。此设置将定期转储程序的跟踪以帮助调试。要禁用此功能,请设置collect_stack_trace: False
。stack_trace_to_cloud: False
以在控制台上显示堆栈跟踪。 stack_trace_to_cloud: True
将在 TPU 中的/tmp/debugging
中创建一个临时文件来存储堆栈跟踪。 TPU VM 上运行的代理会定期将跟踪从临时目录上传到 gcp 项目中的云日志记录。您可以使用以下查询在 Cloud Logging 的日志浏览器中查看跟踪: logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
表示每个堆栈跟踪收集事件之间的持续时间(以秒为单位)。设置stack_trace_interval_seconds: 600
将每 600 秒(10 分钟)收集一次堆栈跟踪。以下是相关的 PyPI 包:https://pypi.org/project/cloud-tpu-diagnostics。
为了提前编译您的训练运行,我们提供了一个工具train_compile.py
。该工具允许您在不使用完整集群的情况下为目标硬件(例如大量v5e设备)编译train.py
中的主要train_step
。
您可以仅使用不同系列的 CPU 或单个 VM 来为 TPU 集群进行预编译。此编译有助于实现两个主要目标:
它将标记任何内存不足 (OOM) 信息,例如当per_device_batch_size
设置得太高时,并使用相同的 OOM 堆栈跟踪,就好像它是在目标硬件上编译的一样。
可以保存提前编译,然后加载,以便在目标硬件上快速启动和重新启动。
工具train_compile.py
与train.py
紧密链接,并使用相同的配置文件configs/base.yml
。虽然您不需要在 TPU 上运行,但除了其他依赖项之外,您确实需要安装jax[tpu]
,因此我们建议您运行setup.sh
来安装这些依赖项(如果您尚未这样做)。
安装完上面列出的依赖项后,就可以提前编译了:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
这将在 2 个 v5e pod 上编译 16B 参数 MaxText 模型。
这是一个保存然后加载编译的train_step
的示例,从保存开始:
第 1 步:运行 AOT 并保存编译后的函数
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
第2步:运行train.py并加载编译后的函数
要加载编译后的train_step,只需将compiled_trainstep_file=my_compiled_train.pickle
传递到train.py
中:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
在上面示例 2 的保存步骤中,我们包括导出编译器标志LIBTPU_INIT_ARGS
和learning_rate
,因为它们会影响编译对象my_compiled_train.pickle.
当您最初通过compile_train.py
进行编译时,模型的大小(例如global_parameter_scale
、 max_sequence_length
和per_device_batch
)是固定的,如果您尝试以与编译时不同的大小运行保存的编译对象,您将看到大小错误。然而,一个微妙的注意事项是,当您运行compile_train
时,学习率计划也是固定的 - 这是由steps
和learning_rate
决定的。优化器参数(例如adam_b1
仅作为成形对象传递给编译器 - 因此它们的实际值是在运行train.py
时确定的,而不是在编译期间确定的。如果您确实传递了不同的形状(例如per_device_batch
),您将收到一条明确的错误消息,报告编译后的签名具有与输入不同的预期形状。如果您尝试在与通过compile_topology
请求的编译目标不同的硬件上运行,您将收到一条错误消息,指出无法将编译设备映射到真实设备。使用与编译时不同的 XLA 标志或 LIBTPU 可能会在您编译的环境中静默运行而不会出现错误。然而,在这种情况下,不保证行为;您应该在编译时的相同环境中运行。
GPU 也支持提前编译,但与 TPU 有所不同:
GPU不支持跨硬件编译:仍然需要GPU主机来运行AoT编译,但单个GPU主机可以为相同硬件的更大集群编译程序。
对于 A3 Cloud GPU,最大“切片”大小是单个主机, compile_topology_num_slices
参数表示要预编译的 A3 机器的数量。
此示例说明了用于针对 4 个 A3 主机的集群的多主机 GPU 编译的标志:
第 1 步:运行 AOT 并保存编译后的函数
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
第2步:运行train.py并加载编译后的函数
要加载编译后的train_step,只需将compiled_trainstep_file=my_compiled_train.pickle
传递到train.py
中:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
与 TPU 情况一样,请注意编译环境必须与执行环境匹配,在本例中通过设置相同的XLA_FLAGS
。
MaxText 支持将目录中收集的日志自动上传到 Vertex AI 中的 Tensorboard 实例。请按照用户指南了解更多信息。