JAX Toolbox は、パブリック CI、一般的な JAX ライブラリの Docker イメージ、最適化された JAX サンプルを提供して、NVIDIA GPU での JAX 開発エクスペリエンスを簡素化および強化します。 MaxText、Paxml、Pallas などの JAX ライブラリをサポートします。
私たちは、次の JAX フレームワークとモデル アーキテクチャをサポートし、テストします。各モデルと利用可能なコンテナの詳細については、それぞれの README を参照してください。
フレームワーク | モデル | ユースケース | 容器 |
---|---|---|---|
マックステキスト | GPT、LLaMA、ジェマ、ミストラル、ミストラル | 事前トレーニング | ghcr.io/nvidia/jax:maxtext |
paxml | GPT、LLaMA、MoE | 事前トレーニング、微調整、LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5、ViT | 事前トレーニング、微調整 | ghcr.io/nvidia/jax:t5x |
t5x | イマージェン | 事前トレーニング | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
大きなビジョン | パリジェンマ | 微調整、評価 | ghcr.io/nvidia/jax:gemma |
レヴァンター | GPT、LLaMA、MPT、バックパック | 事前トレーニング、微調整 | ghcr.io/nvidia/jax:levanter |
コンポーネント | 容器 | 建てる | テスト |
---|---|---|---|
ghcr.io/nvidia/jax:base | [テストなし] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [テストが無効になっています] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
いずれの場合も、 ghcr.io/nvidia/jax:XXX
、 XXX
のコンテナの最新の夜間ビルドを指します。安定したリファレンスとしては、 ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
使用してください。
パブリック CI に加えて、H100 SXM 80GB および A100 SXM 80GB で内部 CI テストも実行します。
JAX イメージには、XLA および NCCL のパフォーマンス チューニングのために次のフラグと環境変数が埋め込まれています。
XLA フラグ | 価値 | 説明 |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | XLA が通信集合体を移動して、コンピューティング カーネルとの重複を増やすことができるようになります。 |
--xla_gpu_enable_triton_gemm | false | Trition GeMM カーネルの代わりに cuBLAS を使用する |
環境変数 | 価値 | 説明 |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | GPU 作業に単一のキューを使用して、ストリーム操作のレイテンシを短縮します。 XLA がすでに発売を発注しているため、OK |
NCCL_NVLS_ENABLE | 0 | NVLink SHARP を無効にします (1)。将来のリリースでは、この機能が再び有効になる予定です。 |
パフォーマンスを向上させるためにユーザーが設定できるさまざまな XLA フラグが他にもあります。これらのフラグの詳細については、GPU パフォーマンスのドキュメントを参照してください。 XLA フラグはワークフローごとに調整できます。たとえば、contrib/gpu/scripts_gpu 内の各スクリプトは独自の XLA フラグを設定します。
以前に使用され、不要になった XLA フラグのリストについては、GPU パフォーマンスのページも参照してください。
新しいベースコンテナを使用した最初の夜間 | ベースコンテナ |
---|---|
2024-11-06 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
2024-09-25 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
2024-07-24 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |
GPU で JAX プログラムをプロファイリングする方法の詳細については、このページを参照してください。
解決:
docker run -it --shm-size=1g ...
説明: /dev/shm
のサイズ制限が原因でbus error
が発生する可能性があります。この問題は、コンテナーの起動時に--shm-size
オプションを使用して共有メモリのサイズを増やすことで解決できます。
問題の説明:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
解決策: enroot をアップグレードするか、enroot v3.4.0 リリース ノートに記載されている単一ファイルのパッチを適用します。
説明: Docker は従来、マルチ アーキテクチャ マニフェスト リストに Docker Schema V2.2 を使用してきましたが、20.10 以降は Open Container Initiative (OCI) 形式の使用に切り替えました。 Enroot は、バージョン 3.4.0 で OCI 形式のサポートを追加しました。
AWS
EFA統合を追加
SageMaker コードサンプル
GCP
Google Kubernetes Engine 上の NVIDIA GPU を使用した JAX マルチノード アプリケーションの入門
アズール
Azure の NDm A100 v4 仮想マシン上で JAX フレームワークを使用して AI アプリケーションを高速化する
OCI
OCI 上のマルチノード マルチ GPU クラスター上で JAX を使用したディープ ラーニング ワークロードを実行する
ジャックス | NVIDIA NGC コンテナ
Slurm と OpenMPI のゼロ構成統合
カスタム GPU 演算の追加
回帰のトリアージ
JAX の Equinox: 科学と機械学習のためのエコシステムの基盤
JAX および H100 を使用した Grok のスケーリング
GPU 上でスーパーチャージされた JAX: JAX および OpenXLA を使用した高性能 LLM
JAX の新機能 | GTC 2024 年春
JAX の新機能 | GTC 2023 年春