JAX Toolbox bietet ein öffentliches CI, Docker-Images für beliebte JAX-Bibliotheken und optimierte JAX-Beispiele, um Ihre JAX-Entwicklungserfahrung auf NVIDIA-GPUs zu vereinfachen und zu verbessern. Es unterstützt JAX-Bibliotheken wie MaxText, Paxml und Pallas.
Wir unterstützen und testen die folgenden JAX-Frameworks und Modellarchitekturen. Weitere Details zu den einzelnen Modellen und verfügbaren Containern finden Sie in den jeweiligen READMEs.
Rahmen | Modelle | Anwendungsfälle | Container |
---|---|---|---|
maxtext | GPT, LLaMA, Gemma, Mistral, Mixtral | Vorschulung | ghcr.io/nvidia/jax:maxtext |
paxml | GPT, LLaMA, MoE | Vortraining, Feinabstimmung, LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5, ViT | Vortraining, Feinabstimmung | ghcr.io/nvidia/jax:t5x |
t5x | Bild | Vorschulung | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
große Vision | PaliGemma | Feinabstimmung, Bewertung | ghcr.io/nvidia/jax:gemma |
levanter | GPT, LLaMA, MPT, Rucksäcke | Vorschulung, Feinabstimmung | ghcr.io/nvidia/jax:levanter |
Komponenten | Container | Bauen | Prüfen |
---|---|---|---|
ghcr.io/nvidia/jax:base | [keine Tests] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [Tests deaktiviert] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
In allen Fällen verweist ghcr.io/nvidia/jax:XXX
auf den neuesten nächtlichen Build des Containers für XXX
. Für eine stabile Referenz verwenden Sie ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
.
Zusätzlich zum öffentlichen CI führen wir auch interne CI-Tests für H100 SXM 80 GB und A100 SXM 80 GB durch.
Das JAX-Image ist mit den folgenden Flags und Umgebungsvariablen zur Leistungsoptimierung von XLA und NCCL eingebettet:
XLA-Flaggen | Wert | Erläuterung |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | ermöglicht es XLA, Kommunikationskollektive zu verschieben, um die Überlappung mit Rechenkernen zu erhöhen |
--xla_gpu_enable_triton_gemm | false | Verwenden Sie cuBLAS anstelle von Trition GeMM-Kerneln |
Umgebungsvariable | Wert | Erläuterung |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | Verwenden Sie eine einzelne Warteschlange für die GPU-Arbeit, um die Latenz von Stream-Vorgängen zu verringern. OK, da XLA bereits Markteinführungen bestellt |
NCCL_NVLS_ENABLE | 0 | Deaktiviert NVLink SHARP (1). Zukünftige Versionen werden diese Funktion wieder aktivieren. |
Es gibt verschiedene andere XLA-Flags, die Benutzer setzen können, um die Leistung zu verbessern. Eine ausführliche Erläuterung dieser Flags finden Sie im Dokument zur GPU-Leistung. XLA-Flags können pro Workflow angepasst werden. Beispielsweise setzt jedes Skript in contrib/gpu/scripts_gpu seine eigenen XLA-Flags.
Eine Liste der zuvor verwendeten XLA-Flags, die nicht mehr benötigt werden, finden Sie auch auf der Seite zur GPU-Leistung.
Erste Nacht mit neuem Basisbehälter | Basisbehälter |
---|---|
06.11.2024 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
25.09.2024 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
24.07.2024 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |
Weitere Informationen zum Profilieren von JAX-Programmen auf der GPU finden Sie auf dieser Seite.
Lösung:
docker run -it --shm-size=1g ...
Erläuterung: Der bus error
kann aufgrund der Größenbeschränkung von /dev/shm
auftreten. Sie können dieses Problem beheben, indem Sie beim Starten Ihres Containers die Größe des gemeinsam genutzten Speichers mithilfe der Option --shm-size
erhöhen.
Problembeschreibung:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
Lösung: Aktualisieren Sie enroot oder wenden Sie einen Einzeldatei-Patch an, wie im Versionshinweis zu enroot v3.4.0 beschrieben.
Erläuterung: Docker verwendet traditionell Docker Schema V2.2 für Multi-Arch-Manifestlisten, ist jedoch seit 20.10 auf die Verwendung des Open Container Initiative (OCI)-Formats umgestiegen. Enroot hat in Version 3.4.0 Unterstützung für das OCI-Format hinzugefügt.
AWS
EFA-Integration hinzufügen
SageMaker-Codebeispiel
GCP
Erste Schritte mit JAX-Multi-Node-Anwendungen mit NVIDIA-GPUs auf Google Kubernetes Engine
Azurblau
Beschleunigen von KI-Anwendungen mithilfe des JAX-Frameworks auf den virtuellen Maschinen NDm A100 v4 von Azure
OCI
Ausführen einer Deep-Learning-Workload mit JAX auf Multi-Node-Multi-GPU-Clustern auf OCI
JAX | NVIDIA NGC-Container
Slurm- und OpenMPI-Zero-Config-Integration
Hinzufügen benutzerdefinierter GPU-Operationen
Triaging von Regressionen
Equinox für JAX: Die Grundlage eines Ökosystems für Wissenschaft und maschinelles Lernen
Grok mit JAX und H100 skalieren
JAX Supercharged auf GPUs: Hochleistungs-LLMs mit JAX und OpenXLA
Was ist neu in JAX | AGB Frühjahr 2024
Was ist neu in JAX | AGB Frühjahr 2023