Übersicht | Warum Haiku? | Schnellstart | Installation | Beispiele | Benutzerhandbuch | Dokumentation | Haiku zitieren
Wichtig
Ab Juli 2023 empfiehlt Google DeepMind, dass neue Projekte Flax anstelle von Haiku übernehmen. Flax ist eine neuronale Netzwerkbibliothek, die ursprünglich von Google Brain und jetzt von Google DeepMind entwickelt wurde.
Zum Zeitpunkt des Verfassens dieses Artikels verfügt Flax über einen Übersatz der in Haiku verfügbaren Funktionen, ein größeres und aktiveres Entwicklungsteam und eine größere Akzeptanz bei Benutzern außerhalb von Alphabet. Flax verfügt über eine umfangreichere Dokumentation, Beispiele und eine aktive Community, die End-to-End-Beispiele erstellt.
Haiku wird weiterhin nach besten Kräften unterstützt, das Projekt wird jedoch in den Wartungsmodus wechseln, was bedeutet, dass sich die Entwicklungsbemühungen auf Fehlerbehebungen und Kompatibilität mit neuen JAX-Versionen konzentrieren werden.
Es werden neue Versionen veröffentlicht, damit Haiku weiterhin mit neueren Versionen von Python und JAX funktioniert. Wir werden jedoch keine neuen Funktionen hinzufügen (oder PRs dafür akzeptieren).
Wir nutzen Haiku intern bei Google DeepMind in erheblichem Umfang und planen derzeit, Haiku in diesem Modus auf unbestimmte Zeit zu unterstützen.
Haiku ist ein Werkzeug
Zum Aufbau neuronaler Netze
Denken Sie: „Sonett für JAX“
Haiku ist eine einfache neuronale Netzwerkbibliothek für JAX, die von einigen Autoren von Sonnet, einer neuronalen Netzwerkbibliothek für TensorFlow, entwickelt wurde.
Dokumentation zu Haiku finden Sie unter https://dm-haiku.readthedocs.io/.
Begriffsklärung: Wenn Sie nach dem Betriebssystem Haiku suchen, schauen Sie sich bitte https://haiku-os.org/ an.
JAX ist eine numerische Computerbibliothek, die NumPy, automatische Differenzierung und erstklassige GPU/TPU-Unterstützung kombiniert.
Haiku ist eine einfache neuronale Netzwerkbibliothek für JAX, die es Benutzern ermöglicht, bekannte objektorientierte Programmiermodelle zu verwenden und gleichzeitig vollen Zugriff auf die reinen Funktionstransformationen von JAX zu ermöglichen.
Haiku bietet zwei Kernwerkzeuge: eine Modulabstraktion, hk.Module
, und eine einfache Funktionstransformation, hk.transform
.
hk.Module
s sind Python-Objekte, die Verweise auf ihre eigenen Parameter, andere Module und Methoden enthalten, die Funktionen auf Benutzereingaben anwenden.
hk.transform
wandelt Funktionen, die diese objektorientierten, funktional „unreinen“ Module verwenden, in reine Funktionen um, die mit jax.jit
, jax.grad
, jax.pmap
usw. verwendet werden können.
Es gibt eine Reihe neuronaler Netzwerkbibliotheken für JAX. Warum sollten Sie sich für Haiku entscheiden?
Module
Programmiermodell von Sonnet für die Zustandsverwaltung bei und behält gleichzeitig den Zugriff auf die Funktionstransformationen von JAX bei.hk.transform
) zielt Haiku darauf ab, mit der API von Sonnet 2 übereinzustimmen. Module, Methoden, Argumentnamen, Standardeinstellungen und Initialisierungsschemata sollten übereinstimmen.hk.next_rng_key()
einen eindeutigen RNG-Schlüssel zurück.Schauen wir uns ein Beispiel für ein neuronales Netzwerk, eine Verlustfunktion und eine Trainingsschleife an. (Weitere Beispiele finden Sie in unserem Beispielverzeichnis. Das MNIST-Beispiel ist ein guter Ausgangspunkt.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Der Kern von Haiku ist hk.transform
. Mit der transform
können Sie neuronale Netzwerkfunktionen schreiben, die auf Parametern (hier den Gewichtungen der Linear
Schichten) basieren, ohne dass Sie das Boilerplate zum Initialisieren dieser Parameter explizit schreiben müssen. transform
tut dies, indem es die Funktion in ein Paar reiner Funktionen (wie von JAX gefordert) init
und apply
umwandelt.
init
Mit der Funktion init
mit der Signatur params = init(rng, ...)
(wobei ...
die Argumente der nicht transformierten Funktion sind) können Sie den Anfangswert aller Parameter im Netzwerk erfassen . Haiku tut dies, indem es Ihre Funktion ausführt, alle über hk.get_parameter
(aufgerufen z. B. hk.Linear
) angeforderten Parameter verfolgt und sie an Sie zurückgibt.
Das zurückgegebene params
-Objekt ist eine verschachtelte Datenstruktur aller Parameter in Ihrem Netzwerk, die Sie überprüfen und bearbeiten können. Konkret handelt es sich um eine Zuordnung von Modulnamen zu Modulparametern, wobei ein Modulparameter eine Zuordnung von Parameternamen zu Parameterwerten ist. Zum Beispiel:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
Mit der apply
Funktion mit der Signatur result = apply(params, rng, ...)
können Sie Parameterwerte in Ihre Funktion einfügen . Immer wenn hk.get_parameter
aufgerufen wird, stammt der zurückgegebene Wert aus den params
die Sie als Eingabe zum apply
bereitstellen:
loss = loss_fn_t . apply ( params , rng , images , labels )
Beachten Sie, dass die Übergabe eines Zufallszahlengenerators nicht erforderlich ist, da die tatsächliche Berechnung unserer Verlustfunktion nicht auf Zufallszahlen basiert. Wir könnten also auch None
für das rng
-Argument übergeben. (Beachten Sie, dass, wenn Ihre Berechnung Zufallszahlen verwendet , die Übergabe von None
für rng
zu einem Fehler führt.) In unserem Beispiel oben bitten wir Haiku, dies automatisch für uns zu tun mit:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
Da es sich apply
um eine reine Funktion handelt, können wir sie an jax.grad
(oder eine der anderen JAX-Transformationen) übergeben:
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
Die Trainingsschleife in diesem Beispiel ist sehr einfach. Ein zu beachtendes Detail ist die Verwendung von jax.tree.map
um die sgd
-Funktion auf alle übereinstimmenden Einträge in params
und grads
anzuwenden. Das Ergebnis hat die gleiche Struktur wie die vorherigen params
und kann erneut mit apply
verwendet werden.
Haiku ist in reinem Python geschrieben, hängt aber von C++-Code über JAX ab.
Da die JAX-Installation je nach CUDA-Version unterschiedlich ist, listet Haiku JAX nicht als Abhängigkeit in requirements.txt
auf.
Befolgen Sie zunächst diese Anweisungen, um JAX mit der entsprechenden Beschleunigerunterstützung zu installieren.
Dann installieren Sie Haiku mit pip:
$ pip install git+https://github.com/deepmind/dm-haiku
Alternativ können Sie auch über PyPI installieren:
$ pip install -U dm-haiku
Unsere Beispiele basieren auf zusätzlichen Bibliotheken (z. B. bsuite). Sie können den gesamten Satz zusätzlicher Anforderungen mit pip installieren:
$ pip install -r examples/requirements.txt
In Haiku sind alle Module eine Unterklasse von hk.Module
. Sie können jede beliebige Methode implementieren (nichts ist in Sonderfällen angegeben), aber normalerweise implementieren Module __init__
und __call__
.
Lassen Sie uns die Implementierung einer linearen Ebene durchgehen:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
Alle Module haben einen Namen. Wenn dem Modul kein name
übergeben wird, wird sein Name aus dem Namen der Python-Klasse abgeleitet (z. B. wird MyLinear
zu my_linear
). Module können benannte Parameter haben, auf die mit hk.get_parameter(param_name, ...)
zugegriffen wird. Wir verwenden diese API (anstatt nur Objekteigenschaften zu verwenden), damit wir Ihren Code mithilfe von hk.transform
in eine reine Funktion umwandeln können.
Wenn Sie Module verwenden, müssen Sie Funktionen definieren und diese mithilfe von hk.transform
in ein Paar reiner Funktionen umwandeln. Weitere Informationen zu den von transform
zurückgegebenen Funktionen finden Sie in unserem Schnellstart:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
Bei einigen Modellen ist im Rahmen der Berechnung möglicherweise eine Zufallsstichprobe erforderlich. Beispielsweise wird bei Variations-Autoencodern mit dem Reparametrisierungstrick eine Zufallsstichprobe aus der Standardnormalverteilung benötigt. Für den Dropout benötigen wir eine Zufallsmaske, um Einheiten aus der Eingabe zu entfernen. Die größte Hürde, damit dies mit JAX funktioniert, liegt in der Verwaltung der PRNG-Schlüssel.
In Haiku stellen wir eine einfache API zum Verwalten einer PRNG-Tastenfolge bereit, die Modulen zugeordnet ist: hk.next_rng_key()
(oder next_rng_keys()
für mehrere Schlüssel):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
Einen ausführlicheren Einblick in die Arbeit mit stochastischen Modellen finden Sie in unserem VAE-Beispiel.
Hinweis: hk.next_rng_key()
ist nicht funktionell rein, was bedeutet, dass Sie es nicht zusammen mit JAX-Transformationen verwenden sollten, die sich in hk.transform
befinden. Weitere Informationen und mögliche Problemumgehungen finden Sie in den Dokumenten zu Haiku-Transformationen und verfügbaren Wrappern für JAX-Transformationen in Haiku-Netzwerken.
Einige Modelle möchten möglicherweise einen internen, veränderlichen Zustand beibehalten. Beispielsweise wird bei der Batch-Normalisierung ein gleitender Durchschnitt der während des Trainings ermittelten Werte beibehalten.
In Haiku stellen wir eine einfache API zur Aufrechterhaltung des veränderlichen Zustands bereit, die den Modulen hk.set_state
und hk.get_state
zugeordnet ist. Wenn Sie diese Funktionen verwenden, müssen Sie Ihre Funktion mit hk.transform_with_state
transformieren, da die Signatur des zurückgegebenen Funktionspaars unterschiedlich ist:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
Wenn Sie vergessen, hk.transform_with_state
zu verwenden, machen Sie sich keine Sorgen, wir geben einen eindeutigen Fehler aus, der Sie auf hk.transform_with_state
verweist, anstatt Ihren Status stillschweigend zu löschen.
jax.pmap
Die von hk.transform
(oder hk.transform_with_state
) zurückgegebenen reinen Funktionen sind vollständig kompatibel mit jax.pmap
. Weitere Informationen zur SPMD-Programmierung mit jax.pmap
finden Sie hier.
Eine häufige Verwendung von jax.pmap
mit Haiku ist das datenparallele Training auf vielen Beschleunigern, möglicherweise über mehrere Hosts hinweg. Bei Haiku könnte das so aussehen:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
Für einen umfassenderen Einblick in das verteilte Haiku-Training werfen Sie einen Blick auf unser Beispiel „ResNet-50 auf ImageNet“.
Um dieses Repository zu zitieren:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
In diesem Bibtex-Eintrag soll die Versionsnummer von haiku/__init__.py
stammen und das Jahr entspricht der Open-Source-Veröffentlichung des Projekts.