JAX Toolbox fournit un CI public, des images Docker pour les bibliothèques JAX populaires et des exemples JAX optimisés pour simplifier et améliorer votre expérience de développement JAX sur les GPU NVIDIA. Il prend en charge les bibliothèques JAX telles que MaxText, Paxml et Pallas.
Nous prenons en charge et testons les frameworks JAX et les architectures de modèles suivants. Plus de détails sur chaque modèle et les conteneurs disponibles peuvent être trouvés dans leurs README respectifs.
Cadre | Modèles | Cas d'utilisation | Récipient |
---|---|---|---|
texte maximum | GPT, LLaMA, Gemma, Mistral, Mixtral | pré-entraînement | ghcr.io/nvidia/jax:maxtext |
paxml | GPT, LLaMA, ministère de l’Environnement | pré-entraînement, mise au point, LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5, ViT | pré-formation, mise au point | ghcr.io/nvidia/jax:t5x |
t5x | Image | pré-formation | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
grande vision | PaliGemma | mise au point, évaluation | ghcr.io/nvidia/jax:gemma |
levantin | GPT, LLaMA, MPT, Sacs à dos | pré-entraînement, mise au point | ghcr.io/nvidia/jax:levanter |
Composants | Récipient | Construire | Test |
---|---|---|---|
ghcr.io/nvidia/jax:base | [pas de tests] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [tests désactivés] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
Dans tous les cas, ghcr.io/nvidia/jax:XXX
pointe vers la dernière version nocturne du conteneur pour XXX
. Pour une référence stable, utilisez ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
.
En plus du CI public, nous effectuons également des tests CI internes sur H100 SXM 80 Go et A100 SXM 80 Go.
L'image JAX est intégrée aux indicateurs et variables d'environnement suivants pour l'optimisation des performances de XLA et NCCL :
Drapeaux XLA | Valeur | Explication |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | permet à XLA de déplacer les collectifs de communication pour augmenter le chevauchement avec les noyaux de calcul |
--xla_gpu_enable_triton_gemm | false | utiliser cuBLAS au lieu des noyaux Trition GeMM |
Variable d'environnement | Valeur | Explication |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | utiliser une seule file d'attente pour le travail GPU afin de réduire la latence des opérations de flux ; OK puisque XLA commande déjà des lancements |
NCCL_NVLS_ENABLE | 0 | Désactive NVLink SHARP (1). Les prochaines versions réactiveront cette fonctionnalité. |
Il existe divers autres indicateurs XLA que les utilisateurs peuvent définir pour améliorer les performances. Pour une explication détaillée de ces indicateurs, veuillez vous référer à la documentation sur les performances du GPU. Les indicateurs XLA peuvent être réglés par flux de travail. Par exemple, chaque script dans contrib/gpu/scripts_gpu définit ses propres indicateurs XLA.
Pour une liste des indicateurs XLA précédemment utilisés qui ne sont plus nécessaires, veuillez également vous référer à la page des performances du GPU.
Première nuit avec un nouveau conteneur de base | Conteneur de base |
---|---|
2024-11-06 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
2024-09-25 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
2024-07-24 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |
Consultez cette page pour plus d'informations sur la façon de profiler les programmes JAX sur GPU.
Solution:
docker run -it --shm-size=1g ...
Explication : L' bus error
peut se produire en raison de la limitation de taille de /dev/shm
. Vous pouvez résoudre ce problème en augmentant la taille de la mémoire partagée à l'aide de l'option --shm-size
lors du lancement de votre conteneur.
Description du problème :
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
Solution : Mettez à niveau enroot ou appliquez un correctif à fichier unique comme mentionné dans la note de version enroot v3.4.0.
Explication : Docker utilise traditionnellement Docker Schema V2.2 pour les listes de manifestes multi-arch, mais est passé au format Open Container Initiative (OCI) depuis la version 20.10. Enroot a ajouté la prise en charge du format OCI dans la version 3.4.0.
AWS
Ajouter l'intégration EFA
Exemple de code SageMaker
GCP
Premiers pas avec les applications multi-nœuds JAX avec GPU NVIDIA sur Google Kubernetes Engine
Azuré
Accélération des applications d'IA à l'aide du framework JAX sur les machines virtuelles NDm A100 v4 d'Azure
OCI
Exécuter une charge de travail de deep learning avec JAX sur des clusters multi-nœuds multi-GPU sur OCI
JAX | Conteneur NVIDIA NGC
Intégration sans configuration Slurm et OpenMPI
Ajout d'opérations GPU personnalisées
Trier les régressions
Equinox pour JAX : la fondation d'un écosystème pour la science et l'apprentissage automatique
Faire évoluer Grok avec JAX et H100
JAX Supercharged sur les GPU : LLM hautes performances avec JAX et OpenXLA
Quoi de neuf dans JAX | CGV Printemps 2024
Quoi de neuf dans JAX | CGV Printemps 2023