Schnellstart | Transformationen | Installationsanleitung | Neuronale Netzbibliotheken | Änderungsprotokolle | Referenzdokumente
JAX ist eine Python-Bibliothek für beschleunigerorientierte Array-Berechnung und Programmtransformation, die für leistungsstarkes numerisches Rechnen und groß angelegtes maschinelles Lernen entwickelt wurde.
Mit seiner aktualisierten Version von Autograd kann JAX automatisch zwischen nativen Python- und NumPy-Funktionen unterscheiden. Es kann durch Schleifen, Verzweigungen, Rekursionen und Abschlüsse differenzieren und Ableitungen von Ableitungen von Ableitungen annehmen. Es unterstützt die Reverse-Mode-Differenzierung (auch Backpropagation genannt) über grad
sowie die Forward-Mode-Differenzierung, und die beiden können beliebig in beliebiger Reihenfolge zusammengesetzt werden.
Neu ist, dass JAX XLA verwendet, um Ihre NumPy-Programme auf GPUs und TPUs zu kompilieren und auszuführen. Die Kompilierung erfolgt standardmäßig unter der Haube, wobei Bibliotheksaufrufe just-in-time kompiliert und ausgeführt werden. Mit JAX können Sie aber auch Ihre eigenen Python-Funktionen mithilfe einer Ein-Funktions-API, jit
just-in-time in XLA-optimierte Kernel kompilieren. Kompilierung und automatische Differenzierung können beliebig zusammengestellt werden, sodass Sie anspruchsvolle Algorithmen ausdrücken und maximale Leistung erzielen können, ohne Python zu verlassen. Mit pmap
können Sie sogar mehrere GPUs oder TPU-Kerne gleichzeitig programmieren und im Ganzen differenzieren.
Wenn Sie etwas tiefer graben, werden Sie sehen, dass JAX tatsächlich ein erweiterbares System für zusammensetzbare Funktionstransformationen ist. Sowohl grad
als auch jit
sind Beispiele für solche Transformationen. Andere sind vmap
für die automatische Vektorisierung und pmap
für die parallele SPMD-Programmierung (Single-Program Multiple Data) mehrerer Beschleuniger, weitere werden folgen.
Dies ist ein Forschungsprojekt, kein offizielles Google-Produkt. Erwarten Sie Insekten und scharfe Kanten. Bitte helfen Sie, indem Sie es ausprobieren, Fehler melden und uns Ihre Meinung mitteilen!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Steigen Sie direkt ein, indem Sie ein Notebook in Ihrem Browser verwenden, das mit einer Google Cloud-GPU verbunden ist. Hier sind einige Einsteiger-Notizbücher:
grad
für Differenzierung, jit
für Kompilierung und vmap
für VektorisierungJAX läuft jetzt auf Cloud TPUs. Um die Vorschau auszuprobieren, sehen Sie sich die Cloud TPU Colabs an.
Für einen tieferen Einblick in JAX:
Im Kern ist JAX ein erweiterbares System zur Transformation numerischer Funktionen. Hier sind vier Transformationen von Hauptinteresse: grad
, jit
, vmap
und pmap
.
grad
JAX hat ungefähr die gleiche API wie Autograd. Die beliebteste Funktion ist grad
für Farbverläufe im Umkehrmodus:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
Sie können mit grad
zu jeder Bestellung differenzieren.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
Für eine erweiterte Autodiff-Funktion können Sie jax.vjp
für Vektor-Jacobian-Produkte im Rückwärtsmodus und jax.jvp
für Jacobi-Vektor-Produkte im Vorwärtsmodus verwenden. Beide können untereinander und mit anderen JAX-Transformationen beliebig zusammengesetzt werden. Hier ist eine Möglichkeit, diese zusammenzusetzen, um eine Funktion zu erstellen, die vollständige Hesse-Matrizen effizient berechnet:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Wie bei Autograd steht es Ihnen frei, die Differenzierung mit Python-Kontrollstrukturen zu verwenden:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
Weitere Informationen finden Sie in den Referenzdokumenten zur automatischen Differenzierung und im JAX Autodiff Cookbook.
jit
Sie können XLA verwenden, um Ihre Funktionen durchgängig mit jit
zu kompilieren, entweder als @jit
Dekorator oder als Funktion höherer Ordnung.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
Sie können jit
und grad
sowie jede andere JAX-Transformation beliebig kombinieren.
Die Verwendung von jit
schränkt die Art des Python-Kontrollflusses ein, den die Funktion verwenden kann. Weitere Informationen finden Sie im Tutorial zu Kontrollfluss und logischen Operatoren mit JIT.
vmap
vmap
ist die vektorisierende Karte. Es verfügt über die bekannte Semantik der Abbildung einer Funktion entlang von Array-Achsen, aber anstatt die Schleife außen zu lassen, verschiebt es die Schleife für eine bessere Leistung nach unten in die primitiven Operationen einer Funktion.
Durch die Verwendung von vmap
können Sie sich das Mitschleppen von Batch-Dimensionen in Ihrem Code ersparen. Betrachten Sie zum Beispiel diese einfache , nicht gestapelte Vorhersagefunktion für neuronale Netze:
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
Wir schreiben stattdessen oft jnp.dot(activations, W)
um eine Batch-Dimension auf der linken Seite von activations
zu ermöglichen, aber wir haben diese spezielle Vorhersagefunktion so geschrieben, dass sie nur auf einzelne Eingabevektoren anwendbar ist. Wenn wir diese Funktion auf einen Stapel von Eingaben gleichzeitig anwenden wollten, könnten wir semantisch einfach schreiben
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
Aber ein Beispiel nach dem anderen durch das Netzwerk zu schicken, wäre langsam! Es ist besser, die Berechnung zu vektorisieren, sodass wir auf jeder Ebene eine Matrix-Matrix-Multiplikation statt einer Matrix-Vektor-Multiplikation durchführen.
Die vmap
-Funktion erledigt diese Transformation für uns. Das heißt, wenn wir schreiben
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
dann verschiebt die vmap
-Funktion die äußere Schleife innerhalb der Funktion und unsere Maschine führt am Ende Matrix-Matrix-Multiplikationen genau so aus, als hätten wir die Stapelung von Hand durchgeführt.
Es ist einfach genug, ein einfaches neuronales Netzwerk ohne vmap
manuell zu stapeln, aber in anderen Fällen kann die manuelle Vektorisierung unpraktisch oder unmöglich sein. Nehmen Sie das Problem der effizienten Berechnung von Gradienten pro Beispiel: Das heißt, wir möchten für einen festen Satz von Parametern den Gradienten unserer Verlustfunktion berechnen, der bei jedem Beispiel in einem Stapel separat ausgewertet wird. Mit vmap
ist es ganz einfach:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
Natürlich kann vmap
mit jit
, grad
und jeder anderen JAX-Transformation beliebig zusammengesetzt werden! Wir verwenden vmap
mit automatischer Differenzierung im Vorwärts- und Rückwärtsmodus für schnelle Jacobi- und Hesse-Matrixberechnungen in jax.jacfwd
, jax.jacrev
und jax.hessian
.
pmap
Verwenden Sie für die parallele Programmierung mehrerer Beschleuniger, z. B. mehrerer GPUs, pmap
. Mit pmap
schreiben Sie Single-Program-Multiple-Data-Programme (SPMD), einschließlich schneller paralleler kollektiver Kommunikationsoperationen. Die Anwendung von pmap
bedeutet, dass die von Ihnen geschriebene Funktion von XLA kompiliert wird (ähnlich wie jit
), dann repliziert und parallel auf allen Geräten ausgeführt wird.
Hier ist ein Beispiel auf einer 8-GPU-Maschine:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
Neben der Darstellung reiner Karten können Sie auch schnelle kollektive Kommunikationsvorgänge zwischen Geräten nutzen:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
Für anspruchsvollere Kommunikationsmuster können Sie sogar pmap
Funktionen verschachteln.
Da sich alles zusammensetzt, können Sie durch parallele Berechnungen differenzieren:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
Bei der Differenzierung einer pmap
Funktion im umgekehrten Modus (z. B. mit grad
) wird der Rückwärtsdurchlauf der Berechnung genau wie der Vorwärtsdurchlauf parallelisiert.
Weitere Informationen finden Sie im SPMD-Kochbuch und im SPMD-MNIST-Klassifikator-Beispiel von Grund auf.
Für einen ausführlicheren Überblick über aktuelle Fallstricke mit Beispielen und Erklärungen empfehlen wir dringend die Lektüre des Fallstrick-Notizbuchs. Einige Besonderheiten:
is
bleiben nicht erhalten). Wenn Sie eine JAX-Transformation für eine unreine Python-Funktion verwenden, wird möglicherweise eine Fehlermeldung wie Exception: Can't lift Traced...
oder Exception: Different traces at same level
angezeigt.x[i] += y
werden nicht unterstützt, es gibt jedoch funktionale Alternativen. Unter einem jit
werden diese funktionalen Alternativen Puffer automatisch vor Ort wiederverwenden.jax.lax
.float32
). Um Werte mit doppelter Genauigkeit (64-Bit, z. B. float64
) zu aktivieren, muss beim Start die Variable jax_enable_x64
festgelegt werden (oder die Umgebungsvariable JAX_ENABLE_X64=True
festgelegt werden). . Auf TPU verwendet JAX standardmäßig 32-Bit-Werte für alles außer internen temporären Variablen in „matmul-ähnlichen“ Vorgängen wie jax.numpy.dot
und lax.conv
. Diese Operationen verfügen über einen precision
, der zur Annäherung an 32-Bit-Operationen über drei bfloat16-Durchgänge verwendet werden kann, was möglicherweise zu einer langsameren Laufzeit führt. Nicht-matmul-Operationen auf der TPU sind bei Implementierungen, bei denen oft die Geschwindigkeit wichtiger ist als die Genauigkeit, geringer, sodass Berechnungen auf der TPU in der Praxis weniger präzise sind als ähnliche Berechnungen auf anderen Backends.np.add(1, np.array([2], np.float32)).dtype
ist float64
statt float32
.jit
, schränken die Verwendung des Python-Kontrollflusses ein. Es gibt immer laute Fehlermeldungen, wenn etwas schiefgeht. Möglicherweise müssen Sie static_argnums
-Parameter von jit
, strukturierte Kontrollflussprimitive wie lax.scan
oder einfach jit
für kleinere Unterfunktionen verwenden. Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | Ja | Ja | Ja | Ja | Ja | Ja |
NVIDIA-GPU | Ja | Ja | NEIN | n / A | NEIN | Experimental- |
Google TPU | Ja | n / A | n / A | n / A | n / A | n / A |
AMD-GPU | Ja | NEIN | Experimental- | n / A | NEIN | NEIN |
Apple-GPU | n / A | NEIN | n / A | Experimental- | n / A | n / A |
Intel-GPU | Experimental- | n / A | n / A | n / A | NEIN | NEIN |
Plattform | Anweisungen |
---|---|
CPU | pip install -U jax |
NVIDIA-GPU | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD-GPU (Linux) | Verwenden Sie Docker, vorgefertigte Räder oder erstellen Sie aus dem Quellcode. |
Mac-GPU | Befolgen Sie die Anweisungen von Apple. |
Intel-GPU | Befolgen Sie die Anweisungen von Intel. |
Informationen zu alternativen Installationsstrategien finden Sie in der Dokumentation. Dazu gehören das Kompilieren aus dem Quellcode, die Installation mit Docker, die Verwendung anderer Versionen von CUDA, ein von der Community unterstützter Conda-Build und Antworten auf einige häufig gestellte Fragen.
Mehrere Google-Forschungsgruppen bei Google DeepMind und Alphabet entwickeln und teilen Bibliotheken zum Training neuronaler Netze in JAX. Wenn Sie eine voll ausgestattete Bibliothek für das Training neuronaler Netze mit Beispielen und Anleitungen wünschen, probieren Sie Flax und seine Dokumentationsseite aus.
Im Abschnitt „JAX-Ökosystem“ auf der JAX-Dokumentationsseite finden Sie eine Liste JAX-basierter Netzwerkbibliotheken, darunter Optax für die Gradientenverarbeitung und -optimierung, chex für zuverlässigen Code und Tests sowie Equinox für neuronale Netze. (Sehen Sie sich hier den Vortrag zum NeurIPS 2020 JAX Ecosystem bei DeepMind an, um weitere Einzelheiten zu erfahren.)
Um dieses Repository zu zitieren:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
Im obigen Bibtex-Eintrag sind die Namen in alphabetischer Reihenfolge, die Versionsnummer soll die von jax/version.py sein und das Jahr entspricht der Open-Source-Veröffentlichung des Projekts.
Eine neue Version von JAX, die nur die automatische Differenzierung und Kompilierung in XLA unterstützt, wurde in einem Artikel beschrieben, der auf der SysML 2018 erschien. Wir arbeiten derzeit daran, die Ideen und Fähigkeiten von JAX in einem umfassenderen und aktuelleren Artikel abzudecken.
Einzelheiten zur JAX-API finden Sie in der Referenzdokumentation.
Informationen zum Einstieg als JAX-Entwickler finden Sie in der Entwicklerdokumentation.