O JAX Toolbox fornece um CI público, imagens Docker para bibliotecas JAX populares e exemplos JAX otimizados para simplificar e aprimorar sua experiência de desenvolvimento JAX em GPUs NVIDIA. Suporta bibliotecas JAX como MaxText, Paxml e Pallas.
Oferecemos suporte e testamos as seguintes estruturas JAX e arquiteturas de modelo. Mais detalhes sobre cada modelo e containers disponíveis podem ser encontrados nos respectivos READMEs.
Estrutura | Modelos | Casos de uso | Recipiente |
---|---|---|---|
texto máximo | GPT, LLaMA, Gemma, Mistral, Mixtral | pré-treinamento | ghcr.io/nvidia/jax:maxtext |
paxml | GPT, LLaMA, MoE | pré-treinamento, ajuste fino, LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5, ViT | pré-treinamento, ajuste fino | ghcr.io/nvidia/jax:t5x |
t5x | Imagem | pré-treinamento | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
grande visão | Pali Gemma | ajuste fino, avaliação | ghcr.io/nvidia/jax:gemma |
Levanter | GPT, LLaMA, MPT, Mochilas | pré-treinamento, ajuste fino | ghcr.io/nvidia/jax:levanter |
Componentes | Recipiente | Construir | Teste |
---|---|---|---|
ghcr.io/nvidia/jax:base | [sem testes] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [testes desativados] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
Em todos os casos, ghcr.io/nvidia/jax:XXX
aponta para a versão noturna mais recente do contêiner para XXX
. Para uma referência estável, use ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
.
Além do CI público, também executamos testes de CI internos no H100 SXM 80GB e no A100 SXM 80GB.
A imagem JAX é integrada com os seguintes sinalizadores e variáveis de ambiente para ajuste de desempenho de XLA e NCCL:
Bandeiras XLA | Valor | Explicação |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | permite que o XLA mova coletivos de comunicação para aumentar a sobreposição com kernels de computação |
--xla_gpu_enable_triton_gemm | false | use cuBLAS em vez de kernels Trition GeMM |
Variável de ambiente | Valor | Explicação |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | use uma única fila para trabalho de GPU para reduzir a latência das operações de streaming; OK, já que XLA já encomenda lançamentos |
NCCL_NVLS_ENABLE | 0 | Desativa o NVLink SHARP (1). Versões futuras reativarão esse recurso. |
Existem vários outros sinalizadores XLA que os usuários podem definir para melhorar o desempenho. Para obter uma explicação detalhada desses sinalizadores, consulte o documento de desempenho da GPU. Os sinalizadores XLA podem ser ajustados por fluxo de trabalho. Por exemplo, cada script em contrib/gpu/scripts_gpu define seus próprios sinalizadores XLA.
Para obter uma lista de sinalizadores XLA usados anteriormente que não são mais necessários, consulte também a página de desempenho da GPU.
Primeira noite com novo contêiner base | Recipiente base |
---|---|
06/11/2024 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
25/09/2024 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
2024-07-24 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |
Consulte esta página para obter mais informações sobre como criar perfil de programas JAX na GPU.
Solução:
docker execute -it --shm-size=1g ...
Explicação: O bus error
pode ocorrer devido à limitação de tamanho de /dev/shm
. Você pode resolver isso aumentando o tamanho da memória compartilhada usando a opção --shm-size
ao iniciar seu contêiner.
Descrição do problema:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
Solução: atualize o enroot ou aplique um patch de arquivo único conforme mencionado na nota de versão do enroot v3.4.0.
Explicação: O Docker tradicionalmente usa o Docker Schema V2.2 para listas de manifestos multi-arch, mas passou a usar o formato Open Container Initiative (OCI) desde 20.10. Enroot adicionou suporte para formato OCI na versão 3.4.0.
AWS
Adicionar integração EFA
Exemplo de código do SageMaker
GCP
Primeiros passos com aplicativos JAX de vários nós com GPUs NVIDIA no Google Kubernetes Engine
Azul
Acelerando aplicativos de IA usando a estrutura JAX nas máquinas virtuais NDm A100 v4 do Azure
OCI
Executando uma carga de trabalho de aprendizagem profunda com JAX em clusters multi-GPU de vários nós no OCI
JAX | Contêiner NVIDIA NGC
Integração de configuração zero Slurm e OpenMPI
Adicionando operações de GPU personalizadas
Triagem de regressões
Equinócio para JAX: a base de um ecossistema para ciência e aprendizado de máquina
Dimensionando Grok com JAX e H100
JAX Supercharged em GPUs: LLMs de alto desempenho com JAX e OpenXLA
O que há de novo no JAX | GTC Primavera 2024
O que há de novo no JAX | GTC Primavera 2023