JAX Toolbox menyediakan CI publik, image Docker untuk pustaka JAX populer, dan contoh JAX yang dioptimalkan untuk menyederhanakan dan meningkatkan pengalaman pengembangan JAX Anda pada GPU NVIDIA. Ini mendukung perpustakaan JAX seperti MaxText, Paxml, dan Pallas.
Kami mendukung dan menguji kerangka kerja JAX dan arsitektur model berikut. Detail lebih lanjut tentang setiap model dan container yang tersedia dapat ditemukan di README masing-masing.
Kerangka | Model | Kasus penggunaan | Wadah |
---|---|---|---|
teks maksimal | GPT, LLaMA, Gemma, Mistral, Campuran | pelatihan awal | ghcr.io/nvidia/jax:maxtext |
paxml | GPT, LLaMA, MoE | pra-pelatihan, penyesuaian, LoRA | ghcr.io/nvidia/jax:pax |
t5x | T5, ViT | pra-pelatihan, penyesuaian | ghcr.io/nvidia/jax:t5x |
t5x | Gambar | pra-pelatihan | ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 |
visi besar | PaliGemma | penyesuaian, evaluasi | ghcr.io/nvidia/jax:gemma |
angin timur yg keras | GPT, LLaMA, MPT, Ransel | pra-pelatihan, penyesuaian | ghcr.io/nvidia/jax:levanter |
Komponen | Wadah | Membangun | Tes |
---|---|---|---|
ghcr.io/nvidia/jax:base | [tidak ada tes] | ||
ghcr.io/nvidia/jax:jax | |||
ghcr.io/nvidia/jax:levanter | |||
ghcr.io/nvidia/jax:equinox | [tes dinonaktifkan] | ||
ghcr.io/nvidia/jax:triton | |||
ghcr.io/nvidia/jax:upstream-t5x | |||
ghcr.io/nvidia/jax:t5x | |||
ghcr.io/nvidia/jax:upstream-pax | |||
ghcr.io/nvidia/jax:pax | |||
ghcr.io/nvidia/jax:maxtext | |||
ghcr.io/nvidia/jax:gemma |
Dalam semua kasus, ghcr.io/nvidia/jax:XXX
menunjuk ke build nightly terbaru dari container untuk XXX
. Untuk referensi yang stabil, gunakan ghcr.io/nvidia/jax:XXX-YYYY-MM-DD
.
Selain CI publik, kami juga menjalankan pengujian CI internal pada H100 SXM 80GB dan A100 SXM 80GB.
Gambar JAX disematkan dengan flag dan variabel lingkungan berikut untuk penyetelan kinerja XLA dan NCCL:
Bendera XLA | Nilai | Penjelasan |
---|---|---|
--xla_gpu_enable_latency_hiding_scheduler | true | memungkinkan XLA memindahkan kolektif komunikasi untuk meningkatkan tumpang tindih dengan kernel komputasi |
--xla_gpu_enable_triton_gemm | false | gunakan cuBLAS sebagai ganti kernel Trition GeMM |
Variabel Lingkungan | Nilai | Penjelasan |
---|---|---|
CUDA_DEVICE_MAX_CONNECTIONS | 1 | gunakan antrian tunggal untuk pekerjaan GPU guna menurunkan latensi operasi streaming; OK karena XLA sudah meluncurkan pesanan |
NCCL_NVLS_ENABLE | 0 | Menonaktifkan NVLink SHARP (1). Rilis mendatang akan mengaktifkan kembali fitur ini. |
Ada berbagai tanda XLA lain yang dapat diatur pengguna untuk meningkatkan kinerja. Untuk penjelasan mendetail mengenai tanda-tanda ini, silakan merujuk ke dokumen kinerja GPU. Bendera XLA dapat disetel per alur kerja. Misalnya, setiap skrip di contrib/gpu/scripts_gpu menetapkan tanda XLA-nya sendiri.
Untuk daftar flag XLA yang digunakan sebelumnya dan tidak diperlukan lagi, silakan lihat juga halaman performa GPU.
Malam pertama dengan wadah dasar baru | Wadah dasar |
---|---|
06-11-2024 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
25-09-2024 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
24-07-2024 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |
Lihat halaman ini untuk informasi lebih lanjut tentang cara membuat profil program JAX di GPU.
Larutan:
menjalankan buruh pelabuhan -itu --shm-size=1g ...
Penjelasan: bus error
mungkin terjadi karena batasan ukuran /dev/shm
. Anda dapat mengatasinya dengan meningkatkan ukuran memori bersama menggunakan opsi --shm-size
saat meluncurkan container Anda.
Deskripsi masalah:
slurmstepd: error: pyxis: [INFO] Authentication succeeded slurmstepd: error: pyxis: [INFO] Fetching image manifest list slurmstepd: error: pyxis: [INFO] Fetching image manifest slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/returned error code: 404 Not Found
Solusi: Tingkatkan enroot atau terapkan patch file tunggal seperti yang disebutkan dalam catatan rilis enroot v3.4.0.
Penjelasan: Docker secara tradisional menggunakan Docker Schema V2.2 untuk daftar manifes multi-lengkungan tetapi telah beralih menggunakan format Open Container Initiative (OCI) sejak 20.10. Enroot menambahkan dukungan untuk format OCI di versi 3.4.0.
AWS
Tambahkan integrasi EFA
Contoh kode SageMaker
GCP
Memulai aplikasi multi-node JAX dengan GPU NVIDIA di Google Kubernetes Engine
Biru langit
Mempercepat aplikasi AI menggunakan kerangka JAX pada Mesin Virtual NDm A100 v4 Azure
OKI
Menjalankan beban kerja pembelajaran mendalam dengan JAX pada cluster multi-GPU multinode di OCI
JAX | Wadah NVIDIA NGC
Integrasi konfigurasi nol Slurm dan OpenMPI
Menambahkan operasi GPU khusus
Triaging regresi
Equinox untuk JAX: Landasan Ekosistem untuk Sains dan Pembelajaran Mesin
Menskalakan Grok dengan JAX dan H100
JAX Supercharged pada GPU: LLM Performa Tinggi dengan JAX dan OpenXLA
Apa yang Baru di JAX | GTC Musim Semi 2024
Apa yang Baru di JAX | GTC Musim Semi 2023