개요 | 왜 하이쿠인가? | 빠른 시작 | 설치 | 예 | 사용 설명서 | 문서 | 하이쿠를 인용하다
중요한
2023년 7월부터 Google DeepMind는 새로운 프로젝트에 Haiku 대신 Flax를 채택할 것을 권장합니다. Flax는 원래 Google Brain에서 개발했으며 현재는 Google DeepMind에서 개발한 신경망 라이브러리입니다.
이 글을 쓰는 시점에서 Flax는 Haiku에서 사용할 수 있는 기능의 상위 집합을 보유하고 있으며 더 크고 활동적인 개발 팀이 있으며 Alphabet 외부 사용자의 채택도 더 많습니다. Flax에는 보다 광범위한 문서, 예제 및 엔드투엔드 예제를 생성하는 활발한 커뮤니티가 있습니다.
Haiku는 최선의 지원을 받을 예정이지만 프로젝트는 유지 관리 모드로 전환됩니다. 즉, 개발 노력은 버그 수정 및 JAX의 새 릴리스와의 호환성에 중점을 둘 것입니다.
Haiku가 최신 버전의 Python 및 JAX와 계속 작동하도록 새 릴리스가 만들어질 예정이지만, 새로운 기능을 추가하거나 이에 대한 PR을 수락하지 않을 것입니다.
우리는 Google DeepMind 내부적으로 하이쿠를 많이 사용하고 있으며 현재 이 모드에서 하이쿠를 무기한 지원할 계획입니다.
하이쿠는 도구이다
신경망 구축용
생각해 보세요: "JAX용 소네트"
Haiku는 TensorFlow용 신경망 라이브러리인 Sonnet의 일부 저자가 개발한 JAX용 간단한 신경망 라이브러리입니다.
하이쿠에 관한 문서는 https://dm-haiku.readthedocs.io/에서 찾을 수 있습니다.
명확성: Haiku 운영 체제를 찾고 있다면 https://haiku-os.org/를 참조하세요.
JAX는 NumPy, 자동 미분 및 최고 수준의 GPU/TPU 지원을 결합한 수치 컴퓨팅 라이브러리입니다.
Haiku는 사용자가 익숙한 객체 지향 프로그래밍 모델을 사용하는 동시에 JAX의 순수 함수 변환에 대한 전체 액세스를 허용하는 JAX용 간단한 신경망 라이브러리입니다.
Haiku는 모듈 추상화인 hk.Module
과 간단한 함수 변환인 hk.transform
두 가지 핵심 도구를 제공합니다.
hk.Module
은 자체 매개변수, 기타 모듈 및 사용자 입력에 함수를 적용하는 메소드에 대한 참조를 보유하는 Python 객체입니다.
hk.transform
이러한 객체 지향적이고 기능적으로 "불순한" 모듈을 사용하는 함수를 jax.jit
, jax.grad
, jax.pmap
등과 함께 사용할 수 있는 순수 함수로 변환합니다.
JAX용 신경망 라이브러리가 많이 있습니다. 하이쿠를 선택해야 하는 이유는 무엇인가요?
Module
기반 프로그래밍 모델을 유지합니다.hk.transform
) 외에 Haiku는 Sonnet 2의 API와 일치하는 것을 목표로 합니다. 모듈, 메서드, 인수 이름, 기본값 및 초기화 체계가 일치해야 합니다.hk.next_rng_key()
고유한 rng 키를 반환합니다.신경망, 손실 함수 및 훈련 루프의 예를 살펴보겠습니다. (더 많은 예제를 보려면 예제 디렉터리를 참조하세요. MNIST 예제는 시작하기에 좋은 곳입니다.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
하이쿠의 핵심은 hk.transform
입니다. transform
함수를 사용하면 해당 매개변수를 초기화하기 위한 상용구를 명시적으로 작성할 필요 없이 매개변수(여기서는 Linear
레이어의 가중치)에 의존하는 신경망 함수를 작성할 수 있습니다. transform
함수를 순수 (JAX에서 요구하는 대로) init
및 apply
함수 쌍으로 변환하여 이를 수행합니다.
init
params = init(rng, ...)
(여기서 ...
변환되지 않은 함수에 대한 인수)라는 서명이 있는 init
함수를 사용하면 네트워크에 있는 모든 매개변수의 초기 값을 수집 할 수 있습니다. Haiku는 함수를 실행하고 hk.get_parameter
(예: hk.Linear
에 의해 호출됨)를 통해 요청된 매개변수를 추적하고 이를 반환함으로써 이를 수행합니다.
반환된 params
객체는 네트워크에 있는 모든 매개변수의 중첩된 데이터 구조이며, 사용자가 검사하고 조작할 수 있도록 설계되었습니다. 구체적으로 모듈 이름을 모듈 매개변수로 매핑하는 것이며, 여기서 모듈 매개변수는 매개변수 이름을 매개변수 값으로 매핑하는 것입니다. 예를 들어:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
서명 result = apply(params, rng, ...)
인 apply
함수를 사용하면 매개변수 값을 함수에 주입 할 수 있습니다. hk.get_parameter
호출될 때마다 반환된 값은 apply
입력으로 제공한 params
에서 나옵니다.
loss = loss_fn_t . apply ( params , rng , images , labels )
손실 함수에 의해 수행되는 실제 계산은 난수에 의존하지 않기 때문에 난수 생성기를 전달할 필요가 없으므로 rng
인수에 None
전달할 수도 있습니다. (계산에서 난수를 사용 하는 경우 rng
에 None
전달하면 오류가 발생합니다.) 위의 예에서 우리는 Haiku에게 다음을 사용하여 이 작업을 자동으로 수행하도록 요청합니다.
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
apply
순수 함수이므로 이를 jax.grad
(또는 JAX의 다른 변환 중 하나)에 전달할 수 있습니다.
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
이 예의 훈련 루프는 매우 간단합니다. 주목해야 할 한 가지 세부 사항은 jax.tree.map
을 사용하여 params
및 grads
의 일치하는 모든 항목에 sgd
함수를 적용한다는 것입니다. 결과는 이전 params
와 동일한 구조를 가지며 apply
와 함께 다시 사용할 수 있습니다.
Haiku는 순수 Python으로 작성되었지만 JAX를 통한 C++ 코드에 의존합니다.
JAX 설치는 CUDA 버전에 따라 다르기 때문에 Haiku는 JAX를 requirements.txt
의 종속성으로 나열하지 않습니다.
먼저 다음 지침에 따라 관련 가속기 지원과 함께 JAX를 설치하세요.
그런 다음 pip를 사용하여 Haiku를 설치합니다.
$ pip install git+https://github.com/deepmind/dm-haiku
또는 PyPI를 통해 설치할 수 있습니다.
$ pip install -U dm-haiku
우리의 예제는 추가 라이브러리(예: bsuite)에 의존합니다. pip를 사용하여 추가 요구 사항 전체 세트를 설치할 수 있습니다.
$ pip install -r examples/requirements.txt
Haiku에서 모든 모듈은 hk.Module
의 하위 클래스입니다. 원하는 메서드를 구현할 수 있지만(특별한 경우는 없음) 일반적으로 모듈은 __init__
및 __call__
구현합니다.
선형 레이어를 구현해 보겠습니다.
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
모든 모듈에는 이름이 있습니다. name
인수가 모듈에 전달되지 않으면 해당 이름은 Python 클래스의 이름에서 유추됩니다(예를 들어 MyLinear
my_linear
됩니다). 모듈에는 hk.get_parameter(param_name, ...)
사용하여 액세스되는 명명된 매개변수가 있을 수 있습니다. 우리는 (객체 속성을 사용하는 대신) 이 API를 사용하여 hk.transform
사용하여 코드를 순수 함수로 변환할 수 있습니다.
모듈을 사용할 때 함수를 정의하고 hk.transform
사용하여 이를 순수 함수 쌍으로 변환해야 합니다. transform
에서 반환된 함수에 대한 자세한 내용은 빠른 시작을 참조하세요.
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
일부 모델에는 계산의 일부로 무작위 샘플링이 필요할 수 있습니다. 예를 들어, 재매개변수화 트릭을 사용하는 변형 자동 인코더에서는 표준 정규 분포의 무작위 표본이 필요합니다. 드롭아웃의 경우 입력에서 단위를 삭제하는 무작위 마스크가 필요합니다. JAX로 이 작업을 수행하는 데 있어 주요 장애물은 PRNG 키를 관리하는 것입니다.
Haiku에서는 모듈과 관련된 PRNG 키 시퀀스를 유지하기 위한 간단한 API를 제공합니다: hk.next_rng_key()
(또는 다중 키의 경우 next_rng_keys()
):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
확률론적 모델 작업에 대한 자세한 내용은 VAE 예제를 참조하세요.
참고: hk.next_rng_key()
기능적으로 순수하지 않습니다. 즉, hk.transform
내부에 있는 JAX 변환과 함께 사용하지 않아야 한다는 의미입니다. 자세한 내용과 가능한 해결 방법은 Haiku 변환 및 Haiku 네트워크 내부의 JAX 변환에 사용 가능한 래퍼에 대한 문서를 참조하세요.
일부 모델은 내부의 변경 가능한 상태를 유지하기를 원할 수 있습니다. 예를 들어 배치 정규화에서는 훈련 중에 발생한 값의 이동 평균이 유지됩니다.
Haiku에서는 모듈과 관련된 변경 가능한 상태를 유지하기 위한 간단한 API인 hk.set_state
및 hk.get_state
제공합니다. 이러한 함수를 사용할 때 반환된 함수 쌍의 서명이 다르기 때문에 hk.transform_with_state
사용하여 함수를 변환해야 합니다.
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
hk.transform_with_state
사용을 잊어버린 경우 걱정하지 마십시오. 상태를 자동으로 삭제하는 대신 hk.transform_with_state
가리키는 명확한 오류가 인쇄됩니다.
jax.pmap
사용한 분산 교육 hk.transform
(또는 hk.transform_with_state
)에서 반환된 순수 함수는 jax.pmap
과 완전히 호환됩니다. jax.pmap
사용한 SPMD 프로그래밍에 대한 자세한 내용은 여기를 참조하세요.
Haiku와 함께 jax.pmap
사용하는 일반적인 용도 중 하나는 잠재적으로 여러 호스트에 걸쳐 많은 가속기에 대한 데이터 병렬 교육입니다. 하이쿠를 사용하면 다음과 같이 보일 수 있습니다.
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
분산형 Haiku 교육에 대한 더 자세한 내용을 보려면 ImageNet의 ResNet-50 예제를 살펴보세요.
이 저장소를 인용하려면:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
이 bibtex 항목에서 버전 번호는 haiku/__init__.py
에서 가져오고 연도는 프로젝트의 오픈 소스 릴리스에 해당합니다.