概述|為什麼是俳句? |快速入門|安裝|範例|使用者手冊|文檔|引用俳句
重要的
自 2023 年 7 月起,Google DeepMind 建議新專案採用 Flax 而非 Haiku。 Flax 是一個神經網路函式庫,最初由 Google Brain 開發,現在由 Google DeepMind 開發。
在撰寫本文時,Flax 已經擁有 Haiku 中可用功能的超集,擁有更大、更活躍的開發團隊,並且得到了 Alphabet 以外用戶的更多採用。 Flax 擁有更廣泛的文件、範例和建立端到端範例的活躍社群。
Haiku 將繼續盡力支持,但該專案將進入維護模式,這意味著開發工作將集中在錯誤修復和與 JAX 新版本的兼容性上。
將發布新版本以使 Haiku 能夠與較新版本的 Python 和 JAX 一起使用,但我們不會添加(或接受 PR)新功能。
我們在 Google DeepMind 內部大量使用 Haiku,目前計劃無限期地支援這種模式的 Haiku。
俳句是一種工具
用於建構神經網絡
思考:“JAX 十四行詩”
Haiku 是一個簡單的 JAX 神經網路函式庫,由 Sonnet(TensorFlow 的神經網路函式庫)的一些作者所開發。
有關 Haiku 的文檔可以在 https://dm-haiku.readthedocs.io/ 找到。
消歧義:如果您正在尋找 Haiku 作業系統,請參閱 https://haiku-os.org/。
JAX 是一個數值運算庫,結合了 NumPy、自動微分和一流的 GPU/TPU 支援。
Haiku 是一個簡單的 JAX 神經網路函式庫,使用戶能夠使用熟悉的物件導向程式設計模型,同時允許完全存取 JAX 的純函數轉換。
Haiku 提供了兩個核心工具:模組抽象hk.Module
和簡單的函數轉換hk.transform
。
hk.Module
是 Python 對象,它們保存對其自身參數、其他模組以及對使用者輸入應用函數的方法的參考。
hk.transform
將使用這些物件導向的、功能上「不純」的模組的函數轉換為可與jax.jit
、 jax.grad
、 jax.pmap
等一起使用的純函數。
JAX 有許多神經網路函式庫。為什麼要選擇俳句?
Module
程式設計模型以進行狀態管理,同時保留對 JAX 函數轉換的存取。hk.transform
)之外,Haiku 的目標是匹配 Sonnet 2 的 API。hk.next_rng_key()
傳回一個唯一的 rng 鍵。讓我們來看看神經網路、損失函數和訓練循環的範例。 (有關更多範例,請參閱我們的範例目錄。MNIST 範例是一個很好的起點。)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Haiku 的核心是hk.transform
。 transform
函數可讓您編寫依賴參數(此處為Linear
層的權重)的神經網路函數,而不需要您明確編寫用於初始化這些參數的樣板檔案。 transform
透過將函數轉換為一對純函數(根據 JAX 的要求) init
和apply
來實現此目的。
init
init
函數的簽章為params = init(rng, ...)
(其中...
是未轉換函數的參數),可讓您收集網路中任何參數的初始值。 Haiku 透過執行您的函數、追蹤透過hk.get_parameter
(例如hk.Linear
呼叫)請求的任何參數並將它們傳回給您來實現此目的。
傳回的params
是網路中所有參數的嵌套資料結構,專為您檢查和操作而設計。具體來說,它是模組名稱到模組參數的映射,其中模組參數是參數名稱到參數值的映射。例如:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
apply
函數的簽章為result = apply(params, rng, ...)
,可讓您將參數值注入到函數中。每當呼叫hk.get_parameter
時,傳回的值將來自您作為apply
輸入提供的params
:
loss = loss_fn_t . apply ( params , rng , images , labels )
請注意,由於損失函數執行的實際計算不依賴隨機數,因此不需要傳入隨機數產生器,因此我們也可以為rng
參數傳入None
。 (請注意,如果您的計算確實使用隨機數,則為rng
傳遞None
將導致引發錯誤。)在上面的範例中,我們要求 Haiku 自動為我們執行此操作:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
由於apply
是一個純函數,我們可以將其傳遞給jax.grad
(或任何 JAX 的其他轉換):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
此範例中的訓練循環非常簡單。需要注意的一個細節是使用jax.tree.map
將sgd
函數應用於params
和grads
中的所有匹配條目。結果與前面的params
具有相同的結構,並且可以再次與apply
一起使用。
Haiku 是用純 Python 寫的,但依賴透過 JAX 的 C++ 程式碼。
由於 JAX 安裝會根據您的 CUDA 版本而有所不同,因此 Haiku 不會在requirements.txt
中將 JAX 列為依賴項。
首先,請按照以下說明安裝具有相關加速器支援的 JAX。
然後,使用 pip 安裝 Haiku:
$ pip install git+https://github.com/deepmind/dm-haiku
或者,您可以透過 PyPI 安裝:
$ pip install -U dm-haiku
我們的範例依賴其他庫(例如 bsuite)。您可以使用 pip 安裝全套附加要求:
$ pip install -r examples/requirements.txt
在俳句中,所有模組都是hk.Module
的子類別。您可以實現任何您喜歡的方法(沒有什麼特殊情況),但通常模組實作__init__
和__call__
。
讓我們來實作一個線性層:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
所有模組都有一個名稱。當沒有name
參數傳遞給模組時,其名稱是從 Python 類別的名稱推斷出來的(例如MyLinear
變成my_linear
)。模組可以具有使用hk.get_parameter(param_name, ...)
存取的命名參數。我們使用此 API(而不僅僅是使用物件屬性),以便我們可以使用hk.transform
將您的程式碼轉換為純函數。
使用模組時,您需要定義函數並使用hk.transform
將它們轉換為一對純函數。有關從transform
返回的函數的更多詳細信息,請參閱我們的快速入門:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
某些模型可能需要隨機取樣作為計算的一部分。例如,在具有重參數化技巧的變分自動編碼器中,需要來自標準常態分佈的隨機樣本。對於 dropout,我們需要一個隨機遮罩來從輸入中刪除單位。使用 JAX 進行這項工作的主要障礙是 PRNG 金鑰的管理。
在 Haiku 中,我們提供了一個簡單的 API 來維護與模組關聯的 PRNG 鍵序列: hk.next_rng_key()
(或用於多個鍵的next_rng_keys()
):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
若要更完整地了解如何使用隨機模型,請參閱我們的 VAE 範例。
注意: hk.next_rng_key()
在功能上不是純粹的,這意味著您應該避免將它與hk.transform
內的 JAX 轉換一起使用。如需更多資訊和可能的解決方法,請參閱有關 Haiku 轉換的文檔以及 Haiku 網路內 JAX 轉換的可用包裝器。
某些模型可能想要維護一些內部的可變狀態。例如,在批量歸一化中,維護訓練期間遇到的值的移動平均值。
在 Haiku 中,我們提供了一個簡單的 API 來維護與模組關聯的可變狀態: hk.set_state
和hk.get_state
。使用這些函數時,您需要使用hk.transform_with_state
轉換函數,因為傳回的函數對的簽名不同:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
如果您忘記使用hk.transform_with_state
請不要擔心,我們將列印一個明確的錯誤,並將您指向hk.transform_with_state
而不是默默地刪除您的狀態。
jax.pmap
進行分散式訓練從hk.transform
(或hk.transform_with_state
)傳回的純函數與jax.pmap
完全相容。有關使用jax.pmap
進行 SPMD 編程的更多詳細信息,請參閱此處。
jax.pmap
與 Haiku 的常見用途是在許多加速器上(可能跨多個主機)進行資料並行訓練。對於俳句,這可能看起來像這樣:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
要更完整地了解分散式俳句訓練,請查看我們的 ImageNet 上的 ResNet-50 範例。
引用這個儲存庫:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
在此 bibtex 條目中,版本號碼來自haiku/__init__.py
,年份對應於該專案的開源版本。