Aperçu | Pourquoi Haïku ? | Démarrage rapide | Installation | Exemples | Manuel d'utilisation | Documents | Citer des haïkus
Important
Depuis juillet 2023, Google DeepMind recommande aux nouveaux projets d'adopter le lin au lieu du haïku. Flax est une bibliothèque de réseaux neuronaux développée à l'origine par Google Brain et maintenant par Google DeepMind.
Au moment de la rédaction de cet article, Flax propose un surensemble de fonctionnalités disponibles dans Haiku, une équipe de développement plus grande et plus active et une plus grande adoption par les utilisateurs en dehors d'Alphabet. Flax dispose d'une documentation plus complète, d'exemples et d'une communauté active créant des exemples de bout en bout.
Haiku restera pris en charge au mieux, mais le projet entrera en mode maintenance, ce qui signifie que les efforts de développement seront concentrés sur les corrections de bugs et la compatibilité avec les nouvelles versions de JAX.
De nouvelles versions seront créées pour que Haiku continue de fonctionner avec les versions plus récentes de Python et JAX, mais nous n'ajouterons pas (ni n'accepterons de PR) de nouvelles fonctionnalités.
Nous utilisons de manière significative Haiku en interne chez Google DeepMind et prévoyons actuellement de prendre en charge Haiku dans ce mode indéfiniment.
Le haïku est un outil
Pour construire des réseaux de neurones
Pensez : "Sonnet pour JAX"
Haiku est une simple bibliothèque de réseaux neuronaux pour JAX développée par certains des auteurs de Sonnet, une bibliothèque de réseaux neuronaux pour TensorFlow.
La documentation sur Haiku peut être trouvée sur https://dm-haiku.readthedocs.io/.
Désambiguïsation : si vous recherchez Haiku, le système d'exploitation, veuillez consulter https://haiku-os.org/.
JAX est une bibliothèque de calcul numérique qui combine NumPy, différenciation automatique et prise en charge GPU/TPU de première classe.
Haiku est une simple bibliothèque de réseaux neuronaux pour JAX qui permet aux utilisateurs d'utiliser des modèles de programmation orientés objet familiers tout en permettant un accès complet aux transformations de fonctions pures de JAX.
Haiku fournit deux outils de base : une abstraction de module, hk.Module
, et une simple transformation de fonction, hk.transform
.
Les hk.Module
sont des objets Python qui contiennent des références à leurs propres paramètres, à d'autres modules et à des méthodes qui appliquent des fonctions aux entrées utilisateur.
hk.transform
transforme les fonctions qui utilisent ces modules fonctionnellement « impurs » orientés objet en fonctions pures qui peuvent être utilisées avec jax.jit
, jax.grad
, jax.pmap
, etc.
Il existe un certain nombre de bibliothèques de réseaux neuronaux pour JAX. Pourquoi choisir Haïku ?
Module
de Sonnet pour la gestion des états tout en conservant l'accès aux transformations de fonctions de JAX.hk.transform
), Haiku vise à correspondre à l'API de Sonnet 2. Les modules, méthodes, noms d'arguments, valeurs par défaut et schémas d'initialisation doivent correspondre.hk.next_rng_key()
renvoie une clé rng unique.Jetons un coup d'œil à un exemple de réseau neuronal, de fonction de perte et de boucle d'entraînement. (Pour plus d'exemples, consultez notre répertoire d'exemples. L'exemple MNIST est un bon point de départ.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Le cœur de Haiku est hk.transform
. La fonction transform
vous permet d'écrire des fonctions de réseau neuronal qui s'appuient sur des paramètres (ici les poids des couches Linear
) sans vous obliger à écrire explicitement le passe-partout pour initialiser ces paramètres. transform
fait cela en transformant la fonction en une paire de fonctions pures (comme l'exige JAX) init
et apply
.
init
La fonction init
, avec la signature params = init(rng, ...)
(où ...
sont les arguments de la fonction non transformée), vous permet de collecter la valeur initiale de tous les paramètres du réseau. Haiku fait cela en exécutant votre fonction, en gardant une trace de tous les paramètres demandés via hk.get_parameter
(appelé par exemple par hk.Linear
) et en vous les renvoyant.
L'objet params
renvoyé est une structure de données imbriquée de tous les paramètres de votre réseau, conçue pour que vous puissiez l'inspecter et la manipuler. Concrètement, il s'agit d'un mappage du nom du module aux paramètres du module, où un paramètre de module est un mappage du nom du paramètre à la valeur du paramètre. Par exemple:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
La fonction apply
, avec la signature result = apply(params, rng, ...)
, vous permet d' injecter des valeurs de paramètres dans votre fonction. Chaque fois que hk.get_parameter
est appelé, la valeur renvoyée proviendra des params
que vous fournissez en entrée à apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
Notez que puisque le calcul réel effectué par notre fonction de perte ne repose pas sur des nombres aléatoires, il n'est pas nécessaire de transmettre un générateur de nombres aléatoires, nous pourrions donc également transmettre None
pour l'argument rng
. (Notez que si votre calcul utilise des nombres aléatoires, passer None
pour rng
entraînera une erreur.) Dans notre exemple ci-dessus, nous demandons à Haiku de le faire automatiquement pour nous avec :
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
Puisque apply
est une fonction pure, nous pouvons la transmettre à jax.grad
(ou à l'une des autres transformations de JAX) :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
La boucle de formation dans cet exemple est très simple. Un détail à noter est l'utilisation de jax.tree.map
pour appliquer la fonction sgd
à toutes les entrées correspondantes dans params
et grads
. Le résultat a la même structure que les params
précédents et peut à nouveau être utilisé avec apply
.
Haiku est écrit en Python pur, mais dépend du code C++ via JAX.
Étant donné que l'installation de JAX est différente selon votre version de CUDA, Haiku ne répertorie pas JAX comme dépendance dans requirements.txt
.
Tout d’abord, suivez ces instructions pour installer JAX avec la prise en charge de l’accélérateur approprié.
Ensuite, installez Haiku en utilisant pip :
$ pip install git+https://github.com/deepmind/dm-haiku
Alternativement, vous pouvez installer via PyPI :
$ pip install -U dm-haiku
Nos exemples s'appuient sur des bibliothèques supplémentaires (par exemple bsuite). Vous pouvez installer l'ensemble complet des exigences supplémentaires à l'aide de pip :
$ pip install -r examples/requirements.txt
Dans Haiku, tous les modules sont une sous-classe de hk.Module
. Vous pouvez implémenter n'importe quelle méthode de votre choix (rien n'est spécifique), mais généralement les modules implémentent __init__
et __call__
.
Passons à l'implémentation d'une couche linéaire :
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
Tous les modules ont un nom. Lorsqu'aucun argument name
n'est passé au module, son nom est déduit du nom de la classe Python (par exemple MyLinear
devient my_linear
). Les modules peuvent avoir des paramètres nommés accessibles à l'aide de hk.get_parameter(param_name, ...)
. Nous utilisons cette API (plutôt que d'utiliser simplement les propriétés de l'objet) afin de pouvoir convertir votre code en une fonction pure à l'aide de hk.transform
.
Lorsque vous utilisez des modules, vous devez définir des fonctions et les transformer en une paire de fonctions pures à l'aide de hk.transform
. Consultez notre guide de démarrage rapide pour plus de détails sur les fonctions renvoyées par transform
:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
Certains modèles peuvent nécessiter un échantillonnage aléatoire dans le cadre du calcul. Par exemple, dans les auto-encodeurs variationnels dotés de l'astuce de reparamétrisation, un échantillon aléatoire de la distribution normale standard est nécessaire. Pour l'abandon, nous avons besoin d'un masque aléatoire pour supprimer les unités de l'entrée. Le principal obstacle à ce fonctionnement avec JAX réside dans la gestion des clés PRNG.
Dans Haiku, nous fournissons une API simple pour maintenir une séquence de clés PRNG associée aux modules : hk.next_rng_key()
(ou next_rng_keys()
pour plusieurs clés) :
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
Pour un aperçu plus complet de l'utilisation de modèles stochastiques, veuillez consulter notre exemple VAE.
Remarque : hk.next_rng_key()
n'est pas fonctionnellement pur, ce qui signifie que vous devez éviter de l'utiliser avec les transformations JAX qui se trouvent à l'intérieur hk.transform
. Pour plus d'informations et les solutions de contournement possibles, veuillez consulter la documentation sur les transformations Haiku et les wrappers disponibles pour les transformations JAX au sein des réseaux Haiku.
Certains modèles peuvent vouloir conserver un état interne mutable. Par exemple, dans la normalisation par lots, une moyenne mobile des valeurs rencontrées lors de la formation est maintenue.
Dans Haiku, nous fournissons une API simple pour maintenir l'état mutable associé aux modules : hk.set_state
et hk.get_state
. Lorsque vous utilisez ces fonctions, vous devez transformer votre fonction en utilisant hk.transform_with_state
puisque la signature de la paire de fonctions renvoyée est différente :
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
Si vous oubliez d'utiliser hk.transform_with_state
ne vous inquiétez pas, nous imprimerons une erreur claire vous indiquant hk.transform_with_state
plutôt que de supprimer silencieusement votre état.
jax.pmap
Les fonctions pures renvoyées par hk.transform
(ou hk.transform_with_state
) sont entièrement compatibles avec jax.pmap
. Pour plus de détails sur la programmation SPMD avec jax.pmap
, regardez ici.
Une utilisation courante de jax.pmap
avec Haiku concerne la formation parallèle de données sur de nombreux accélérateurs, potentiellement sur plusieurs hôtes. Avec Haiku, cela pourrait ressembler à ceci :
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
Pour un aperçu plus complet de la formation Haiku distribuée, jetez un œil à notre exemple ResNet-50 sur ImageNet.
Pour citer ce référentiel :
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
Dans cette entrée bibtex, le numéro de version est censé provenir de haiku/__init__.py
, et l'année correspond à la version open source du projet.