概述|为什么是俳句? |快速入门|安装|示例|用户手册|文档|引用俳句
重要的
自 2023 年 7 月起,Google DeepMind 建议新项目采用 Flax 而不是 Haiku。 Flax 是一个神经网络库,最初由 Google Brain 开发,现在由 Google DeepMind 开发。
在撰写本文时,Flax 已经拥有 Haiku 中可用功能的超集,拥有更大、更活跃的开发团队,并且得到了 Alphabet 以外用户的更多采用。 Flax 拥有更广泛的文档、示例和创建端到端示例的活跃社区。
Haiku 将继续尽力支持,但该项目将进入维护模式,这意味着开发工作将集中在错误修复和与 JAX 新版本的兼容性上。
将发布新版本以使 Haiku 能够与较新版本的 Python 和 JAX 配合使用,但是我们不会添加(或接受 PR)新功能。
我们在 Google DeepMind 内部大量使用 Haiku,目前计划无限期地支持这种模式的 Haiku。
俳句是一种工具
用于构建神经网络
思考:“JAX 十四行诗”
Haiku 是一个简单的 JAX 神经网络库,由 Sonnet(TensorFlow 的神经网络库)的一些作者开发。
有关 Haiku 的文档可以在 https://dm-haiku.readthedocs.io/ 找到。
消歧义:如果您正在寻找 Haiku 操作系统,请参阅 https://haiku-os.org/。
JAX 是一个数值计算库,结合了 NumPy、自动微分和一流的 GPU/TPU 支持。
Haiku 是一个简单的 JAX 神经网络库,使用户能够使用熟悉的面向对象编程模型,同时允许完全访问 JAX 的纯函数转换。
Haiku 提供了两个核心工具:模块抽象hk.Module
和简单的函数转换hk.transform
。
hk.Module
是 Python 对象,它们保存对其自身参数、其他模块以及对用户输入应用函数的方法的引用。
hk.transform
将使用这些面向对象的、功能上“不纯”的模块的函数转换为可与jax.jit
、 jax.grad
、 jax.pmap
等一起使用的纯函数。
JAX 有许多神经网络库。为什么要选择俳句?
Module
编程模型以进行状态管理,同时保留对 JAX 函数转换的访问。hk.transform
)之外,Haiku 的目标是匹配 Sonnet 2 的 API。模块、方法、参数名称、默认值和初始化方案应该匹配。hk.next_rng_key()
返回一个唯一的 rng 键。让我们看一下神经网络、损失函数和训练循环的示例。 (有关更多示例,请参阅我们的示例目录。MNIST 示例是一个很好的起点。)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Haiku 的核心是hk.transform
。 transform
函数允许您编写依赖于参数(此处为Linear
层的权重)的神经网络函数,而不需要您显式编写用于初始化这些参数的样板文件。 transform
通过将函数转换为一对纯函数(根据 JAX 的要求) init
和apply
来实现此目的。
init
init
函数的签名为params = init(rng, ...)
(其中...
是未转换函数的参数),允许您收集网络中任何参数的初始值。 Haiku 通过运行您的函数、跟踪通过hk.get_parameter
(例如hk.Linear
调用)请求的任何参数并将它们返回给您来实现此目的。
返回的params
对象是网络中所有参数的嵌套数据结构,专为您检查和操作而设计。具体来说,它是模块名称到模块参数的映射,其中模块参数是参数名称到参数值的映射。例如:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
apply
函数的签名为result = apply(params, rng, ...)
,允许您将参数值注入到函数中。每当调用hk.get_parameter
时,返回的值将来自您作为apply
输入提供的params
:
loss = loss_fn_t . apply ( params , rng , images , labels )
请注意,由于损失函数执行的实际计算不依赖于随机数,因此不需要传入随机数生成器,因此我们也可以为rng
参数传入None
。 (请注意,如果您的计算确实使用随机数,则为rng
传递None
将导致引发错误。)在上面的示例中,我们要求 Haiku 自动为我们执行此操作:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
由于apply
是一个纯函数,我们可以将其传递给jax.grad
(或任何 JAX 的其他转换):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
此示例中的训练循环非常简单。需要注意的一个细节是使用jax.tree.map
将sgd
函数应用于params
和grads
中的所有匹配条目。结果与前面的params
具有相同的结构,并且可以再次与apply
一起使用。
Haiku 是用纯 Python 编写的,但依赖于通过 JAX 的 C++ 代码。
由于 JAX 安装根据您的 CUDA 版本而有所不同,因此 Haiku 不会在requirements.txt
中将 JAX 列为依赖项。
首先,按照以下说明安装具有相关加速器支持的 JAX。
然后,使用 pip 安装 Haiku:
$ pip install git+https://github.com/deepmind/dm-haiku
或者,您可以通过 PyPI 安装:
$ pip install -U dm-haiku
我们的示例依赖于其他库(例如 bsuite)。您可以使用 pip 安装全套附加要求:
$ pip install -r examples/requirements.txt
在俳句中,所有模块都是hk.Module
的子类。您可以实现任何您喜欢的方法(没有什么特殊情况),但通常模块实现__init__
和__call__
。
让我们来实现一个线性层:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
所有模块都有一个名称。当没有name
参数传递给模块时,其名称是从 Python 类的名称推断出来的(例如MyLinear
变为my_linear
)。模块可以具有使用hk.get_parameter(param_name, ...)
访问的命名参数。我们使用此 API(而不仅仅是使用对象属性),以便我们可以使用hk.transform
将您的代码转换为纯函数。
使用模块时,您需要定义函数并使用hk.transform
将它们转换为一对纯函数。有关从transform
返回的函数的更多详细信息,请参阅我们的快速入门:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
某些模型可能需要随机采样作为计算的一部分。例如,在具有重参数化技巧的变分自动编码器中,需要来自标准正态分布的随机样本。对于 dropout,我们需要一个随机掩码来从输入中删除单位。使用 JAX 进行这项工作的主要障碍是 PRNG 密钥的管理。
在 Haiku 中,我们提供了一个简单的 API 来维护与模块关联的 PRNG 键序列: hk.next_rng_key()
(或用于多个键的next_rng_keys()
):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
要更完整地了解如何使用随机模型,请参阅我们的 VAE 示例。
注意: hk.next_rng_key()
在功能上不是纯粹的,这意味着您应该避免将它与hk.transform
内的 JAX 转换一起使用。有关更多信息和可能的解决方法,请参阅有关 Haiku 转换的文档以及 Haiku 网络内 JAX 转换的可用包装器。
某些模型可能想要维护一些内部的可变状态。例如,在批量归一化中,维护训练期间遇到的值的移动平均值。
在 Haiku 中,我们提供了一个简单的 API 来维护与模块关联的可变状态: hk.set_state
和hk.get_state
。使用这些函数时,您需要使用hk.transform_with_state
转换函数,因为返回的函数对的签名不同:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
如果您忘记使用hk.transform_with_state
请不要担心,我们将打印一个明确的错误,将您指向hk.transform_with_state
而不是默默地删除您的状态。
jax.pmap
进行分布式训练从hk.transform
(或hk.transform_with_state
)返回的纯函数与jax.pmap
完全兼容。有关使用jax.pmap
进行 SPMD 编程的更多详细信息,请参阅此处。
jax.pmap
与 Haiku 的一种常见用途是在许多加速器上(可能跨多个主机)进行数据并行训练。对于俳句,这可能看起来像这样:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
要更完整地了解分布式俳句训练,请查看我们的 ImageNet 上的 ResNet-50 示例。
引用这个存储库:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
在此 bibtex 条目中,版本号来自haiku/__init__.py
,年份对应于该项目的开源版本。