빠른 시작 | 변환 | 설치 가이드 | 신경망 라이브러리 | 변경 로그 | 참조 문서
JAX는 고성능 수치 컴퓨팅 및 대규모 기계 학습을 위해 설계된 가속기 중심 배열 계산 및 프로그램 변환을 위한 Python 라이브러리입니다.
업데이트된 Autograd 버전을 통해 JAX는 기본 Python 및 NumPy 기능을 자동으로 구별할 수 있습니다. 루프, 분기, 재귀 및 클로저를 통해 차별화할 수 있으며 파생 상품의 파생 상품을 취할 수 있습니다. grad
및 순방향 미분을 통한 역방향 미분(역전파라고도 함)을 지원하며 두 가지를 임의의 순서로 임의로 구성할 수 있습니다.
새로운 점은 JAX가 XLA를 사용하여 GPU 및 TPU에서 NumPy 프로그램을 컴파일하고 실행한다는 것입니다. 컴파일은 기본적으로 내부적으로 이루어지며, 라이브러리 호출은 적시에 컴파일되고 실행됩니다. 그러나 JAX를 사용하면 단일 함수 API인 jit
사용하여 Python 함수를 XLA 최적화 커널로 JIT(Just-In-Time) 컴파일할 수도 있습니다. 컴파일과 자동 미분을 임의로 구성할 수 있어 Python을 떠나지 않고도 정교한 알고리즘을 표현하고 최대의 성능을 얻을 수 있습니다. pmap
사용하면 여러 GPU 또는 TPU 코어를 한 번에 프로그래밍하고 전체를 차별화할 수도 있습니다.
조금 더 깊이 파고들면 JAX가 실제로 구성 가능한 함수 변환을 위한 확장 가능한 시스템이라는 것을 알 수 있습니다. grad
와 jit
모두 이러한 변환의 인스턴스입니다. 그 외에 자동 벡터화를 위한 vmap
과 다중 가속기의 SPMD(단일 프로그램 다중 데이터) 병렬 프로그래밍을 위한 pmap
이 있으며 앞으로 더 많은 기능이 추가될 예정입니다.
이것은 공식 Google 제품이 아닌 연구 프로젝트입니다. 버그와 날카로운 모서리가 예상됩니다. 사용해 보시고, 버그를 보고하고, 여러분의 생각을 알려주세요!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Google Cloud GPU에 연결된 브라우저에서 노트북을 사용하여 바로 시작해 보세요. 다음은 몇 가지 시작 노트북입니다.
grad
, 컴파일을 위한 jit
, 벡터화를 위한 vmap
JAX는 이제 Cloud TPU에서 실행됩니다. 미리보기를 사용해 보려면 Cloud TPU Colab을 참조하세요.
JAX에 대해 더 자세히 알아보려면 다음을 수행하세요.
기본적으로 JAX는 숫자 함수를 변환하기 위한 확장 가능한 시스템입니다. 다음은 주요 관심 대상인 4가지 변환인 grad
, jit
, vmap
및 pmap
입니다.
grad
사용한 자동 미분 JAX에는 Autograd와 거의 동일한 API가 있습니다. 가장 널리 사용되는 함수는 역방향 그래디언트에 대한 grad
입니다.
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
grad
사용하면 어떤 순서로든 차별화할 수 있습니다.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
고급 자동 비교의 경우 역방향 벡터-야코비안 제품에는 jax.vjp
사용하고 순방향 모드 야코비안 벡터 제품에는 jax.jvp
사용할 수 있습니다. 두 개는 서로 임의로 구성될 수 있으며 다른 JAX 변환을 사용하여 구성될 수도 있습니다. 전체 헤세 행렬을 효율적으로 계산하는 함수를 만들기 위해 이를 구성하는 한 가지 방법은 다음과 같습니다.
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Autograd와 마찬가지로 Python 제어 구조에서도 자유롭게 차별화를 사용할 수 있습니다.
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
자세한 내용은 자동 미분에 대한 참조 문서와 JAX Autodiff Cookbook을 참조하세요.
jit
로 컴파일 XLA를 사용하면 @jit
데코레이터 또는 고차 함수로 사용되는 jit
사용하여 함수를 엔드투엔드로 컴파일할 수 있습니다.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
원하는 대로 jit
, grad
및 기타 JAX 변환을 혼합할 수 있습니다.
jit
사용하면 함수가 사용할 수 있는 Python 제어 흐름의 종류에 제약이 가해집니다. 자세한 내용은 제어 흐름 및 JIT를 사용한 논리 연산자에 대한 자습서를 참조하세요.
vmap
사용한 자동 벡터화 vmap
벡터화 맵입니다. 이는 배열 축을 따라 함수를 매핑하는 익숙한 의미를 가지지만 루프를 외부에 유지하는 대신 더 나은 성능을 위해 루프를 함수의 기본 작업으로 푸시합니다.
vmap
사용하면 코드에 배치 차원을 가지고 다닐 필요가 없습니다. 예를 들어 다음과 같은 간단한 비 배치 신경망 예측 함수를 생각해 보세요.
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
대신에 activations
의 왼쪽에 배치 차원을 허용하기 위해 jnp.dot(activations, W)
작성하는 경우가 많지만, 이 특정 예측 함수는 단일 입력 벡터에만 적용되도록 작성했습니다. 이 함수를 한 번에 입력 배치에 적용하려면 의미상으로 다음과 같이 작성할 수 있습니다.
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
그러나 한 번에 하나의 예제를 네트워크를 통해 푸시하는 것은 느릴 것입니다! 계산을 벡터화하여 모든 레이어에서 행렬-벡터 곱셈보다는 행렬-행렬 곱셈을 수행하는 것이 더 좋습니다.
vmap
함수는 우리를 위해 이러한 변환을 수행합니다. 즉, 우리가 쓴다면
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
그런 다음 vmap
함수는 함수 내부의 외부 루프를 푸시하고 우리 기계는 마치 우리가 일괄 처리를 수행한 것처럼 정확하게 행렬-행렬 곱셈을 실행하게 됩니다.
vmap
없이 간단한 신경망을 수동으로 일괄 처리하는 것은 쉽지만, 다른 경우 수동 벡터화는 비실용적이거나 불가능할 수 있습니다. 예제별 기울기를 효율적으로 계산하는 문제를 생각해 보겠습니다. 즉, 고정된 매개변수 집합에 대해 배치의 각 예제에서 개별적으로 평가되는 손실 함수의 기울기를 계산하려고 합니다. vmap
사용하면 쉽습니다:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
물론 vmap
jit
, grad
및 기타 JAX 변환을 사용하여 임의로 구성할 수 있습니다! 우리는 jax.jacfwd
, jax.jacrev
및 jax.hessian
에서 빠른 야코비안 및 헤시안 행렬 계산을 위해 정방향 및 역방향 자동 미분 기능이 있는 vmap
사용합니다.
pmap
사용한 SPMD 프로그래밍 다중 GPU와 같은 다중 가속기의 병렬 프로그래밍에는 pmap
사용하십시오. pmap
사용하면 빠른 병렬 집단 통신 작업을 포함하여 단일 프로그램 다중 데이터(SPMD) 프로그램을 작성할 수 있습니다. pmap
적용한다는 것은 작성한 함수가 XLA( jit
와 유사)에 의해 컴파일된 다음 장치 간에 병렬로 복제 및 실행된다는 것을 의미합니다.
다음은 8-GPU 머신의 예입니다.
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
순수 맵을 표현하는 것 외에도 장치 간에 빠른 집단 통신 작업을 사용할 수 있습니다.
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
보다 정교한 통신 패턴을 위해 pmap
기능을 중첩할 수도 있습니다.
모두 구성되므로 병렬 계산을 통해 자유롭게 차별화할 수 있습니다.
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
pmap
함수를 역방향 모드로 미분할 때(예: grad
사용) 계산의 역방향 전달은 순방향 전달과 마찬가지로 병렬화됩니다.
자세한 내용은 SPMD Cookbook 및 처음부터 SPMD MNIST 분류기 예제를 참조하세요.
예와 설명이 포함된 현재 문제에 대한 보다 철저한 조사를 보려면 Gotchas Notebook을 읽어 보시기 바랍니다. 몇 가지 뛰어난 점:
is
사용한 객체 ID 테스트는 유지되지 않습니다). 순수하지 않은 Python 함수에서 JAX 변환을 사용하는 경우 Exception: Can't lift Traced...
또는 Exception: Different traces at same level
같은 오류가 표시될 수 있습니다.x[i] += y
와 같은 배열의 내부 변경 업데이트는 지원되지 않지만 기능적 대안이 있습니다. jit
에서 이러한 기능적 대안은 버퍼를 자동으로 재사용합니다.jax.lax
패키지에 있습니다.float32
) 값을 적용하고 배정밀도(64비트, 예: float64
)를 활성화하려면 시작 시 jax_enable_x64
변수를 설정해야 합니다(또는 환경 변수 JAX_ENABLE_X64=True
). . TPU에서 JAX는 jax.numpy.dot
및 lax.conv
와 같은 'matmul 유사' 작업의 내부 임시 변수를 제외한 모든 항목에 대해 기본적으로 32비트 값을 사용합니다. 이러한 작업에는 3개의 bfloat16 패스를 통해 대략 32비트 작업에 사용할 수 있는 precision
매개변수가 있으며, 런타임 비용이 느려질 수 있습니다. TPU의 비 matmul 작업은 종종 정확성보다 속도를 강조하는 구현보다 낮으므로 실제로 TPU에서의 계산은 다른 백엔드에서의 유사한 계산보다 정확도가 떨어집니다.np.add(1, np.array([2], np.float32)).dtype
float32
가 아닌 float64
입니다.jit
와 같은 일부 변환은 Python 제어 흐름을 사용하는 방법을 제한합니다. 뭔가 잘못되면 항상 큰 오류가 발생합니다. jit
의 static_argnums
매개변수, lax.scan
과 같은 구조화된 제어 흐름 기본 요소를 사용하거나 더 작은 하위 함수에 jit
사용해야 할 수도 있습니다. 리눅스 x86_64 | 리눅스 aarch64 | 맥 x86_64 | 맥 aarch64 | 윈도우 x86_64 | 윈도우 WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | 예 | 예 | 예 | 예 | 예 | 예 |
엔비디아 GPU | 예 | 예 | 아니요 | 해당 없음 | 아니요 | 실험적인 |
구글 TPU | 예 | 해당 없음 | 해당 없음 | 해당 없음 | 해당 없음 | 해당 없음 |
AMD GPU | 예 | 아니요 | 실험적인 | 해당 없음 | 아니요 | 아니요 |
애플 GPU | 해당 없음 | 아니요 | 해당 없음 | 실험적인 | 해당 없음 | 해당 없음 |
인텔 GPU | 실험적인 | 해당 없음 | 해당 없음 | 해당 없음 | 아니요 | 아니요 |
플랫폼 | 지침 |
---|---|
CPU | pip install -U jax |
엔비디아 GPU | pip install -U "jax[cuda12]" |
구글 TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU(리눅스) | Docker, 사전 빌드된 휠을 사용하거나 소스에서 빌드하세요. |
맥 GPU | Apple의 지시를 따르십시오. |
인텔 GPU | 인텔의 지침을 따르십시오. |
대체 설치 전략에 대한 자세한 내용은 설명서를 참조하세요. 여기에는 소스에서 컴파일하기, Docker로 설치하기, 다른 버전의 CUDA 사용하기, 커뮤니티에서 지원하는 conda 빌드 및 몇 가지 자주 묻는 질문에 대한 답변이 포함됩니다.
Google DeepMind 및 Alphabet의 여러 Google 연구 그룹은 JAX에서 신경망 훈련을 위한 라이브러리를 개발하고 공유합니다. 예제와 방법 가이드가 포함된 신경망 훈련을 위한 모든 기능을 갖춘 라이브러리를 원한다면 Flax와 해당 문서 사이트를 사용해 보세요.
그라데이션 처리 및 최적화를 위한 Optax, 안정적인 코드 및 테스트를 위한 chex, 신경망을 위한 Equinox를 포함하는 JAX 기반 네트워크 라이브러리 목록을 보려면 JAX 문서 사이트의 JAX 생태계 섹션을 확인하세요. (자세한 내용은 여기에서 DeepMind의 NeurIPS 2020 JAX 생태계 강연을 시청하세요.)
이 저장소를 인용하려면:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
위의 bibtex 항목에서 이름은 알파벳 순서로 되어 있고 버전 번호는 jax/version.py의 번호이며 연도는 프로젝트의 오픈 소스 릴리스에 해당합니다.
XLA에 대한 자동 차별화 및 컴파일만 지원하는 초기 버전의 JAX는 SysML 2018에 게재된 논문에 설명되어 있습니다. 우리는 현재 보다 포괄적이고 최신 논문에서 JAX의 아이디어와 기능을 다루기 위해 노력하고 있습니다.
JAX API에 대한 자세한 내용은 참조 문서를 참조하세요.
JAX 개발자로 시작하려면 개발자 설명서를 참조하세요.