Обзор | Почему Хайку? | Быстрый старт | Установка | Примеры | Руководство пользователя | Документация | Цитируя Хайку
Важный
С июля 2023 года Google DeepMind рекомендует в новых проектах использовать Flax вместо Haiku. Flax — это библиотека нейронных сетей, первоначально разработанная Google Brain, а теперь — Google DeepMind.
На момент написания Flax имеет расширенный набор функций, доступных в Haiku, более крупную и активную команду разработчиков и большее распространение среди пользователей за пределами Alphabet. У Flax более обширная документация, примеры и активное сообщество, создающее комплексные примеры.
Haiku по-прежнему будет поддерживаться всеми возможными способами, однако проект перейдет в режим обслуживания, а это означает, что усилия по разработке будут сосредоточены на исправлении ошибок и совместимости с новыми выпусками JAX.
Будут выпускаться новые выпуски, позволяющие Haiku работать с новыми версиями Python и JAX, однако мы не будем добавлять (или принимать PR) новые функции.
Мы активно используем Haiku внутри Google DeepMind и в настоящее время планируем поддерживать Haiku в этом режиме на неопределенный срок.
Хайку — это инструмент
Для построения нейронных сетей
Подумайте: «Сонет для JAX».
Haiku — это простая библиотека нейронных сетей для JAX, разработанная некоторыми авторами Sonnet, библиотеки нейронных сетей для TensorFlow.
Документацию по Haiku можно найти по адресу https://dm-haiku.readthedocs.io/.
Значение: если вы ищете операционную систему Haiku, посетите https://haiku-os.org/.
JAX — это библиотека числовых вычислений, сочетающая в себе NumPy, автоматическое дифференцирование и первоклассную поддержку GPU/TPU.
Haiku — это простая библиотека нейронных сетей для JAX, которая позволяет пользователям использовать знакомые модели объектно-ориентированного программирования , обеспечивая при этом полный доступ к преобразованиям чистых функций JAX.
Haiku предоставляет два основных инструмента: абстракцию модуля hk.Module
и простое преобразование функций hk.transform
.
hk.Module
— это объекты Python, которые содержат ссылки на свои собственные параметры, другие модули и методы, которые применяют функции к пользовательскому вводу.
hk.transform
превращает функции, использующие эти объектно-ориентированные, функционально «нечистые» модули, в чистые функции, которые можно использовать с jax.jit
, jax.grad
, jax.pmap
и т. д.
Для JAX существует ряд библиотек нейронных сетей. Почему вам стоит выбрать Хайку?
Module
Sonnet для управления состоянием, сохраняя при этом доступ к преобразованиям функций JAX.hk.transform
), Haiku стремится соответствовать API Sonnet 2. Модули, методы, имена аргументов, значения по умолчанию и схемы инициализации должны совпадать.hk.next_rng_key()
возвращает уникальный ключ rng.Давайте посмотрим на пример нейронной сети, функции потерь и цикла обучения. (Дополнительные примеры см. в нашем каталоге примеров. Пример MNIST — хорошее место для начала.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Ядро Haiku — hk.transform
. Функция transform
позволяет вам писать функции нейронной сети, которые полагаются на параметры (здесь веса Linear
слоев), не требуя явного написания шаблона для инициализации этих параметров. transform
делает это путем преобразования функции в пару чистых (как того требует JAX) функций init
и apply
.
init
Функция init
с сигнатурой params = init(rng, ...)
(где ...
— аргументы непреобразованной функции) позволяет собирать начальное значение любых параметров в сети. Haiku делает это, запуская вашу функцию, отслеживая все параметры, запрошенные через hk.get_parameter
(вызываемый, например, hk.Linear
), и возвращая их вам.
Возвращенный объект params
представляет собой вложенную структуру данных всех параметров вашей сети, предназначенную для проверки и управления вами. Конкретно, это сопоставление имени модуля с параметрами модуля, где параметр модуля представляет собой сопоставление имени параметра со значением параметра. Например:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
Функция apply
с сигнатурой result = apply(params, rng, ...)
позволяет вам вставлять значения параметров в вашу функцию. Всякий раз, когда вызывается hk.get_parameter
, возвращаемое значение будет получено из params
которые вы предоставляете в качестве входных данных для apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
Обратите внимание: поскольку фактические вычисления, выполняемые нашей функцией потерь, не основаны на случайных числах, передача генератора случайных чисел не требуется, поэтому мы также можем передать None
в качестве аргумента rng
. (Обратите внимание: если в ваших вычислениях используются случайные числа, передача None
для rng
приведет к возникновению ошибки.) В нашем примере выше мы просим Haiku сделать это за нас автоматически с помощью:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
Поскольку apply
— это чистая функция, мы можем передать ее в jax.grad
(или любое другое преобразование JAX):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
Цикл обучения в этом примере очень прост. Следует отметить одну деталь: использование jax.tree.map
для применения функции sgd
ко всем совпадающим записям в params
и grads
. Результат имеет ту же структуру, что и предыдущие params
, и его снова можно использовать с apply
.
Haiku написан на чистом Python, но зависит от кода C++ через JAX.
Поскольку установка JAX различается в зависимости от вашей версии CUDA, Haiku не указывает JAX как зависимость в файле requirements.txt
.
Сначала следуйте этим инструкциям, чтобы установить JAX с соответствующей поддержкой ускорителя.
Затем установите Haiku с помощью pip:
$ pip install git+https://github.com/deepmind/dm-haiku
Альтернативно вы можете установить через PyPI:
$ pip install -U dm-haiku
Наши примеры основаны на дополнительных библиотеках (например, bsuite). Вы можете установить полный набор дополнительных требований с помощью pip:
$ pip install -r examples/requirements.txt
В Haiku все модули являются подклассами hk.Module
. Вы можете реализовать любой метод, который вам нравится (ничего особенного), но обычно модули реализуют __init__
и __call__
.
Давайте поработаем над реализацией линейного слоя:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
Все модули имеют имя. Если модулю не передается аргумент name
, его имя выводится из имени класса Python (например, MyLinear
становится my_linear
). Модули могут иметь именованные параметры, доступ к которым осуществляется с помощью hk.get_parameter(param_name, ...)
. Мы используем этот API (а не просто свойства объекта), чтобы можно было преобразовать ваш код в чистую функцию с помощью hk.transform
.
При использовании модулей вам необходимо определить функции и преобразовать их в пару чистых функций с помощью hk.transform
. Дополнительную информацию о функциях, возвращаемых transform
см. в нашем кратком руководстве:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
Некоторые модели могут потребовать случайной выборки в рамках вычислений. Например, в вариационных автоэнкодерах с приемом перепараметризации необходима случайная выборка из стандартного нормального распределения. Для исключения нам нужна случайная маска для удаления единиц из ввода. Основное препятствие при работе с JAX заключается в управлении ключами PRNG.
В Haiku мы предоставляем простой API для поддержки последовательности ключей PRNG, связанной с модулями: hk.next_rng_key()
(или next_rng_keys()
для нескольких ключей):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
Более полное представление о работе со стохастическими моделями см. в нашем примере VAE.
Примечание. hk.next_rng_key()
не является функционально чистым, что означает, что вам следует избегать его использования вместе с преобразованиями JAX, которые находятся внутри hk.transform
. Для получения дополнительной информации и возможных обходных путей обратитесь к документации по преобразованиям Haiku и доступным оболочкам для преобразований JAX внутри сетей Haiku.
Некоторые модели могут захотеть сохранить какое-то внутреннее изменяемое состояние. Например, при пакетной нормализации сохраняется скользящее среднее значений, встречающихся во время обучения.
В Haiku мы предоставляем простой API для поддержания изменяемого состояния, связанного с модулями: hk.set_state
и hk.get_state
. При использовании этих функций вам необходимо преобразовать вашу функцию с помощью hk.transform_with_state
поскольку сигнатура возвращаемой пары функций отличается:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
Если вы забудете использовать hk.transform_with_state
, не волнуйтесь, мы напечатаем явную ошибку, указывающую вам на hk.transform_with_state
, а не молча отбросим ваше состояние.
jax.pmap
Чистые функции, возвращаемые из hk.transform
(или hk.transform_with_state
), полностью совместимы с jax.pmap
. Более подробную информацию о программировании SPMD с помощью jax.pmap
можно найти здесь.
Одним из распространенных вариантов использования jax.pmap
с Haiku является параллельное обучение данным на многих ускорителях, возможно, на нескольких хостах. В Haiku это может выглядеть так:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
Чтобы получить более полное представление о распределенном обучении Haiku, взгляните на наш пример ResNet-50 на ImageNet.
Чтобы процитировать этот репозиторий:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
В этой записи bibtex номер версии должен быть от haiku/__init__.py
, а год соответствует выпуску проекта с открытым исходным кодом.