Inicio rápido | Transformaciones | Guía de instalación | Bibliotecas de redes neuronales | Cambiar registros | Documentos de referencia
JAX es una biblioteca de Python para computación de matrices orientada a aceleradores y transformación de programas, diseñada para computación numérica de alto rendimiento y aprendizaje automático a gran escala.
Con su versión actualizada de Autograd, JAX puede diferenciar automáticamente las funciones nativas de Python y NumPy. Puede diferenciar mediante bucles, ramas, recursividad y cierres, y puede tomar derivadas de derivadas de derivadas. Admite la diferenciación en modo inverso (también conocida como retropropagación) mediante grad
así como la diferenciación en modo directo, y ambas pueden componerse arbitrariamente en cualquier orden.
La novedad es que JAX usa XLA para compilar y ejecutar sus programas NumPy en GPU y TPU. La compilación ocurre internamente de forma predeterminada, y las llamadas a la biblioteca se compilan y ejecutan justo a tiempo. Pero JAX también le permite compilar justo a tiempo sus propias funciones de Python en núcleos optimizados para XLA utilizando una API de una sola función, jit
. La compilación y la diferenciación automática se pueden componer de forma arbitraria, por lo que puedes expresar algoritmos sofisticados y obtener el máximo rendimiento sin salir de Python. Incluso puede programar varias GPU o núcleos de TPU a la vez usando pmap
y diferenciarlo todo.
Profundice un poco más y verá que JAX es en realidad un sistema extensible para transformaciones de funciones componibles. Tanto grad
como jit
son ejemplos de tales transformaciones. Otros son vmap
para la vectorización automática y pmap
para la programación paralela de un solo programa y múltiples datos (SPMD) de múltiples aceleradores, y habrá más por venir.
Este es un proyecto de investigación, no un producto oficial de Google. Espere errores y bordes afilados. ¡Ayúdenos probándolo, reportando errores y contándonos lo que piensa!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Comienza a usar una computadora portátil en tu navegador, conectada a una GPU de Google Cloud. Aquí hay algunos cuadernos de inicio:
grad
para diferenciación, jit
para compilación y vmap
para vectorizaciónJAX ahora se ejecuta en Cloud TPU. Para probar la vista previa, consulte Cloud TPU Colabs.
Para una inmersión más profunda en JAX:
En esencia, JAX es un sistema extensible para transformar funciones numéricas. Aquí hay cuatro transformaciones de principal interés: grad
, jit
, vmap
y pmap
.
grad
JAX tiene aproximadamente la misma API que Autograd. La función más popular es grad
para gradientes en modo inverso:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
Puedes diferenciar a cualquier pedido con grad
.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
Para una diferenciación automática más avanzada, puede utilizar jax.vjp
para productos vectoriales jacobianos en modo inverso y jax.jvp
para productos vectoriales jacobianos en modo directo. Los dos pueden componerse arbitrariamente entre sí y con otras transformaciones JAX. Aquí hay una forma de componerlos para crear una función que calcule de manera eficiente matrices de Hesse completas:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Al igual que con Autograd, puedes utilizar la diferenciación con las estructuras de control de Python:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
Consulte los documentos de referencia sobre diferenciación automática y el libro de recetas JAX Autodiff para obtener más información.
jit
Puede usar XLA para compilar sus funciones de un extremo a otro con jit
, usado como decorador @jit
o como una función de orden superior.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
Puedes mezclar jit
y grad
y cualquier otra transformación JAX como quieras.
El uso de jit
impone restricciones al tipo de flujo de control de Python que la función puede usar; consulte el tutorial sobre flujo de control y operadores lógicos con JIT para obtener más información.
vmap
vmap
es el mapa vectorizador. Tiene la semántica familiar de mapear una función a lo largo de los ejes de una matriz, pero en lugar de mantener el bucle en el exterior, empuja el bucle hacia las operaciones primitivas de una función para un mejor rendimiento.
El uso de vmap
puede evitarle tener que llevar dimensiones de lote en su código. Por ejemplo, considere esta sencilla función de predicción de red neuronal no por lotes :
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
En lugar de eso, a menudo escribimos jnp.dot(activations, W)
para permitir una dimensión de lote en el lado izquierdo de activations
, pero hemos escrito esta función de predicción particular para que se aplique solo a vectores de entrada únicos. Si quisiéramos aplicar esta función a un lote de entradas a la vez, semánticamente podríamos simplemente escribir
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
¡Pero enviar un ejemplo a través de la red a la vez sería lento! Es mejor vectorizar el cálculo, de modo que en cada capa hagamos una multiplicación matriz-matriz en lugar de una multiplicación matriz-vector.
La función vmap
hace esa transformación por nosotros. Es decir, si escribimos
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
entonces la función vmap
empujará el bucle externo dentro de la función, y nuestra máquina terminará ejecutando multiplicaciones matriz-matriz exactamente como si hubiéramos hecho el procesamiento por lotes a mano.
Es bastante fácil agrupar manualmente una red neuronal simple sin vmap
, pero en otros casos la vectorización manual puede resultar poco práctica o imposible. Tomemos el problema de calcular de manera eficiente los gradientes por ejemplo: es decir, para un conjunto fijo de parámetros, queremos calcular el gradiente de nuestra función de pérdida evaluada por separado en cada ejemplo de un lote. Con vmap
, es fácil:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
¡Por supuesto, vmap
se puede componer arbitrariamente con jit
, grad
y cualquier otra transformación JAX! Usamos vmap
con diferenciación automática en modo directo e inverso para cálculos rápidos de matrices jacobianas y hessianas en jax.jacfwd
, jax.jacrev
y jax.hessian
.
pmap
Para la programación paralela de múltiples aceleradores, como múltiples GPU, use pmap
. Con pmap
usted escribe programas de datos múltiples de un solo programa (SPMD), incluidas operaciones rápidas de comunicación colectiva paralela. Aplicar pmap
significará que la función que escriba será compilada por XLA (de manera similar a jit
), luego replicada y ejecutada en paralelo en todos los dispositivos.
Aquí hay un ejemplo en una máquina de 8 GPU:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
Además de expresar mapas puros, puede utilizar operaciones rápidas de comunicación colectiva entre dispositivos:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
Incluso puedes anidar funciones pmap
para patrones de comunicación más sofisticados.
Todo se compone, por lo que eres libre de diferenciar mediante cálculos paralelos:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
Cuando se diferencia en modo inverso una función pmap
(por ejemplo, con grad
), el paso hacia atrás del cálculo se paraleliza al igual que el paso hacia adelante.
Consulte el libro de cocina SPMD y el ejemplo del clasificador SPMD MNIST desde cero para obtener más información.
Para obtener un estudio más completo de los errores actuales, con ejemplos y explicaciones, recomendamos encarecidamente leer el Cuaderno de errores. Algunos destacados:
is
no se conservan). Si utiliza una transformación JAX en una función impura de Python, es posible que vea un error como Exception: Can't lift Traced...
o Exception: Different traces at same level
.x[i] += y
, no son compatibles, pero existen alternativas funcionales. Bajo un jit
, esas alternativas funcionales reutilizarán los buffers en el lugar automáticamente.jax.lax
.float32
) de forma predeterminada, y para habilitar la precisión doble (64 bits, por ejemplo, float64
) es necesario configurar la variable jax_enable_x64
al inicio (o configurar la variable de entorno JAX_ENABLE_X64=True
) . En TPU, JAX usa valores de 32 bits de forma predeterminada para todo, excepto para las variables temporales internas en operaciones tipo matmul, como jax.numpy.dot
y lax.conv
. Esas operaciones tienen un parámetro precision
que se puede utilizar para aproximar operaciones de 32 bits mediante tres pases bfloat16, con un costo de tiempo de ejecución posiblemente más lento. Las operaciones no matmul en TPU se reducen a implementaciones que a menudo enfatizan la velocidad sobre la precisión, por lo que en la práctica los cálculos en TPU serán menos precisos que cálculos similares en otros backends.np.add(1, np.array([2], np.float32)).dtype
es float64
en lugar de float32
.jit
, restringen cómo se puede utilizar el flujo de control de Python. Siempre obtendrás errores ruidosos si algo sale mal. Es posible que tengas que usar el parámetro static_argnums
de jit
, primitivas de flujo de control estructurado como lax.scan
, o simplemente usar jit
en subfunciones más pequeñas. Linuxx86_64 | Linux aarch64 | Macx86_64 | Mac aarch64 | Ventanas x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
UPC | Sí | Sí | Sí | Sí | Sí | Sí |
GPU NVIDIA | Sí | Sí | No | n / A | No | experimental |
Google TPU | Sí | n / A | n / A | n / A | n / A | n / A |
GPU AMD | Sí | No | experimental | n / A | No | No |
GPU de Apple | n / A | No | n / A | experimental | n / A | n / A |
GPU Intel | experimental | n / A | n / A | n / A | No | No |
Plataforma | Instrucciones |
---|---|
UPC | pip install -U jax |
GPU NVIDIA | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
GPU AMD (Linux) | Utilice Docker, ruedas prediseñadas o cree desde el código fuente. |
GPU Mac | Sigue las instrucciones de Apple. |
GPU Intel | Siga las instrucciones de Intel. |
Consulte la documentación para obtener información sobre estrategias de instalación alternativas. Estos incluyen compilar desde el código fuente, instalar con Docker, usar otras versiones de CUDA, una compilación conda respaldada por la comunidad y respuestas a algunas preguntas frecuentes.
Varios grupos de investigación de Google en Google DeepMind y Alphabet desarrollan y comparten bibliotecas para entrenar redes neuronales en JAX. Si desea una biblioteca con todas las funciones para el entrenamiento de redes neuronales con ejemplos y guías prácticas, pruebe Flax y su sitio de documentación.
Consulte la sección JAX Ecosystem en el sitio de documentación de JAX para obtener una lista de bibliotecas de red basadas en JAX, que incluye Optax para procesamiento y optimización de gradientes, chex para código y pruebas confiables y Equinox para redes neuronales. (Vea la charla sobre el ecosistema JAX de NeurIPS 2020 en DeepMind aquí para obtener detalles adicionales).
Para citar este repositorio:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
En la entrada bibtex anterior, los nombres están en orden alfabético, el número de versión debe ser el de jax/version.py y el año corresponde al lanzamiento de código abierto del proyecto.
En un artículo que apareció en SysML 2018 se describió una versión incipiente de JAX, que solo admite diferenciación automática y compilación en XLA. Actualmente estamos trabajando para cubrir las ideas y capacidades de JAX en un artículo más completo y actualizado.
Para obtener detalles sobre la API JAX, consulte la documentación de referencia.
Para comenzar como desarrollador de JAX, consulte la documentación para desarrolladores.