JAX Toolbox 提供公共 CI、流行 JAX 库的 Docker 映像以及优化的 JAX 示例,以简化和增强您在 NVIDIA GPU 上的 JAX 开发体验。它支持 MaxText、Paxml 和 Pallas 等 JAX 库。
我们支持并测试以下 JAX 框架和模型架构。有关每个模型和可用容器的更多详细信息可以在各自的自述文件中找到。
框架 | 型号 | 使用案例 | 容器 |
---|---|---|---|
最大文本 | GPT、LLaMA、Gemma、Mistral、Mixtral | 预训练 | ghcr.io/nvidia/jax:maxtext |
帕XML | GPT、骆驼、教育部 | 预训练、微调、LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5、ViT | 预训练、微调 | ghcr.io/nvidia/jax:t5x |
t5x | 图像 | 预训练 | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
大愿景 | 巴利语杰玛 | 微调、评估 | ghcr.io/nvidia/jax:gemma |
莱万特 | GPT、LLaMA、MPT、背包 | 预训练、微调 | ghcr.io/nvidia/jax:levanter |
成分 | 容器 | 建造 | 测试 |
---|---|---|---|
ghcr.io/nvidia/jax:base | [没有测试] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [测试已禁用] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
在所有情况下, ghcr.io/nvidia/jax:XXX
XXX 都指向XXX
容器的最新夜间构建。要获得稳定的参考,请使用ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
。
除了公共 CI 之外,我们还在 H100 SXM 80GB 和 A100 SXM 80GB 上运行内部 CI 测试。
JAX 映像嵌入了以下标志和环境变量,用于 XLA 和 NCCL 的性能调整:
XLA 旗帜 | 价值 | 解释 |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | 允许 XLA 移动通信集合以增加与计算内核的重叠 |
--xla_gpu_enable_triton_gemm | false | 使用 cuBLAS 代替 Trition GeMM 内核 |
环境变量 | 价值 | 解释 |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | 使用单个队列进行 GPU 工作以降低流操作的延迟;好的,因为 XLA 已经下订单了 |
NCCL_NVLS_ENABLE | 0 | 禁用 NVLink SHARP (1)。未来的版本将重新启用此功能。 |
用户可以设置各种其他 XLA 标志来提高性能。有关这些标志的详细说明,请参阅 GPU 性能文档。 XLA 标志可以根据工作流程进行调整。例如,contrib/gpu/scripts_gpu 中的每个脚本都会设置自己的 XLA 标志。
有关不再需要的以前使用的 XLA 标志的列表,另请参阅 GPU 性能页面。
第一个晚上使用新的基础容器 | 基础容器 |
---|---|
2024-11-06 | 英伟达/cuda:12.6.2-devel-ubuntu22.04 |
2024-09-25 | 英伟达/cuda:12.6.1-devel-ubuntu22.04 |
2024-07-24 | 英伟达/cuda:12.5.0-devel-ubuntu22.04 |
有关如何在 GPU 上分析 JAX 程序的更多信息,请参阅此页面。
解决方案:
docker run -it --shm-size=1g ...
解释:由于/dev/shm
的大小限制,可能会发生bus error
。您可以通过在启动容器时使用--shm-size
选项增加共享内存大小来解决此问题。
问题描述:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
解决方案:升级 enroot 或应用 enroot v3.4.0 发行说明中提到的单文件补丁。
说明: Docker 传统上使用 Docker Schema V2.2 来表示多架构清单列表,但自 20.10 起已改用开放容器计划 (OCI) 格式。 Enroot在3.4.0版本中添加了对OCI格式的支持。
AWS
添加 EFA 集成
SageMaker 代码示例
GCP
开始在 Google Kubernetes Engine 上使用 NVIDIA GPU 的 JAX 多节点应用程序
天蓝色
在 Azure NDm A100 v4 虚拟机上使用 JAX 框架加速 AI 应用程序
OCI
在 OCI 上的多节点多 GPU 集群上使用 JAX 运行深度学习工作负载
贾克斯| NVIDIA NGC 容器
Slurm 和 OpenMPI 零配置集成
添加自定义 GPU 操作
对回归进行分类
Equinox for JAX:科学和机器学习生态系统的基础
使用 JAX 和 H100 扩展 Grok
GPU 上的 JAX 增压:使用 JAX 和 OpenXLA 的高性能法学硕士
JAX 的新增功能 | 2024 年春季 GTC
JAX 的新增功能 | 2023 年春季 GTC