Início rápido | Transformações | Guia de instalação | Bibliotecas de redes neurais | Registros de alterações | Documentos de referência
JAX é uma biblioteca Python para computação de array orientada a aceleradores e transformação de programas, projetada para computação numérica de alto desempenho e aprendizado de máquina em grande escala.
Com sua versão atualizada do Autograd, o JAX pode diferenciar automaticamente funções nativas do Python e do NumPy. Ele pode diferenciar por meio de loops, ramificações, recursão e fechamentos, e pode obter derivadas de derivadas de derivadas. Ele suporta diferenciação de modo reverso (também conhecido como retropropagação) via grad
, bem como diferenciação de modo direto, e os dois podem ser compostos arbitrariamente em qualquer ordem.
A novidade é que JAX usa XLA para compilar e executar seus programas NumPy em GPUs e TPUs. A compilação acontece nos bastidores por padrão, com as chamadas de biblioteca sendo compiladas e executadas just-in-time. Mas o JAX também permite compilar just-in-time suas próprias funções Python em kernels otimizados para XLA usando uma API de função única, jit
. A compilação e a diferenciação automática podem ser compostas arbitrariamente, para que você possa expressar algoritmos sofisticados e obter desempenho máximo sem sair do Python. Você pode até programar vários núcleos de GPUs ou TPU ao mesmo tempo usando pmap
e diferenciar tudo.
Vá um pouco mais fundo e você verá que JAX é realmente um sistema extensível para transformações de funções combináveis. Tanto grad
quanto jit
são exemplos de tais transformações. Outros são vmap
para vetorização automática e pmap
para programação paralela de múltiplos dados de programa único (SPMD) de múltiplos aceleradores, com mais por vir.
Este é um projeto de pesquisa, não um produto oficial do Google. Espere bugs e arestas vivas. Por favor, ajude testando, relatando bugs e nos contando o que você pensa!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Comece imediatamente usando um notebook no seu navegador, conectado a uma GPU do Google Cloud. Aqui estão alguns cadernos iniciais:
grad
para diferenciação, jit
para compilação e vmap
para vetorizaçãoJAX agora é executado em Cloud TPUs. Para experimentar a visualização, consulte Cloud TPU Colabs.
Para um mergulho mais profundo no JAX:
Basicamente, JAX é um sistema extensível para transformar funções numéricas. Aqui estão quatro transformações de interesse primário: grad
, jit
, vmap
e pmap
.
grad
JAX tem aproximadamente a mesma API do Autograd. A função mais popular é grad
para gradientes de modo reverso:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
Você pode diferenciar qualquer pedido com grad
.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
Para autodiff mais avançado, você pode usar jax.vjp
para produtos de vetor Jacobiano de modo reverso e jax.jvp
para produtos de vetor Jacobiano de modo direto. Os dois podem ser compostos arbitrariamente entre si e com outras transformações JAX. Aqui está uma maneira de compô-los para criar uma função que calcule eficientemente matrizes Hessianas completas:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Assim como no Autograd, você pode usar diferenciação com estruturas de controle Python:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
Consulte os documentos de referência sobre diferenciação automática e o JAX Autodiff Cookbook para obter mais informações.
jit
Você pode usar o XLA para compilar suas funções de ponta a ponta com jit
, usado como decorador @jit
ou como uma função de ordem superior.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
Você pode misturar jit
e grad
e qualquer outra transformação JAX como quiser.
O uso jit
impõe restrições ao tipo de fluxo de controle do Python que a função pode usar; consulte o tutorial sobre fluxo de controle e operadores lógicos com JIT para obter mais informações.
vmap
vmap
é o mapa de vetorização. Ele tem a semântica familiar de mapear uma função ao longo dos eixos do array, mas em vez de manter o loop do lado de fora, ele empurra o loop para as operações primitivas de uma função para obter melhor desempenho.
Usar vmap
pode evitar que você tenha que carregar dimensões de lote em seu código. Por exemplo, considere esta função simples de previsão de rede neural sem lote :
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
Muitas vezes, em vez disso, escrevemos jnp.dot(activations, W)
para permitir uma dimensão de lote no lado esquerdo de activations
, mas escrevemos esta função de previsão específica para ser aplicada apenas a vetores de entrada únicos. Se quiséssemos aplicar esta função a um lote de entradas de uma só vez, semanticamente poderíamos simplesmente escrever
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
Mas enviar um exemplo de cada vez pela rede seria lento! É melhor vetorizar o cálculo, de modo que em cada camada façamos multiplicação matriz-matriz em vez de multiplicação matriz-vetor.
A função vmap
faz essa transformação para nós. Isto é, se escrevermos
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
então a função vmap
empurrará o loop externo para dentro da função, e nossa máquina acabará executando multiplicações matriz-matriz exatamente como se tivéssemos feito o lote manualmente.
É fácil agrupar manualmente uma rede neural simples sem vmap
, mas em outros casos a vetorização manual pode ser impraticável ou impossível. Tomemos o problema de calcular eficientemente gradientes por exemplo: isto é, para um conjunto fixo de parâmetros, queremos calcular o gradiente de nossa função de perda avaliada separadamente em cada exemplo de um lote. Com vmap
, é fácil:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
Claro, vmap
pode ser composto arbitrariamente com jit
, grad
e qualquer outra transformação JAX! Usamos vmap
com diferenciação automática de modo direto e reverso para cálculos rápidos de matriz Jacobiana e Hessiana em jax.jacfwd
, jax.jacrev
e jax.hessian
.
pmap
Para programação paralela de vários aceleradores, como várias GPUs, use pmap
. Com pmap
você escreve programas de múltiplos dados (SPMD) de programa único, incluindo operações rápidas de comunicação coletiva paralela. Aplicar pmap
significará que a função que você escreve será compilada por XLA (semelhante a jit
), depois replicada e executada em paralelo entre dispositivos.
Aqui está um exemplo em uma máquina com 8 GPU:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
Além de expressar mapas puros, você pode usar operações rápidas de comunicação coletiva entre dispositivos:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
Você pode até aninhar funções pmap
para padrões de comunicação mais sofisticados.
Tudo compõe, então você está livre para diferenciar por meio de cálculos paralelos:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
Ao diferenciar uma função pmap
no modo reverso (por exemplo, com grad
), a passagem para trás da computação é paralelizada exatamente como a passagem para frente.
Consulte o SPMD Cookbook e o exemplo do classificador SPMD MNIST do zero para obter mais informações.
Para um levantamento mais completo das pegadinhas atuais, com exemplos e explicações, recomendamos a leitura do Caderno de pegadinhas. Alguns destaques:
is
não é preservado). Se você usar uma transformação JAX em uma função Python impura, poderá ver um erro como Exception: Can't lift Traced...
ou Exception: Different traces at same level
.x[i] += y
, não são suportadas, mas existem alternativas funcionais. Sob um jit
, essas alternativas funcionais reutilizarão buffers no local automaticamente.jax.lax
.float32
) por padrão, e para ativar a precisão dupla (64 bits, por exemplo, float64
) é necessário definir a variável jax_enable_x64
na inicialização (ou definir a variável de ambiente JAX_ENABLE_X64=True
) . Na TPU, o JAX usa valores de 32 bits por padrão para tudo, exceto variáveis temporárias internas em operações 'semelhantes ao matmul', como jax.numpy.dot
e lax.conv
. Essas operações têm um parâmetro precision
que pode ser usado para aproximar operações de 32 bits por meio de três passagens bfloat16, com um custo de tempo de execução possivelmente mais lento. As operações não matmul na TPU são inferiores às implementações que geralmente enfatizam a velocidade em vez da precisão, portanto, na prática, os cálculos na TPU serão menos precisos do que cálculos semelhantes em outros back-ends.np.add(1, np.array([2], np.float32)).dtype
é float64
em vez de float32
.jit
, restringem a forma como você pode usar o fluxo de controle do Python. Você sempre receberá erros graves se algo der errado. Talvez você precise usar o parâmetro static_argnums
do jit
, primitivas de fluxo de controle estruturadas como lax.scan
ou apenas usar jit
em subfunções menores. Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Janelas x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | sim | sim | sim | sim | sim | sim |
GPU NVIDIA | sim | sim | não | n / D | não | experimental |
TPU do Google | sim | n / D | n / D | n / D | n / D | n / D |
GPU AMD | sim | não | experimental | n / D | não | não |
GPU da Apple | n / D | não | n / D | experimental | n / D | n / D |
GPU Intel | experimental | n / D | n / D | n / D | não | não |
Plataforma | Instruções |
---|---|
CPU | pip install -U jax |
GPU NVIDIA | pip install -U "jax[cuda12]" |
TPU do Google | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
GPU AMD (Linux) | Use Docker, rodas pré-construídas ou crie a partir do código-fonte. |
GPU Mac | Siga as instruções da Apple. |
GPU Intel | Siga as instruções da Intel. |
Consulte a documentação para obter informações sobre estratégias alternativas de instalação. Isso inclui compilar a partir do código-fonte, instalar com Docker, usar outras versões do CUDA, uma construção conda com suporte da comunidade e respostas a algumas perguntas frequentes.
Vários grupos de pesquisa do Google no Google DeepMind e Alphabet desenvolvem e compartilham bibliotecas para treinar redes neurais em JAX. Se você deseja uma biblioteca completa para treinamento de redes neurais com exemplos e guias de procedimentos, experimente o Flax e seu site de documentação.
Confira a seção Ecossistema JAX no site de documentação do JAX para obter uma lista de bibliotecas de rede baseadas em JAX, que inclui Optax para processamento e otimização de gradiente, chex para código e testes confiáveis e Equinox para redes neurais. (Assista à palestra do Ecossistema NeurIPS 2020 JAX na DeepMind aqui para obter detalhes adicionais.)
Para citar este repositório:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
Na entrada bibtex acima, os nomes estão em ordem alfabética, o número da versão deve ser o de jax/version.py e o ano corresponde ao lançamento do código aberto do projeto.
Uma versão nascente do JAX, suportando apenas diferenciação e compilação automática para XLA, foi descrita em um artigo publicado na SysML 2018. Atualmente, estamos trabalhando para cobrir as ideias e capacidades do JAX em um artigo mais abrangente e atualizado.
Para obter detalhes sobre a API JAX, consulte a documentação de referência.
Para começar como desenvolvedor JAX, consulte a documentação do desenvolvedor.