MaxText 是一位高效能、高度可擴展的開源法學碩士,以純 Python/Jax 編寫,針對 Google Cloud TPU 和 GPU 進行訓練和推理。由於 Jax 和 XLA 編譯器的強大功能,MaxText 實現了高 MFU 並可從單主機擴展到非常大的集群,同時保持簡單和「免最佳化」。
MaxText 旨在成為研究和生產領域雄心勃勃的法學碩士課程的起點。我們鼓勵用戶先嘗試開箱即用的 MaxText,然後分叉和修改 MaxText 以滿足他們的需求。
我們使用 MaxText 演示了 int8 的高效能、良好收斂的訓練以及約 51K 晶片的規模訓練。
主要支援的功能:
對於您第一次運行 MaxText,我們提供了具體說明。
MaxText支援各種開放模型的訓練和推理。請按照入門資料夾中的使用者指南以了解更多資訊。
一些額外有用的指南:
除了入門指南之外,還有其他 MaxText 功能正在不斷添加!全套端對端測試位於 end_to_end 中。我們以每晚的步調運行它們。它們可以是理解 MaxText 的良好來源。
有關重現這些結果的更多詳細信息,請參閱 MaxText/configs/README.md。
參數數量 | 加速器類型 | TFLOP/晶片/秒 | 模型觸發器利用率 (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71.47% |
64B | v5p-128 | 3.23e+02 | 70.31% |
128B | v5p-256 | 3.15e+02 | 68.68% |
128B | v5p-512 | 3.15e+02 | 68.53% |
256B | v5p-1024 | 3.16e+02 | 68.82% |
512B | v5p-1024 | 2.94e+02 | 63.99% |
1024B | v5p-2048 | 2.49e+02 | 64.05% |
1024B | v5p-4096 | 2.97e+02 | 64.80% |
1160B | v5p-7680 | 2.95e+02 | 64.27% |
1160B | v5p-12288 | 3.04e+02 | 66.23% |
適用於 16B、32B、64B 和 128B 型號。請參閱 MaxText/configs/v5e/ 中的完整運行配置16b.sh
、 32b.sh
、 64b.sh
、 128b.sh
。
硬體 | 16B TFLOP/秒/晶片 | 16B MFU | 32B TFLOP/秒/晶片 | 32B MFU | 64B TFLOP/秒/晶片 | 64B MFU | 128B TFLOP/秒/晶片 | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61.10% | 132 | 66.86% | 118 | 59.90% | 110 | 56.06% |
2x v5e-256 | 117 | 59.37% | 128 | 64.81% | 112 | 56.66% | 110 | 55.82% |
4x v5e-256 | 117 | 59.14% | 126 | 64.10% | 110 | 55.85% | 108 | 54.93% |
8x v5e-256 | 115 | 58.27% | 125 | 63.67% | 108 | 54.96% | 104 | 52.93% |
16x v5e-256 | 111 | 56.56% | 123 | 62.26% | 105 | 53.29% | 100 | 50.86% |
32x v5e-256 | 108 | 54.65% | 119 | 60.40% | 99 | 50.18% | 91 | 46.25% |
MaxText 深受 MinGPT/NanoGPT 的啟發,這是用 PyTorch 編寫並針對 Nvidia GPU 的優雅的獨立 GPT 實現。 MaxText 更複雜,支援更多行業標準模型並擴展到數萬個晶片。最終,MaxText 的 MFU 是最近報告的該程式碼庫 17% 的三倍多,具有大規模可擴展性,並實現了鍵值快取以實現高效的自動回歸解碼。
MaxText 與 Nvidia/Megatron-LM 更相似,這是一個針對 Nvidia GPU 的精心調整的 LLM 實作。這兩種實作實現了可比較的 MFU。程式碼庫的差異凸顯了不同的程式策略。 MaxText是純Python,嚴重依賴XLA編譯器來實現高效能。相較之下,Megatron-LM 是 Python 和 CUDA 的混合體,依靠優化良好的 CUDA 核心來實現高效能。
MaxText 也與 Pax 相當。與 Pax 一樣,MaxText 在 Jax 中提供了 LLM 的高效能且可擴展的實作。 Pax專注於啟用強大的配置參數,使開發人員能夠透過編輯配置參數來更改模型。相較之下,MaxText 是各種 LLM 的簡單、具體的實現,鼓勵用戶透過分叉和直接編輯原始碼來擴展。
在加速器上執行單程式、多資料 (SPMD) 作業時,如果出現任何錯誤或任何虛擬機器因某種原因掛起/崩潰,整個進程可能會掛起。在這種情況下,擷取堆疊追蹤將有助於識別和解決 TPU VM 上執行的作業的問題。
以下配置將有助於透過收集堆疊追蹤來調試故障或程式卡住或掛起的情況。在MaxText/configs/base.yml
中相應地更改參數值:
collect_stack_trace: True
以在發生故障或程式掛起時啟用堆疊追蹤收集。此設定將定期轉儲程式的追蹤以幫助調試。若要停用此功能,請設定collect_stack_trace: False
。stack_trace_to_cloud: False
以在控制台上顯示堆疊追蹤。 stack_trace_to_cloud: True
將在 TPU 中的/tmp/debugging
中建立一個暫存檔案來儲存堆疊追蹤。 TPU VM 上執行的代理程式會定期將追蹤從臨時目錄上傳到 gcp 專案中的雲端日誌記錄。您可以使用以下查詢在 Cloud Logging 的日誌瀏覽器中查看追蹤: logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
表示每個堆疊追蹤收集事件之間的持續時間(以秒為單位)。設定stack_trace_interval_seconds: 600
將每 600 秒(10 分鐘)收集一次堆疊追蹤。以下是相關的 PyPI 套件:https://pypi.org/project/cloud-tpu-diagnostics。
為了提前編譯您的訓練運行,我們提供了一個工具train_compile.py
。此工具可讓您在不使用完整叢集的情況下為目標硬體(例如大量v5e裝置)編譯train.py
中的主要train_step
。
您可以只使用不同系列的 CPU 或單一 VM 來為 TPU 叢集進行預編譯。此編譯有助於實現兩個主要目標:
它將標記任何記憶體不足 (OOM) 訊息,例如當per_device_batch_size
設定得太高時,並使用相同的 OOM 堆疊跟踪,就好像它是在目標硬體上編譯的一樣。
可以保存提前編譯,然後加載,以便在目標硬體上快速啟動和重新啟動。
工具train_compile.py
與train.py
緊密鏈接,並使用相同的設定檔configs/base.yml
。雖然您不需要在 TPU 上運行,但除了其他依賴項之外,您確實需要安裝jax[tpu]
,因此我們建議您執行setup.sh
來安裝這些依賴項(如果您尚未這樣做)。
安裝完上面列出的依賴項後,就可以事先編譯了:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
這將在 2 個 v5e pod 上編譯 16B 參數 MaxText 模型。
這是一個儲存然後載入編譯的train_step
的範例,從儲存開始:
第 1 步:執行 AOT 並儲存編譯後的函數
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
步驟2:運行train.py並載入編譯後的函數
要載入編譯後的train_step,只需將compiled_trainstep_file=my_compiled_train.pickle
傳遞到train.py
中:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
在上面範例 2 的儲存步驟中,我們包含匯出編譯器標誌LIBTPU_INIT_ARGS
和learning_rate
,因為它們會影響編譯物件my_compiled_train.pickle.
當您最初透過compile_train.py
進行編譯時,模型的大小(例如global_parameter_scale
、 max_sequence_length
和per_device_batch
)是固定的,如果您嘗試以與編譯時不同的大小執行已儲存的編譯對象,您將看到大小錯誤。然而,一個微妙的注意事項是,當您運行compile_train
時,學習率計劃也是固定的 - 這是由steps
和learning_rate
決定的。最佳化器參數(例如adam_b1
僅作為成形物件傳遞給編譯器 - 因此它們的實際值是在執行train.py
時確定的,而不是在編譯期間確定的。如果您確實傳遞了不同的形狀(例如per_device_batch
),您將收到一條明確的錯誤訊息,報告編譯後的簽名具有與輸入不同的預期形狀。如果您嘗試在與透過compile_topology
請求的編譯目標不同的硬體上執行,您將收到錯誤訊息,指出無法將編譯裝置對應到真實裝置。使用與編譯時不同的 XLA 標誌或 LIBTPU 可能會在您編譯的環境中靜默運行而不會出現錯誤。然而,在這種情況下,不保證行為;您應該在編譯時的相同環境中執行。
GPU 也支援提前編譯,但與 TPU 有所不同:
GPU不支援跨硬體編譯:仍需要GPU主機來執行AoT編譯,但單一GPU主機可以為相同硬體的更大叢集編譯程式。
對於 A3 Cloud GPU,最大「切片」大小是單一主機, compile_topology_num_slices
參數表示要預先編譯的 A3 機器的數量。
此範例說明了用於針對 4 個 A3 主機的叢集的多主機 GPU 編譯的標誌:
第 1 步:執行 AOT 並儲存編譯後的函數
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
步驟2:運行train.py並載入編譯後的函數
要載入編譯後的train_step,只需將compiled_trainstep_file=my_compiled_train.pickle
傳遞到train.py
中:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
與 TPU 情況一樣,請注意編譯環境必須與執行環境匹配,在本例中透過設定相同的XLA_FLAGS
。
MaxText 支援將目錄中收集的日誌自動上傳至 Vertex AI 中的 Tensorboard 實例。請按照使用者指南了解更多資訊。