Démarrage rapide | Transformations | Guide d'installation | Bibliothèques de réseaux neuronaux | Journaux des modifications | Documents de référence
JAX est une bibliothèque Python pour le calcul de tableaux orientés accélérateur et la transformation de programmes, conçue pour le calcul numérique haute performance et l'apprentissage automatique à grande échelle.
Avec sa version mise à jour d'Autograd, JAX peut automatiquement différencier les fonctions natives Python et NumPy. Il peut se différencier par des boucles, des branches, des récursions et des fermetures, et il peut prendre des dérivées de dérivées de dérivées. Il prend en charge la différenciation en mode inverse (c'est-à-dire la rétropropagation) via grad
ainsi que la différenciation en mode direct, et les deux peuvent être composés arbitrairement dans n'importe quel ordre.
La nouveauté est que JAX utilise XLA pour compiler et exécuter vos programmes NumPy sur des GPU et TPU. La compilation s'effectue sous le capot par défaut, les appels de bibliothèque étant compilés et exécutés juste à temps. Mais JAX vous permet également de compiler juste à temps vos propres fonctions Python dans des noyaux optimisés pour XLA à l'aide d'une API à fonction unique, jit
. La compilation et la différenciation automatique peuvent être composées arbitrairement, afin que vous puissiez exprimer des algorithmes sophistiqués et obtenir des performances maximales sans quitter Python. Vous pouvez même programmer plusieurs GPU ou cœurs TPU à la fois en utilisant pmap
et différencier l'ensemble.
Creusez un peu plus et vous verrez que JAX est vraiment un système extensible pour les transformations de fonctions composables. grad
et jit
sont tous deux des exemples de telles transformations. D'autres sont vmap
pour la vectorisation automatique et pmap
pour la programmation parallèle à programme unique et données multiples (SPMD) de plusieurs accélérateurs, et d'autres sont à venir.
Il s'agit d'un projet de recherche et non d'un produit Google officiel. Attendez-vous à des bugs et des arêtes vives. S'il vous plaît, aidez-nous en l'essayant, en signalant les bugs et en nous faisant savoir ce que vous en pensez !
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Lancez-vous directement en utilisant un ordinateur portable dans votre navigateur, connecté à un GPU Google Cloud. Voici quelques cahiers de démarrage :
grad
pour la différenciation, jit
pour la compilation et vmap
pour la vectorisationJAX fonctionne désormais sur Cloud TPU. Pour essayer l'aperçu, consultez les Colabs Cloud TPU.
Pour une plongée plus approfondie dans JAX :
À la base, JAX est un système extensible pour transformer des fonctions numériques. Voici quatre transformations de premier intérêt : grad
, jit
, vmap
et pmap
.
grad
JAX a à peu près la même API qu'Autograd. La fonction la plus populaire est grad
pour les dégradés en mode inverse :
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
Vous pouvez différencier n'importe quelle commande avec grad
.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
Pour une comparaison automatique plus avancée, vous pouvez utiliser jax.vjp
pour les produits vectoriels-jacobiens en mode inverse et jax.jvp
pour les produits vectoriels jacobiens en mode avant. Les deux peuvent être composés arbitrairement l'un avec l'autre et avec d'autres transformations JAX. Voici une façon de les composer pour créer une fonction qui calcule efficacement des matrices de Hesse complètes :
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Comme avec Autograd, vous êtes libre d'utiliser la différenciation avec les structures de contrôle Python :
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
Consultez les documents de référence sur la différenciation automatique et le livre de recettes JAX Autodiff pour en savoir plus.
jit
Vous pouvez utiliser XLA pour compiler vos fonctions de bout en bout avec jit
, utilisé soit comme décorateur @jit
, soit comme fonction d'ordre supérieur.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
Vous pouvez mélanger jit
et grad
et toute autre transformation JAX comme bon vous semble.
L'utilisation jit
impose des contraintes sur le type de flux de contrôle Python que la fonction peut utiliser ; consultez le didacticiel sur le flux de contrôle et les opérateurs logiques avec JIT pour en savoir plus.
vmap
vmap
est la carte de vectorisation. Il a la sémantique familière du mappage d'une fonction le long des axes d'un tableau, mais au lieu de garder la boucle à l'extérieur, il pousse la boucle vers les opérations primitives d'une fonction pour de meilleures performances.
L'utilisation vmap
peut vous éviter d'avoir à transporter des dimensions de lot dans votre code. Par exemple, considérons cette simple fonction de prédiction de réseau neuronal non batch :
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
Nous écrivons souvent à la place jnp.dot(activations, W)
pour permettre une dimension de lot sur le côté gauche des activations
, mais nous avons écrit cette fonction de prédiction particulière pour qu'elle s'applique uniquement aux vecteurs d'entrée uniques. Si nous voulions appliquer cette fonction à un lot d’entrées à la fois, sémantiquement nous pourrions simplement écrire
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
Mais transmettre un exemple à la fois sur le réseau serait lent ! Il est préférable de vectoriser le calcul, de sorte qu'à chaque couche, nous effectuions une multiplication matrice-matrice plutôt qu'une multiplication matrice-vecteur.
La fonction vmap
effectue cette transformation pour nous. Autrement dit, si nous écrivons
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
alors la fonction vmap
poussera la boucle externe à l'intérieur de la fonction, et notre machine finira par exécuter des multiplications matrice-matrice exactement comme si nous avions effectué le traitement par lots à la main.
Il est assez simple de regrouper manuellement un simple réseau neuronal sans vmap
, mais dans d'autres cas, la vectorisation manuelle peut s'avérer peu pratique, voire impossible. Prenons le problème du calcul efficace des gradients par exemple : c'est-à-dire que pour un ensemble fixe de paramètres, nous voulons calculer le gradient de notre fonction de perte évaluée séparément pour chaque exemple d'un lot. Avec vmap
, c'est simple :
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
Bien sûr, vmap
peut être arbitrairement composé avec jit
, grad
et toute autre transformation JAX ! Nous utilisons vmap
avec une différenciation automatique en mode avant et arrière pour des calculs rapides de matrices jacobiennes et hessiennes dans jax.jacfwd
, jax.jacrev
et jax.hessian
.
pmap
Pour la programmation parallèle de plusieurs accélérateurs, comme plusieurs GPU, utilisez pmap
. Avec pmap
vous écrivez des programmes SPMD (single-program multiple-data), y compris des opérations de communication collective parallèles rapides. L'application de pmap
signifie que la fonction que vous écrivez est compilée par XLA (de la même manière que jit
), puis répliquée et exécutée en parallèle sur tous les appareils.
Voici un exemple sur une machine à 8 GPU :
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
En plus d'exprimer des cartes pures, vous pouvez utiliser des opérations de communication collective rapides entre appareils :
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
Vous pouvez même imbriquer des fonctions pmap
pour des modèles de communication plus sophistiqués.
Tout se compose, vous êtes donc libre de différencier via des calculs parallèles :
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
Lors de la différenciation en mode inverse d'une fonction pmap
(par exemple avec grad
), la passe arrière du calcul est parallélisée tout comme la passe avant.
Consultez le livre de recettes SPMD et l'exemple de classificateur SPMD MNIST à partir de zéro pour en savoir plus.
Pour une étude plus approfondie des pièges actuels, avec des exemples et des explications, nous vous recommandons fortement de lire le Gotchas Notebook. Quelques points marquants :
is
ne sont pas préservés). Si vous utilisez une transformation JAX sur une fonction Python impure, vous pourriez voir une erreur du type Exception: Can't lift Traced...
ou Exception: Different traces at same level
.x[i] += y
, ne sont pas prises en charge, mais il existe des alternatives fonctionnelles. Sous un jit
, ces alternatives fonctionnelles réutiliseront automatiquement les tampons sur place.jax.lax
.float32
), et pour activer la double précision (64 bits, par exemple float64
), il faut définir la variable jax_enable_x64
au démarrage (ou définir la variable d'environnement JAX_ENABLE_X64=True
) . Sur TPU, JAX utilise des valeurs 32 bits par défaut pour tout, sauf les variables temporaires internes dans les opérations de type « matmul », telles que jax.numpy.dot
et lax.conv
. Ces opérations ont un paramètre precision
qui peut être utilisé pour approximer les opérations 32 bits via trois passes bfloat16, avec un coût d'exécution éventuellement plus lent. Les opérations non matmul sur TPU sont inférieures aux implémentations qui mettent souvent l'accent sur la vitesse plutôt que sur la précision, donc en pratique, les calculs sur TPU seront moins précis que les calculs similaires sur d'autres backends.np.add(1, np.array([2], np.float32)).dtype
est float64
plutôt que float32
.jit
, limitent la manière dont vous pouvez utiliser le flux de contrôle Python. Vous obtiendrez toujours des erreurs bruyantes si quelque chose ne va pas. Vous devrez peut-être utiliser le paramètre static_argnums
de jit
, des primitives de flux de contrôle structuré comme lax.scan
, ou simplement utiliser jit
sur des sous-fonctions plus petites. Linuxx86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
Processeur | Oui | Oui | Oui | Oui | Oui | Oui |
GPU NVIDIA | Oui | Oui | Non | n / A | Non | expérimental |
Google TPU | Oui | n / A | n / A | n / A | n / A | n / A |
GPU AMD | Oui | Non | expérimental | n / A | Non | Non |
GPU Apple | n / A | Non | n / A | expérimental | n / A | n / A |
GPU Intel | expérimental | n / A | n / A | n / A | Non | Non |
Plate-forme | Instructions |
---|---|
Processeur | pip install -U jax |
GPU NVIDIA | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
GPU AMD (Linux) | Utilisez Docker, des roues prédéfinies ou créez à partir des sources. |
GPU Mac | Suivez les instructions d'Apple. |
GPU Intel | Suivez les instructions d'Intel. |
Consultez la documentation pour plus d’informations sur les stratégies d’installation alternatives. Celles-ci incluent la compilation à partir des sources, l'installation avec Docker, l'utilisation d'autres versions de CUDA, une version conda prise en charge par la communauté et les réponses à certaines questions fréquemment posées.
Plusieurs groupes de recherche Google chez Google DeepMind et Alphabet développent et partagent des bibliothèques pour la formation des réseaux de neurones dans JAX. Si vous souhaitez une bibliothèque complète pour la formation aux réseaux neuronaux avec des exemples et des guides pratiques, essayez Flax et son site de documentation.
Consultez la section Écosystème JAX sur le site de documentation JAX pour obtenir une liste des bibliothèques réseau basées sur JAX, qui incluent Optax pour le traitement et l'optimisation des gradients, chex pour le code et les tests fiables et Equinox pour les réseaux de neurones. (Regardez l'écosystème JAX NeurIPS 2020 chez DeepMind ici pour plus de détails.)
Pour citer ce référentiel :
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
Dans l'entrée bibtex ci-dessus, les noms sont classés par ordre alphabétique, le numéro de version est censé être celui de jax/version.py et l'année correspond à la version open source du projet.
Une version naissante de JAX, prenant uniquement en charge la différenciation et la compilation automatiques vers XLA, a été décrite dans un article paru à SysML 2018. Nous travaillons actuellement à couvrir les idées et les capacités de JAX dans un article plus complet et à jour.
Pour plus de détails sur l'API JAX, consultez la documentation de référence.
Pour démarrer en tant que développeur JAX, consultez la documentation du développeur.