概要|なぜ俳句なのか? |クイックスタート|インストール|例|ユーザーマニュアル|ドキュメント|俳句を引用する
重要
2023 年 7 月の時点で、Google DeepMind は新しいプロジェクトに Haiku の代わりに Flax を採用することを推奨しています。 Flax は、当初 Google Brain によって開発され、現在は Google DeepMind によって開発されたニューラル ネットワーク ライブラリです。
この記事の執筆時点では、Flax は Haiku で利用できる機能のスーパーセットを備えており、より大規模で活発な開発チームが存在し、Alphabet 以外のユーザーへの採用も増えています。 Flax には、より広範なドキュメント、サンプル、およびエンドツーエンドのサンプルを作成するアクティブなコミュニティがあります。
Haiku は引き続きベストエフォート型でサポートされますが、プロジェクトはメンテナンス モードに入ります。つまり、開発作業はバグ修正と JAX の新しいリリースとの互換性に集中することになります。
新しいリリースは、Haiku が新しいバージョンの Python および JAX で動作し続けるように作成されますが、新しい機能を追加する (または PR を受け入れる) ことはありません。
Google DeepMind では社内で Haiku を頻繁に使用しており、現在、このモードで Haiku を無期限にサポートする予定です。
俳句はツールです
ニューラルネットワーク構築用
「JAX のためのソネット」を考えてみましょう。
Haiku は、TensorFlow 用のニューラル ネットワーク ライブラリである Sonnet の作者の一部によって開発された、JAX 用のシンプルなニューラル ネットワーク ライブラリです。
Haiku に関するドキュメントは、https://dm-haiku.readthedocs.io/ でご覧いただけます。
曖昧さ回避: Haiku オペレーティング システムをお探しの場合は、https://haiku-os.org/ を参照してください。
JAX は、NumPy、自動微分、および一流の GPU/TPU サポートを組み合わせた数値計算ライブラリです。
Haiku は JAX 用のシンプルなニューラル ネットワーク ライブラリであり、これを使用すると、ユーザーは使い慣れたオブジェクト指向プログラミング モデルを使用できると同時に、JAX の純粋な関数変換に完全にアクセスできます。
Haiku は、モジュール抽象化hk.Module
と単純な関数変換hk.transform
という 2 つのコア ツールを提供します。
hk.Module
は、独自のパラメータ、他のモジュール、およびユーザー入力に関数を適用するメソッドへの参照を保持する Python オブジェクトです。
hk.transform
これらのオブジェクト指向の機能的に「不純な」モジュールを使用する関数をjax.jit
、 jax.grad
、 jax.pmap
などで使用できる純粋な関数に変換します。
JAX には多数のニューラル ネットワーク ライブラリがあります。なぜ俳句を選ぶべきなのでしょうか?
Module
ベースのプログラミング モデルを保持します。hk.transform
) 以外では、Haiku は Sonnet 2 の API と一致することを目指しています。モジュール、メソッド、引数名、デフォルト、および初期化スキームは一致する必要があります。hk.next_rng_key()
一意の rng キーを返します。ニューラル ネットワーク、損失関数、トレーニング ループの例を見てみましょう。 (その他の例については、サンプル ディレクトリを参照してください。MNIST のサンプルから始めるとよいでしょう。)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Haiku の中核はhk.transform
です。 transform
関数を使用すると、パラメータ (ここではLinear
層の重み) に依存するニューラル ネットワーク関数を作成できます。これらのパラメータを初期化するためのボイラープレートを明示的に作成する必要はありません。これは、関数を純粋な(JAX で要求される) init
とapply
関数のペアにtransform
ことによって行われます。
init
シグネチャparams = init(rng, ...)
(ここで...
は未変換関数への引数) を持つinit
関数を使用すると、ネットワーク内の任意のパラメータの初期値を収集できます。 Haiku は、関数を実行し、 hk.get_parameter
(たとえばhk.Linear
によって呼び出される) を通じて要求されたパラメータを追跡し、それらを返すことによってこれを行います。
返されるparams
オブジェクトは、ネットワーク内のすべてのパラメータのネストされたデータ構造であり、検査および操作できるように設計されています。具体的には、これはモジュール名からモジュール パラメータへのマッピングであり、モジュール パラメータはパラメータ名からパラメータ値へのマッピングです。例えば:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
apply
関数をシグネチャresult = apply(params, rng, ...)
で使用すると、関数にパラメーター値を挿入できます。 hk.get_parameter
が呼び出されるたびに、返される値は、 apply
への入力として指定したparams
から取得されます。
loss = loss_fn_t . apply ( params , rng , images , labels )
損失関数によって実行される実際の計算は乱数に依存しないため、乱数ジェネレーターを渡す必要がないため、 rng
引数にNone
を渡すこともできることに注意してください。 (計算で乱数を使用する場合、 rng
にNone
渡すとエラーが発生することに注意してください。) 上記の例では、次のようにしてこれを自動的に実行するように Haiku に依頼しています。
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
apply
純粋な関数であるため、それをjax.grad
(または JAX の他の変換) に渡すことができます。
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
この例のトレーニング ループは非常に単純です。注意すべき詳細の 1 つは、 jax.tree.map
を使用して、 params
およびgrads
内の一致するすべてのエントリにsgd
関数を適用することです。結果は前のparams
と同じ構造を持ち、再度apply
で使用できます。
Haiku は純粋な Python で書かれていますが、JAX を介した C++ コードに依存しています。
JAX のインストールは CUDA バージョンに応じて異なるため、Haiku は、 requirements.txt
に JAX を依存関係としてリストしません。
まず、次の手順に従って、関連するアクセラレータ サポートを備えた JAX をインストールします。
次に、pip を使用して Haiku をインストールします。
$ pip install git+https://github.com/deepmind/dm-haiku
あるいは、PyPI 経由でインストールすることもできます。
$ pip install -U dm-haiku
私たちの例は追加のライブラリ (bsuite など) に依存しています。 pip を使用して、追加要件の完全なセットをインストールできます。
$ pip install -r examples/requirements.txt
Haiku では、すべてのモジュールはhk.Module
のサブクラスです。任意のメソッドを実装できます (特別な場合はありません) が、通常、モジュールは__init__
と__call__
を実装します。
線形層を実装してみましょう。
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
すべてのモジュールには名前があります。 name
引数がモジュールに渡されない場合、その名前は Python クラスの名前から推測されます (たとえば、 MyLinear
my_linear
になります)。モジュールには、 hk.get_parameter(param_name, ...)
を使用してアクセスする名前付きパラメーターを含めることができます。 hk.transform
を使用してコードを純粋な関数に変換できるように、(単にオブジェクト プロパティを使用するのではなく) この API を使用します。
モジュールを使用する場合は、関数を定義し、 hk.transform
を使用してそれらを純粋な関数のペアに変換する必要があります。 transform
から返される関数の詳細については、クイックスタートを参照してください。
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
一部のモデルでは、計算の一部としてランダム サンプリングが必要な場合があります。たとえば、再パラメータ化トリックを備えた変分オートエンコーダでは、標準正規分布からのランダム サンプルが必要です。ドロップアウトの場合、入力からユニットをドロップするためのランダム マスクが必要です。これを JAX で機能させる際の主な障害は、PRNG キーの管理にあります。
Haiku では、モジュールに関連付けられた PRNG キー シーケンスを維持するためのシンプルな API、 hk.next_rng_key()
(複数のキーの場合はnext_rng_keys()
) を提供します。
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
確率モデルの操作の詳細については、VAE の例を参照してください。
注: hk.next_rng_key()
は機能的に純粋ではないため、 hk.transform
内の JAX 変換と一緒に使用することは避けるべきです。詳細と考えられる回避策については、Haiku 変換に関するドキュメントと、Haiku ネットワーク内の JAX 変換に利用可能なラッパーを参照してください。
モデルによっては、内部の変更可能な状態を維持する必要がある場合があります。たとえば、バッチ正規化では、トレーニング中に発生した値の移動平均が維持されます。
Haiku では、モジュールに関連付けられた変更可能な状態を維持するためのシンプルな API ( hk.set_state
およびhk.get_state
を提供します。これらの関数を使用する場合、返される関数のペアの署名が異なるため、 hk.transform_with_state
使用して関数を変換する必要があります。
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
hk.transform_with_state
の使用を忘れた場合でも、心配する必要はありません。状態を黙って削除するのではなく、 hk.transform_with_state
を示す明確なエラーが出力されます。
jax.pmap
を使用した分散トレーニングhk.transform
(またはhk.transform_with_state
) から返される純粋な関数は、 jax.pmap
と完全な互換性があります。 jax.pmap
を使用した SPMD プログラミングの詳細については、ここを参照してください。
Haiku でのjax.pmap
の一般的な使用法の 1 つは、多くのアクセラレータ上で、場合によっては複数のホストにわたるデータ並列トレーニングに使用することです。 Haiku の場合、次のようになります。
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
分散型 Haiku トレーニングの詳細については、ImageNet 上の ResNet-50 の例をご覧ください。
このリポジトリを引用するには:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
この bibtex エントリでは、バージョン番号はhaiku/__init__.py
から取得することを意図しており、年はプロジェクトのオープンソース リリースに対応しています。