JAX Toolbox 提供公共 CI、流行 JAX 庫的 Docker 映像以及優化的 JAX 範例,以簡化和增強您在 NVIDIA GPU 上的 JAX 開發體驗。它支援 MaxText、Paxml 和 Pallas 等 JAX 庫。
我們支援並測試以下 JAX 框架和模型架構。有關每個模型和可用容器的更多詳細資訊可以在各自的自述文件中找到。
框架 | 型號 | 使用案例 | 容器 |
---|---|---|---|
最大文字 | GPT、LLaMA、Gemma、Mistral、Mixtral | 預訓練 | ghcr.io/nvidia/jax:maxtext |
帕XML | GPT、駱駝、教育部 | 預訓練、微調、LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5、ViT | 預訓練、微調 | ghcr.io/nvidia/jax:t5x |
t5x | 影像 | 預訓練 | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
大願景 | 巴利語傑瑪 | 微調、評估 | ghcr.io/nvidia/jax:gemma |
萊萬特 | GPT、LLaMA、MPT、背包 | 預訓練、微調 | ghcr.io/nvidia/jax:levanter |
成分 | 容器 | 建造 | 測試 |
---|---|---|---|
ghcr.io/nvidia/jax:base | [沒有測試] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [測試已停用] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
在所有情況下, ghcr.io/nvidia/jax:XXX
XXX 都指向XXX
容器的最新夜間建置。要獲得穩定的參考,請使用ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
。
除了公共 CI 之外,我們還在 H100 SXM 80GB 和 A100 SXM 80GB 上執行內部 CI 測試。
JAX 映像嵌入了以下標誌和環境變量,用於 XLA 和 NCCL 的效能調整:
XLA 旗幟 | 價值 | 解釋 |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | 允許 XLA 行動通訊集合以增加與計算內核的重疊 |
--xla_gpu_enable_triton_gemm | false | 使用 cuBLAS 取代 Trition GeMM 內核 |
環境變數 | 價值 | 解釋 |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | 使用單一佇列進行 GPU 工作以降低流程操作的延遲;好的,因為 XLA 已經下單了 |
NCCL_NVLS_ENABLE | 0 | 停用 NVLink SHARP (1)。未來的版本將重新啟用此功能。 |
使用者可以設定各種其他 XLA 標誌來提高效能。有關這些標誌的詳細說明,請參閱 GPU 效能文件。 XLA 標誌可根據工作流程進行調整。例如,contrib/gpu/scripts_gpu 中的每個腳本都會設定自己的 XLA 標誌。
有關不再需要的先前使用的 XLA 標誌的列表,請另參閱 GPU 效能頁面。
第一個晚上使用新的基礎容器 | 基礎容器 |
---|---|
2024-11-06 | 英偉達/cuda:12.6.2-devel-ubuntu22.04 |
2024-09-25 | 英偉達/cuda:12.6.1-devel-ubuntu22.04 |
2024-07-24 | 英偉達/cuda:12.5.0-devel-ubuntu22.04 |
有關如何在 GPU 上分析 JAX 程式的更多信息,請參閱此頁面。
解決方案:
docker run -it --shm-size=1g ...
解釋:由於/dev/shm
的大小限制,可能會發生bus error
。您可以透過在啟動容器時使用--shm-size
選項增加共享記憶體大小來解決此問題。
問題描述:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
解決方案:升級 enroot 或套用 enroot v3.4.0 發行說明中提到的單一檔案修補程式。
說明: Docker 傳統上使用 Docker Schema V2.2 來表示多架構清單列表,但自 20.10 起已改用開放容器計劃 (OCI) 格式。 Enroot在3.4.0版本中新增了對OCI格式的支援。
AWS
添加 EFA 集成
SageMaker 程式碼範例
GCP
開始在 Google Kubernetes Engine 上使用 NVIDIA GPU 的 JAX 多節點應用程式
天藍色
在 Azure NDm A100 v4 虛擬機器上使用 JAX 框架加速 AI 應用程式
OCI
在 OCI 上的多節點多 GPU 叢集上使用 JAX 運行深度學習工作負載
賈克斯| NVIDIA NGC 容器
Slurm 和 OpenMPI 零配置集成
新增自訂 GPU 操作
將迴歸進行分類
Equinox for JAX:科學與機器學習生態系的基礎
使用 JAX 和 H100 擴展 Grok
GPU 上的 JAX 增壓:使用 JAX 和 OpenXLA 的高效能法學碩士
JAX 的新增功能 | 2024 年春季 GTC
JAX 的新增功能 | 2023 年春季 GTC