JAX Toolbox proporciona un CI público, imágenes Docker para bibliotecas JAX populares y ejemplos JAX optimizados para simplificar y mejorar su experiencia de desarrollo JAX en GPU NVIDIA. Admite bibliotecas JAX como MaxText, Paxml y Pallas.
Admitimos y probamos los siguientes marcos JAX y arquitecturas de modelos. Se pueden encontrar más detalles sobre cada modelo y contenedores disponibles en sus respectivos README.
Estructura | Modelos | Casos de uso | Recipiente |
---|---|---|---|
texto máximo | GPT, LLaMA, Gemma, Mistral, Mixtral | preentrenamiento | ghcr.io/nvidia/jax:maxtext |
paxml | GPT, LLaMA, Ministerio de Educación | preentrenamiento, ajuste fino, LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5, ViT | preentrenamiento, ajuste fino | ghcr.io/nvidia/jax:t5x |
t5x | Imagen | pre-entrenamiento | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
gran visión | PaliGemma | ajuste, evaluación | ghcr.io/nvidia/jax:gemma |
levantero | GPT, LLaMA, MPT, Mochilas | preentrenamiento, ajuste fino | ghcr.io/nvidia/jax:levanter |
Componentes | Recipiente | Construir | Prueba |
---|---|---|---|
ghcr.io/nvidia/jax:base | [sin pruebas] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [pruebas deshabilitadas] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
En todos los casos, ghcr.io/nvidia/jax:XXX
apunta a la última compilación nocturna del contenedor para XXX
. Para obtener una referencia estable, utilice ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
.
Además del CI público, también realizamos pruebas de CI internas en H100 SXM 80GB y A100 SXM 80GB.
La imagen JAX está integrada con los siguientes indicadores y variables de entorno para ajustar el rendimiento de XLA y NCCL:
Banderas XLA | Valor | Explicación |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | permite que XLA mueva colectivos de comunicación para aumentar la superposición con los núcleos informáticos |
--xla_gpu_enable_triton_gemm | false | use cuBLAS en lugar de núcleos Trition GeMM |
Variable de entorno | Valor | Explicación |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | utilice una única cola para el trabajo de GPU para reducir la latencia de las operaciones de transmisión; OK ya que XLA ya ordena lanzamientos |
NCCL_NVLS_ENABLE | 0 | Desactiva NVLink SHARP (1). Las versiones futuras volverán a habilitar esta función. |
Hay varios otros indicadores XLA que los usuarios pueden configurar para mejorar el rendimiento. Para obtener una explicación detallada de estos indicadores, consulte el documento de rendimiento de la GPU. Los indicadores XLA se pueden ajustar por flujo de trabajo. Por ejemplo, cada script en contrib/gpu/scripts_gpu establece sus propios indicadores XLA.
Para obtener una lista de indicadores XLA utilizados anteriormente que ya no son necesarios, consulte también la página de rendimiento de GPU.
Primera noche con nuevo contenedor base | Contenedor base |
---|---|
2024-11-06 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
2024-09-25 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
2024-07-24 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |
Consulte esta página para obtener más información sobre cómo crear perfiles de programas JAX en GPU.
Solución:
ventana acoplable ejecutar -it --shm-size=1g ...
Explicación: El bus error
puede ocurrir debido a la limitación de tamaño de /dev/shm
. Puede solucionar esto aumentando el tamaño de la memoria compartida usando la opción --shm-size
al iniciar su contenedor.
Descripción del problema:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
Solución: actualice enroot o aplique un parche de archivo único como se menciona en la nota de la versión de enroot v3.4.0.
Explicación: Docker ha utilizado tradicionalmente Docker Schema V2.2 para listas de manifiestos de múltiples arcos, pero ha pasado a utilizar el formato Open Container Initiative (OCI) desde 20.10. Enroot agregó soporte para el formato OCI en la versión 3.4.0.
AWS
Agregar integración de EFA
Ejemplo de código de SageMaker
PCG
Primeros pasos con aplicaciones JAX de múltiples nodos con GPU NVIDIA en Google Kubernetes Engine
Azur
Aceleración de aplicaciones de IA utilizando el marco JAX en máquinas virtuales NDm A100 v4 de Azure
OCI
Ejecución de una carga de trabajo de aprendizaje profundo con JAX en clústeres multinodo y multiGPU en OCI
JAX | Contenedor NVIDIA NGC
Integración de configuración cero de Slurm y OpenMPI
Agregar operaciones de GPU personalizadas
Triaje de regresiones
Equinox para JAX: la base de un ecosistema para la ciencia y el aprendizaje automático
Escalando Grok con JAX y H100
JAX sobrealimentado en GPU: LLM de alto rendimiento con JAX y OpenXLA
Novedades de JAX | GTC Primavera 2024
Novedades de JAX | GTC Primavera 2023