MaxText は、純粋な Python/Jax で書かれた高性能でスケーラブルなオープンソースLLM で、トレーニングと推論のために Google Cloud TPU と GPU をターゲットとしています。 MaxText は、Jax と XLA コンパイラーのおかげで、シンプルかつ「最適化不要」を維持しながら、高い MFU を達成し、単一ホストから非常に大規模なクラスターまで拡張できます。
MaxText は、研究と生産の両方における野心的な LLM プロジェクトの出発点となることを目指しています。まずはすぐに MaxText を試してから、ニーズに合わせて MaxText をフォークして変更することをお勧めします。
MaxText を使用して、int8 での高性能でよく収束したトレーニングを実証し、トレーニングを最大 51,000 チップまで拡張しました。
サポートされている主な機能:
初めて MaxText を実行する場合は、具体的な手順を説明します。
MaxText は、さまざまなオープン モデルのトレーニングと推論をサポートします。詳細については、「はじめに」フォルダーのユーザー ガイドに従ってください。
いくつかの追加の役立つガイド:
入門ガイドに加えて、他の MaxText 機能も常に追加されています。エンドツーエンド テストの完全なスイートは end_to_end にあります。私たちはそれらを毎晩のペースで実行します。これらは、MaxText を理解するための良い情報源になる可能性があります。あるいは、ほぼ継続的に実行される継続的な単体テストを確認することもできます。
これらの結果の再現に関する詳細については、MaxText/configs/README.md を参照してください。
パラメータの数 | アクセルタイプ | TFLOP/チップ/秒 | モデル フロップ使用率 (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71.47% |
64B | v5p-128 | 3.23e+02 | 70.31% |
128B | v5p-256 | 3.15e+02 | 68.68% |
128B | v5p-512 | 3.15e+02 | 68.53% |
256B | v5p-1024 | 3.16e+02 | 68.82% |
512B | v5p-1024 | 2.94e+02 | 63.99% |
1024B | v5p-2048 | 2.49e+02 | 64.05% |
1024B | v5p-4096 | 2.97e+02 | 64.80% |
1160B | v5p-7680 | 2.95e+02 | 64.27% |
1160B | v5p-12288 | 3.04e+02 | 66.23% |
16B、32B、64B、128Bモデル用。 MaxText/configs/v5e/ にある完全な実行構成 ( 16b.sh
、 32b.sh
、 64b.sh
、 128b.sh
を参照してください。
ハードウェア | 16B TFLOP/秒/チップ | 16B MFU | 32B TFLOP/秒/チップ | 32B MFU | 64B TFLOP/秒/チップ | 64B MFU | 128B TFLOP/秒/チップ | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% |
2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% |
4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% |
8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% |
16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |
MaxText は、PyTorch で記述され、Nvidia GPU をターゲットとするエレガントなスタンドアロン GPT 実装である MinGPT/NanoGPT から多大な影響を受けています。 MaxText はより複雑で、より多くの業界標準モデルをサポートし、数万チップまで拡張できます。最終的に、MaxText は、そのコードベースで最近報告された 17% の 3 倍を超える MFU を持ち、非常にスケーラブルで、効率的な自動回帰デコードのためのキーと値のキャッシュを実装しています。
MaxText は、Nvidia GPU をターゲットとして非常によく調整された LLM 実装である Nvidia/Megatron-LM によく似ています。 2 つの実装は同等の MFU を実現します。コードベースの違いは、プログラミング戦略の違いを浮き彫りにします。 MaxText は純粋な Python であり、高いパフォーマンスを実現するために XLA コンパイラーに大きく依存しています。対照的に、Megatron-LM は Python と CUDA を組み合わせたもので、適切に最適化された CUDA カーネルに依存して高いパフォーマンスを実現します。
MaxText も Pax に匹敵します。 Pax と同様に、MaxText は、Jax での LLM の高性能でスケーラブルな実装を提供します。 Pax は強力な構成パラメーターを有効にすることに重点を置き、開発者が構成パラメーターを編集してモデルを変更できるようにします。対照的に、MaxText はさまざまな LLM のシンプルで具体的な実装であり、ユーザーがソース コードをフォークして直接編集することによって拡張することを奨励します。
アクセラレータで単一プログラム、複数データ (SPMD) ジョブを実行する場合、エラーが発生したり、何らかの理由で VM がハング/クラッシュしたりすると、プロセス全体がハングする可能性があります。このシナリオでは、スタック トレースをキャプチャすると、TPU VM で実行されているジョブの問題を特定してトラブルシューティングするのに役立ちます。
次の構成は、スタック トレースを収集することにより、障害のデバッグや、プログラムがどこかでスタックまたはハングした場合に役立ちます。 MaxText/configs/base.yml
でパラメータ値を適宜変更します。
collect_stack_trace: True
を設定します。この設定により、デバッグに役立つプログラムのトレースが定期的にダンプされます。これを無効にするには、 collect_stack_trace: False
を設定します。stack_trace_to_cloud: False
を設定すると、スタック トレースがコンソールに表示されます。 stack_trace_to_cloud: True
スタック トレースを保存するために、TPU の/tmp/debugging
に一時ファイルを作成します。 TPU VM 上で実行されているエージェントがあり、一時ディレクトリから gcp プロジェクトのクラウド ログにトレースを定期的にアップロードします。次のクエリを使用して、Cloud Logging のログ エクスプローラでトレースを表示できます。 logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
各スタック トレース収集イベント間の期間を秒単位で示します。 stack_trace_interval_seconds: 600
設定すると、600 秒 (10 分) ごとにスタック トレースが収集されます。関連する PyPI パッケージは次のとおりです: https://pypi.org/project/cloud-tpu-diagnostics。
トレーニングの実行を事前にコンパイルするために、ツールtrain_compile.py
が提供されています。このツールを使用すると、クラスター全体を使用せずに、ターゲット ハードウェア (多数の v5e デバイスなど) 用にtrain.py
のメインtrain_step
コンパイルできます。
TPU クラスターのプリコンパイルには、別のファミリーの CPU または単一の VM のみを使用できます。このコンパイルは、次の 2 つの主な目的に役立ちます。
per_device_batch_size
の設定が高すぎる場合など、メモリ不足 (OOM) 情報にフラグを立て、ターゲット ハードウェア上でコンパイルされたかのように同一の OOM スタック トレースを表示します。
事前コンパイルを保存してロードすると、ターゲット ハードウェアでの起動と再起動の時間が短縮されます。
train_compile.py
ツールはtrain.py
と密接にリンクされており、同じ構成ファイルconfigs/base.yml
を使用します。 TPU 上で実行する必要はありませんが、他の依存関係に加えてjax[tpu]
をインストールする必要があるため、まだインストールしていない場合は、 setup.sh
実行してこれらをインストールすることをお勧めします。
上記の依存関係をインストールしたら、事前にコンパイルする準備が整います。
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
これにより、2 つの v5e ポッド上で 16B パラメーターの MaxText モデルがコンパイルされます。
以下は、コンパイルされたtrain_step
保存してロードする例です。保存から始めます。
ステップ 1: AOT を実行し、コンパイルされた関数を保存する
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
ステップ 2: train.py を実行し、コンパイルされた関数をロードする
コンパイルされた train_step をロードするには、次のように、 compiled_trainstep_file=my_compiled_train.pickle
train.py
に渡すだけです。
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
上記の例 2 の保存ステップでは、コンパイラ フラグLIBTPU_INIT_ARGS
とlearning_rate
エクスポートを含めました。これは、これらがコンパイル済みオブジェクトmy_compiled_train.pickle.
モデルのサイズ (例: global_parameter_scale
、 max_sequence_length
、 per_device_batch
) は、 最初にcompile_train.py
でコンパイルするときに固定されます。コンパイルしたときと異なるサイズで保存されたコンパイル済みオブジェクトを実行しようとすると、サイズ エラーが表示されます。ただし、微妙な注意点は、学習率スケジュールも、 compile_train
実行するときに固定されることです。これは、 steps
とlearning_rate
の両方によって決定されます。 adam_b1
などのオプティマイザ パラメータは、整形オブジェクトとしてのみコンパイラに渡されます。そのため、実際の値は、コンパイル時ではなく、 train.py
の実行時に決定されます。異なる形状 (例: per_device_batch
) を渡すと、コンパイルされた署名の予期された形状が入力されたものとは異なることを報告する明確なエラー メッセージが表示されます。 compile_topology
経由で要求されたコンパイル ターゲットとは異なるハードウェアで実行しようとすると、コンパイル済みデバイスから実際のデバイスへのデバイスのマッピングに失敗したことを示すエラーが表示されます。コンパイルされたものとは異なる XLA フラグまたは LIBTPU を使用すると、コンパイルした環境でエラーなしでサイレントに実行される可能性があります。ただし、この場合の動作は保証されません。コンパイルしたのと同じ環境で実行する必要があります。
事前コンパイルは GPU でもサポートされていますが、TPU とはいくつかの違いがあります。
GPU はハードウェア間でのコンパイルをサポートしません。AoT コンパイルを実行するには GPU ホストが依然として必要ですが、単一の GPU ホストで同じハードウェアのより大きなクラスターのプログラムをコンパイルできます。
A3 クラウド GPU の場合、最大「スライス」サイズは単一ホストであり、 compile_topology_num_slices
パラメーターはプリコンパイルする A3 マシンの数を表します。
この例は、4 つの A3 ホストのクラスターをターゲットとするマルチホスト GPU コンパイルに使用するフラグを示しています。
ステップ 1: AOT を実行し、コンパイルされた関数を保存する
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
ステップ 2: train.py を実行し、コンパイルされた関数をロードする
コンパイルされた train_step をロードするには、次のように、 compiled_trainstep_file=my_compiled_train.pickle
train.py
に渡すだけです。
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
TPU の場合と同様に、コンパイル環境は実行環境と一致する必要があることに注意してください。この場合は、同じXLA_FLAGS
を設定します。
MaxText は、ディレクトリに収集されたログの Vertex AI の Tensorboard インスタンスへの自動アップロードをサポートしています。詳細については、ユーザーガイドに従ってください。