Visão geral | Por que Haiku? | Início rápido | Instalação | Exemplos | Manual do usuário | Documentação | Citando Haiku
Importante
A partir de julho de 2023, o Google DeepMind recomenda que novos projetos adotem Flax em vez de Haiku. Flax é uma biblioteca de redes neurais desenvolvida originalmente pelo Google Brain e agora pelo Google DeepMind.
No momento em que este artigo foi escrito, o Flax tinha um superconjunto de recursos disponíveis no Haiku, uma equipe de desenvolvimento maior e mais ativa e mais adoção por usuários fora do Alphabet. Flax tem documentação mais extensa, exemplos e uma comunidade ativa criando exemplos de ponta a ponta.
O Haiku permanecerá com suporte de melhor esforço, porém o projeto entrará em modo de manutenção, o que significa que os esforços de desenvolvimento serão focados em correções de bugs e compatibilidade com novas versões do JAX.
Novos lançamentos serão feitos para manter o Haiku funcionando com versões mais recentes do Python e JAX, porém não adicionaremos (ou aceitaremos PRs para) novos recursos.
Temos um uso significativo do Haiku internamente no Google DeepMind e atualmente planejamos oferecer suporte ao Haiku nesse modo indefinidamente.
Haiku é uma ferramenta
Para construir redes neurais
Pense: "Soneto para JAX"
Haiku é uma biblioteca de rede neural simples para JAX desenvolvida por alguns dos autores do Sonnet, uma biblioteca de rede neural para TensorFlow.
A documentação sobre Haiku pode ser encontrada em https://dm-haiku.readthedocs.io/.
Desambiguação: se você estiver procurando pelo sistema operacional Haiku, consulte https://haiku-os.org/.
JAX é uma biblioteca de computação numérica que combina NumPy, diferenciação automática e suporte GPU/TPU de primeira classe.
Haiku é uma biblioteca de rede neural simples para JAX que permite aos usuários usar modelos familiares de programação orientada a objetos, ao mesmo tempo que permite acesso total às transformações de funções puras do JAX.
O Haiku fornece duas ferramentas principais: uma abstração de módulo, hk.Module
, e uma transformação de função simples, hk.transform
.
hk.Module
s são objetos Python que contêm referências a seus próprios parâmetros, outros módulos e métodos que aplicam funções nas entradas do usuário.
hk.transform
transforma funções que usam esses módulos funcionalmente "impuros" orientados a objetos em funções puras que podem ser usadas com jax.jit
, jax.grad
, jax.pmap
, etc.
Existem várias bibliotecas de redes neurais para JAX. Por que você deve escolher o Haiku?
Module
do Sonnet para gerenciamento de estado, enquanto mantém o acesso às transformações de funções do JAX.hk.transform
), o Haiku visa corresponder à API do Sonnet 2. Módulos, métodos, nomes de argumentos, padrões e esquemas de inicialização devem corresponder.hk.next_rng_key()
retorna uma chave rng exclusiva.Vamos dar uma olhada em um exemplo de rede neural, função de perda e loop de treinamento. (Para mais exemplos, consulte nosso diretório de exemplos. O exemplo MNIST é um bom lugar para começar.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
O núcleo do Haiku é hk.transform
. A função transform
permite escrever funções de rede neural que dependem de parâmetros (aqui os pesos das camadas Linear
) sem exigir que você escreva explicitamente o padrão para inicializar esses parâmetros. transform
faz isso transformando a função em um par de funções que são puras (conforme exigido pelo JAX) init
e apply
.
init
A função init
, com assinatura params = init(rng, ...)
(onde ...
são os argumentos para a função não transformada), permite coletar o valor inicial de qualquer parâmetro na rede. O Haiku faz isso executando sua função, acompanhando quaisquer parâmetros solicitados através de hk.get_parameter
(chamado, por exemplo, por hk.Linear
) e retornando-os para você.
O objeto params
retornado é uma estrutura de dados aninhada de todos os parâmetros da sua rede, projetada para você inspecionar e manipular. Concretamente, é um mapeamento do nome do módulo para os parâmetros do módulo, onde um parâmetro do módulo é um mapeamento do nome do parâmetro para o valor do parâmetro. Por exemplo:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
A função apply
, com assinatura result = apply(params, rng, ...)
, permite injetar valores de parâmetro em sua função. Sempre que hk.get_parameter
for chamado, o valor retornado virá dos params
que você forneceu como entrada para apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
Observe que, como o cálculo real executado por nossa função de perda não depende de números aleatórios, passar um gerador de números aleatórios é desnecessário, portanto, também poderíamos passar None
para o argumento rng
. (Observe que se o seu cálculo usar números aleatórios, passar None
para rng
causará um erro.) Em nosso exemplo acima, pedimos ao Haiku para fazer isso automaticamente para nós com:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
Como apply
é uma função pura, podemos passá-la para jax.grad
(ou qualquer outra transformação do JAX):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
O ciclo de treinamento neste exemplo é muito simples. Um detalhe a ser observado é o uso de jax.tree.map
para aplicar a função sgd
em todas as entradas correspondentes em params
e grads
. O resultado tem a mesma estrutura dos params
anteriores e pode ser usado novamente com apply
.
Haiku é escrito em Python puro, mas depende de código C++ via JAX.
Como a instalação do JAX é diferente dependendo da sua versão CUDA, o Haiku não lista o JAX como uma dependência em requirements.txt
.
Primeiro, siga estas instruções para instalar o JAX com o suporte do acelerador relevante.
Em seguida, instale o Haiku usando pip:
$ pip install git+https://github.com/deepmind/dm-haiku
Alternativamente, você pode instalar via PyPI:
$ pip install -U dm-haiku
Nossos exemplos dependem de bibliotecas adicionais (por exemplo, bsuite). Você pode instalar o conjunto completo de requisitos adicionais usando pip:
$ pip install -r examples/requirements.txt
No Haiku, todos os módulos são uma subclasse de hk.Module
. Você pode implementar qualquer método que desejar (nada é especial), mas normalmente os módulos implementam __init__
e __call__
.
Vamos trabalhar na implementação de uma camada linear:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
Todos os módulos têm um nome. Quando nenhum argumento name
é passado para o módulo, seu nome é inferido do nome da classe Python (por exemplo, MyLinear
torna-se my_linear
). Os módulos podem ter parâmetros nomeados que são acessados usando hk.get_parameter(param_name, ...)
. Usamos essa API (em vez de apenas usar propriedades do objeto) para que possamos converter seu código em uma função pura usando hk.transform
.
Ao usar módulos você precisa definir funções e transformá-las em um par de funções puras usando hk.transform
. Consulte nosso guia de início rápido para obter mais detalhes sobre as funções retornadas de transform
:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
Alguns modelos podem exigir amostragem aleatória como parte do cálculo. Por exemplo, em autoencoders variacionais com o truque de reparametrização, é necessária uma amostra aleatória da distribuição normal padrão. Para dropout, precisamos de uma máscara aleatória para eliminar unidades da entrada. O principal obstáculo para fazer isso funcionar com JAX está no gerenciamento de chaves PRNG.
No Haiku, fornecemos uma API simples para manter uma sequência de teclas PRNG associada aos módulos: hk.next_rng_key()
(ou next_rng_keys()
para múltiplas chaves):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
Para uma visão mais completa do trabalho com modelos estocásticos, consulte nosso exemplo VAE.
Nota: hk.next_rng_key()
não é funcionalmente puro, o que significa que você deve evitar usá-lo junto com as transformações JAX que estão dentro de hk.transform
. Para obter mais informações e possíveis soluções alternativas, consulte a documentação sobre transformações do Haiku e wrappers disponíveis para transformações JAX dentro de redes Haiku.
Alguns modelos podem querer manter algum estado interno mutável. Por exemplo, na normalização em lote, é mantida uma média móvel dos valores encontrados durante o treinamento.
No Haiku, fornecemos uma API simples para manter o estado mutável associado aos módulos: hk.set_state
e hk.get_state
. Ao usar essas funções você precisa transformar sua função usando hk.transform_with_state
pois a assinatura do par de funções retornado é diferente:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
Se você esquecer de usar hk.transform_with_state
não se preocupe, imprimiremos um erro claro apontando para hk.transform_with_state
em vez de descartar silenciosamente seu estado.
jax.pmap
As funções puras retornadas de hk.transform
(ou hk.transform_with_state
) são totalmente compatíveis com jax.pmap
. Para mais detalhes sobre a programação SPMD com jax.pmap
, veja aqui.
Um uso comum de jax.pmap
com Haiku é para treinamento paralelo de dados em muitos aceleradores, potencialmente em vários hosts. Com o Haiku, isso pode ser assim:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
Para uma visão mais completa do treinamento distribuído do Haiku, dê uma olhada em nosso exemplo ResNet-50 no ImageNet.
Para citar este repositório:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
Nesta entrada do bibtex, o número da versão deve ser de haiku/__init__.py
e o ano corresponde ao lançamento do código aberto do projeto.