Descripción general | ¿Por qué Haikú? | Inicio rápido | Instalación | Ejemplos | Manual de usuario | Documentación | Citando haiku
Importante
A partir de julio de 2023, Google DeepMind recomienda que los nuevos proyectos adopten Flax en lugar de Haiku. Flax es una biblioteca de redes neuronales desarrollada originalmente por Google Brain y ahora por Google DeepMind.
Al momento de escribir este artículo, Flax tiene un superconjunto de funciones disponibles en Haiku, un equipo de desarrollo más grande y más activo y una mayor adopción por parte de usuarios fuera de Alphabet. Flax tiene documentación más extensa, ejemplos y una comunidad activa que crea ejemplos de un extremo a otro.
Haiku seguirá siendo compatible con el mejor esfuerzo, sin embargo, el proyecto entrará en modo de mantenimiento, lo que significa que los esfuerzos de desarrollo se centrarán en la corrección de errores y la compatibilidad con las nuevas versiones de JAX.
Se realizarán nuevas versiones para que Haiku siga funcionando con las versiones más recientes de Python y JAX; sin embargo, no agregaremos (ni aceptaremos relaciones públicas) nuevas funciones.
Tenemos un uso significativo de Haiku internamente en Google DeepMind y actualmente planeamos admitir Haiku en este modo de manera indefinida.
El haiku es una herramienta.
Para construir redes neuronales
Piensa: "Soneto para JAX"
Haiku es una biblioteca de redes neuronales simple para JAX desarrollada por algunos de los autores de Sonnet, una biblioteca de redes neuronales para TensorFlow.
La documentación sobre Haiku se puede encontrar en https://dm-haiku.readthedocs.io/.
Desambiguación: si está buscando el sistema operativo Haiku, consulte https://haiku-os.org/.
JAX es una biblioteca de computación numérica que combina NumPy, diferenciación automática y compatibilidad con GPU/TPU de primera clase.
Haiku es una biblioteca de redes neuronales simple para JAX que permite a los usuarios utilizar modelos familiares de programación orientada a objetos y al mismo tiempo permite acceso completo a las transformaciones de funciones puras de JAX.
Haiku proporciona dos herramientas principales: una abstracción de módulo, hk.Module
, y una transformación de función simple, hk.transform
.
hk.Module
son objetos de Python que contienen referencias a sus propios parámetros, otros módulos y métodos que aplican funciones en las entradas del usuario.
hk.transform
convierte funciones que utilizan estos módulos funcionalmente "impuros" orientados a objetos en funciones puras que se pueden utilizar con jax.jit
, jax.grad
, jax.pmap
, etc.
Hay varias bibliotecas de redes neuronales para JAX. ¿Por qué deberías elegir Haiku?
Module
de Sonnet para la gestión del estado y al mismo tiempo conserva el acceso a las transformaciones de funciones de JAX.hk.transform
), Haiku pretende coincidir con la API de Sonnet 2. Los módulos, métodos, nombres de argumentos, valores predeterminados y esquemas de inicialización deben coincidir.hk.next_rng_key()
devuelve una clave rng única.Echemos un vistazo a un ejemplo de red neuronal, función de pérdida y bucle de entrenamiento. (Para obtener más ejemplos, consulte nuestro directorio de ejemplos. El ejemplo MNIST es un buen lugar para comenzar).
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
El núcleo de Haiku es hk.transform
. La función transform
le permite escribir funciones de red neuronal que dependen de parámetros (aquí los pesos de las capas Linear
) sin necesidad de escribir explícitamente el texto estándar para inicializar esos parámetros. transform
hace esto transformando la función en un par de funciones que son puras (como lo requiere JAX) init
y apply
.
init
La función init
, con firma params = init(rng, ...)
(donde ...
son los argumentos de la función no transformada), le permite recopilar el valor inicial de cualquier parámetro en la red. Haiku hace esto ejecutando su función, realizando un seguimiento de los parámetros solicitados a través de hk.get_parameter
(llamado, por ejemplo, por hk.Linear
) y devolviéndolos.
El objeto params
devuelto es una estructura de datos anidada de todos los parámetros de su red, diseñada para que usted la inspeccione y manipule. Concretamente, es una asignación del nombre del módulo a los parámetros del módulo, donde un parámetro de módulo es una asignación del nombre del parámetro al valor del parámetro. Por ejemplo:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
La función apply
, con firma result = apply(params, rng, ...)
, le permite inyectar valores de parámetros en su función. Siempre que se llame hk.get_parameter
, el valor devuelto provendrá de los params
que usted proporcione como entrada para apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
Tenga en cuenta que, dado que el cálculo real realizado por nuestra función de pérdida no se basa en números aleatorios, no es necesario pasar un generador de números aleatorios, por lo que también podríamos pasar None
para el argumento rng
. (Tenga en cuenta que si su cálculo utiliza números aleatorios, pasar None
por rng
generará un error). En nuestro ejemplo anterior, le pedimos a Haiku que haga esto automáticamente con:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
Dado que apply
es una función pura, podemos pasarla a jax.grad
(o cualquiera de las otras transformaciones de JAX):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
El ciclo de entrenamiento en este ejemplo es muy simple. Un detalle a tener en cuenta es el uso de jax.tree.map
para aplicar la función sgd
en todas las entradas coincidentes en params
y grads
. El resultado tiene la misma estructura que los params
anteriores y se puede volver a utilizar con apply
.
Haiku está escrito en Python puro, pero depende del código C++ a través de JAX.
Debido a que la instalación de JAX es diferente dependiendo de su versión de CUDA, Haiku no incluye JAX como una dependencia en requirements.txt
.
Primero, siga estas instrucciones para instalar JAX con el soporte de acelerador correspondiente.
Luego, instala Haiku usando pip:
$ pip install git+https://github.com/deepmind/dm-haiku
Alternativamente, puedes instalar a través de PyPI:
$ pip install -U dm-haiku
Nuestros ejemplos se basan en bibliotecas adicionales (por ejemplo, bsuite). Puede instalar el conjunto completo de requisitos adicionales usando pip:
$ pip install -r examples/requirements.txt
En Haiku, todos los módulos son una subclase de hk.Module
. Puede implementar cualquier método que desee (nada está en mayúsculas especiales), pero normalmente los módulos implementan __init__
y __call__
.
Trabajemos en la implementación de una capa lineal:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
Todos los módulos tienen un nombre. Cuando no se pasa ningún argumento name
al módulo, su nombre se infiere del nombre de la clase de Python (por ejemplo, MyLinear
se convierte en my_linear
). Los módulos pueden tener parámetros con nombre a los que se accede mediante hk.get_parameter(param_name, ...)
. Usamos esta API (en lugar de simplemente usar propiedades de objeto) para poder convertir su código en una función pura usando hk.transform
.
Cuando usa módulos, necesita definir funciones y transformarlas en un par de funciones puras usando hk.transform
. Consulte nuestro inicio rápido para obtener más detalles sobre las funciones devueltas por transform
:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
Algunos modelos pueden requerir un muestreo aleatorio como parte del cálculo. Por ejemplo, en codificadores automáticos variacionales con el truco de reparametrización, se necesita una muestra aleatoria de la distribución normal estándar. Para el abandono necesitamos una máscara aleatoria para eliminar unidades de la entrada. El principal obstáculo para que esto funcione con JAX es la gestión de las claves PRNG.
En Haiku proporcionamos una API simple para mantener una secuencia de claves PRNG asociada con módulos: hk.next_rng_key()
(o next_rng_keys()
para múltiples claves):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
Para obtener una visión más completa del trabajo con modelos estocásticos, consulte nuestro ejemplo de VAE.
Nota: hk.next_rng_key()
no es funcionalmente puro, lo que significa que debes evitar usarlo junto con las transformaciones JAX que están dentro de hk.transform
. Para obtener más información y posibles soluciones, consulte los documentos sobre transformaciones Haiku y contenedores disponibles para transformaciones JAX dentro de redes Haiku.
Es posible que algunos modelos quieran mantener algún estado interno mutable. Por ejemplo, en la normalización por lotes se mantiene un promedio móvil de los valores encontrados durante el entrenamiento.
En Haiku proporcionamos una API simple para mantener el estado mutable asociado con los módulos: hk.set_state
y hk.get_state
. Al usar estas funciones, necesita transformar su función usando hk.transform_with_state
ya que la firma del par de funciones devueltas es diferente:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
Si olvida usar hk.transform_with_state
, no se preocupe, imprimiremos un error claro que le indicará hk.transform_with_state
en lugar de eliminar silenciosamente su estado.
jax.pmap
Las funciones puras devueltas por hk.transform
(o hk.transform_with_state
) son totalmente compatibles con jax.pmap
. Para obtener más detalles sobre la programación SPMD con jax.pmap
, consulte aquí.
Un uso común de jax.pmap
con Haiku es el entrenamiento de datos en paralelo en muchos aceleradores, potencialmente en múltiples hosts. Con Haiku, eso podría verse así:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
Para obtener una visión más completa del entrenamiento distribuido de Haiku, consulte nuestro ejemplo de ResNet-50 en ImageNet.
Para citar este repositorio:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
En esta entrada de bibtex, el número de versión debe ser de haiku/__init__.py
y el año corresponde al lanzamiento de código abierto del proyecto.