JAX Toolbox는 공개 CI, 인기 있는 JAX 라이브러리용 Docker 이미지, 최적화된 JAX 예제를 제공하여 NVIDIA GPU에서 JAX 개발 환경을 단순화하고 향상시킵니다. MaxText, Paxml 및 Pallas와 같은 JAX 라이브러리를 지원합니다.
우리는 다음 JAX 프레임워크와 모델 아키텍처를 지원하고 테스트합니다. 각 모델과 사용 가능한 컨테이너에 대한 자세한 내용은 해당 README에서 확인할 수 있습니다.
뼈대 | 모델 | 사용 사례 | 컨테이너 |
---|---|---|---|
최대텍스트 | GPT, LLaMA, 젬마, 미스트랄, 믹스트랄 | 사전 훈련 | ghcr.io/nvidia/jax:maxtext |
paxml | GPT, LLaMA, MoE | 사전 훈련, 미세 조정, LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5, ViT | 사전 훈련, 미세 조정 | ghcr.io/nvidia/jax:t5x |
t5x | 이미지 | 사전 훈련 | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
큰 비전 | 팔리젬마 | 미세 조정, 평가 | ghcr.io/nvidia/jax:gemma |
도망자 | GPT, LLaMA, MPT, 백팩 | 사전 훈련, 미세 조정 | ghcr.io/nvidia/jax:levanter |
구성요소 | 컨테이너 | 짓다 | 시험 |
---|---|---|---|
ghcr.io/nvidia/jax:base | [테스트 없음] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [테스트 비활성화됨] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
모든 경우에 ghcr.io/nvidia/jax:XXX
XXX
에 대한 컨테이너의 최신 야간 빌드를 가리킵니다. 안정적인 참조를 위해서는 ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
사용하세요.
공개 CI 외에도 H100 SXM 80GB, A100 SXM 80GB에 대한 내부 CI 테스트도 진행합니다.
JAX 이미지에는 XLA 및 NCCL의 성능 조정을 위한 다음 플래그 및 환경 변수가 포함되어 있습니다.
XLA 플래그 | 값 | 설명 |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | XLA는 통신 집단을 이동하여 컴퓨팅 커널과의 중복을 늘릴 수 있습니다. |
--xla_gpu_enable_triton_gemm | false | Trition GeMM 커널 대신 cuBLAS 사용 |
환경변수 | 값 | 설명 |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | 스트림 작업의 대기 시간을 낮추기 위해 GPU 작업에 단일 대기열을 사용합니다. XLA 주문이 이미 시작되었으므로 괜찮습니다. |
NCCL_NVLS_ENABLE | 0 | NVLink SHARP(1)를 비활성화합니다. 향후 릴리스에서는 이 기능을 다시 활성화할 예정입니다. |
성능 향상을 위해 사용자가 설정할 수 있는 다양한 XLA 플래그가 있습니다. 이러한 플래그에 대한 자세한 설명은 GPU 성능 문서를 참조하세요. XLA 플래그는 워크플로별로 조정할 수 있습니다. 예를 들어 contrib/gpu/scripts_gpu의 각 스크립트는 자체 XLA 플래그를 설정합니다.
이전에 사용되었지만 더 이상 필요하지 않은 XLA 플래그 목록은 GPU 성능 페이지를 참조하세요.
새로운 기본 컨테이너를 사용하여 매일 밤 처음으로 | 기본 용기 |
---|---|
2024-11-06 | 엔비디아/쿠다:12.6.2-devel-ubuntu22.04 |
2024-09-25 | 엔비디아/쿠다:12.6.1-devel-ubuntu22.04 |
2024-07-24 | 엔비디아/쿠다:12.5.0-devel-ubuntu22.04 |
GPU에서 JAX 프로그램을 프로파일링하는 방법에 대한 자세한 내용은 이 페이지를 참조하세요.
해결책:
docker run -it --shm-size=1g ...
설명: /dev/shm
의 크기 제한으로 인해 bus error
발생할 수 있습니다. 컨테이너를 시작할 때 --shm-size
옵션을 사용하여 공유 메모리 크기를 늘려 이 문제를 해결할 수 있습니다.
문제 설명:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
해결책: enroot v3.4.0 릴리스 노트에 언급된 대로 enroot를 업그레이드하거나 단일 파일 패치를 적용하십시오.
설명: Docker는 전통적으로 멀티 아키텍처 매니페스트 목록에 Docker Schema V2.2를 사용했지만 20.10 이후 OCI(Open Container Initiative) 형식을 사용하도록 전환했습니다. Enroot는 버전 3.4.0에서 OCI 형식에 대한 지원을 추가했습니다.
AWS
EFA 통합 추가
SageMaker 코드 샘플
GCP
Google Kubernetes Engine에서 NVIDIA GPU를 사용하는 JAX 다중 노드 애플리케이션 시작하기
하늘빛
Azure의 NDm A100 v4 가상 머신에서 JAX 프레임워크를 사용하여 AI 애플리케이션 가속화
OCI
OCI의 다중 노드 다중 GPU 클러스터에서 JAX를 사용하여 딥 러닝 워크로드 실행
잭스 | NVIDIA NGC 컨테이너
Slurm 및 OpenMPI 제로 구성 통합
커스텀 GPU 작업 추가
회귀 분류
JAX용 Equinox: 과학 및 기계 학습을 위한 생태계의 기초
JAX 및 H100을 사용하여 Grok 확장
GPU에서 강화된 JAX: JAX 및 OpenXLA를 갖춘 고성능 LLM
JAX의 새로운 기능 | GTC 2024년 봄
JAX의 새로운 기능 | GTC 2023년 봄