Ikhtisar | Mengapa Haiku? | Mulai cepat | Instalasi | Contoh | Panduan pengguna | Dokumentasi | Mengutip Haiku
Penting
Mulai Juli 2023 Google DeepMind merekomendasikan agar proyek baru mengadopsi Flax, bukan Haiku. Flax adalah perpustakaan jaringan saraf yang awalnya dikembangkan oleh Google Brain dan sekarang oleh Google DeepMind.
Pada saat penulisan, Flax memiliki superset fitur yang tersedia di Haiku, tim pengembangan yang lebih besar dan lebih aktif serta lebih banyak adopsi oleh pengguna di luar Alphabet. Flax memiliki dokumentasi yang lebih luas, contoh-contoh dan komunitas aktif yang menciptakan contoh-contoh ujung ke ujung.
Haiku akan tetap mendukung upaya terbaik, namun proyek akan memasuki mode pemeliharaan, yang berarti bahwa upaya pengembangan akan difokuskan pada perbaikan bug dan kompatibilitas dengan rilis baru JAX.
Rilis baru akan dibuat agar Haiku tetap berfungsi dengan versi Python dan JAX yang lebih baru, namun kami tidak akan menambahkan (atau menerima PR untuk) fitur baru.
Kami memiliki penggunaan Haiku secara internal di Google DeepMind secara signifikan dan saat ini berencana untuk mendukung Haiku dalam mode ini tanpa batas waktu.
Haiku adalah sebuah alat
Untuk membangun jaringan saraf
Pikirkan: "Soneta untuk JAX"
Haiku adalah perpustakaan jaringan saraf sederhana untuk JAX yang dikembangkan oleh beberapa penulis Sonnet, perpustakaan jaringan saraf untuk TensorFlow.
Dokumentasi tentang Haiku dapat ditemukan di https://dm-haiku.readthedocs.io/.
Disambiguasi: jika Anda mencari sistem operasi Haiku, silakan lihat https://haiku-os.org/.
JAX adalah perpustakaan komputasi numerik yang menggabungkan NumPy, diferensiasi otomatis, dan dukungan GPU/TPU kelas satu.
Haiku adalah perpustakaan jaringan saraf sederhana untuk JAX yang memungkinkan pengguna untuk menggunakan model pemrograman berorientasi objek yang sudah dikenal sambil memungkinkan akses penuh ke transformasi fungsi murni JAX.
Haiku menyediakan dua alat inti: abstraksi modul, hk.Module
, dan transformasi fungsi sederhana, hk.transform
.
hk.Module
s adalah objek Python yang menyimpan referensi ke parameternya sendiri, modul lain, dan metode yang menerapkan fungsi pada input pengguna.
hk.transform
mengubah fungsi yang menggunakan modul berorientasi objek yang secara fungsional "tidak murni" menjadi fungsi murni yang dapat digunakan dengan jax.jit
, jax.grad
, jax.pmap
, dll.
Ada sejumlah perpustakaan jaringan saraf untuk JAX. Mengapa Anda harus memilih Haiku?
Module
Sonnet untuk manajemen negara sambil mempertahankan akses ke transformasi fungsi JAX.hk.transform
), Haiku bertujuan untuk mencocokkan API Soneta 2. Modul, metode, nama argumen, default, dan skema inisialisasi harus cocok.hk.next_rng_key()
mengembalikan kunci rng unik.Mari kita lihat contoh jaringan saraf, fungsi kerugian, dan loop pelatihan. (Untuk contoh lainnya, lihat direktori contoh kami. Contoh MNIST adalah awal yang baik.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
Inti dari Haiku adalah hk.transform
. Fungsi transform
memungkinkan Anda menulis fungsi jaringan neural yang mengandalkan parameter (di sini bobot lapisan Linear
) tanpa mengharuskan Anda menulis boilerplate secara eksplisit untuk menginisialisasi parameter tersebut. transform
melakukan ini dengan mengubah fungsi menjadi sepasang fungsi yang murni (seperti yang disyaratkan oleh JAX) init
dan apply
.
init
Fungsi init
, dengan tanda tangan params = init(rng, ...)
(di mana ...
adalah argumen untuk fungsi yang tidak ditransformasikan), memungkinkan Anda mengumpulkan nilai awal parameter apa pun di jaringan. Haiku melakukan ini dengan menjalankan fungsi Anda, melacak setiap parameter yang diminta melalui hk.get_parameter
(disebut dengan misalnya hk.Linear
) dan mengembalikannya kepada Anda.
Objek params
yang dikembalikan adalah struktur data bertingkat dari semua parameter di jaringan Anda, yang dirancang untuk Anda periksa dan manipulasi. Konkritnya adalah pemetaan nama modul ke parameter modul, dimana parameter modul adalah pemetaan nama parameter ke nilai parameter. Misalnya:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
Fungsi apply
, dengan tanda tangan result = apply(params, rng, ...)
, memungkinkan Anda memasukkan nilai parameter ke dalam fungsi Anda. Setiap kali hk.get_parameter
dipanggil, nilai yang dikembalikan akan berasal dari params
yang Anda berikan sebagai masukan untuk apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
Perhatikan bahwa karena penghitungan sebenarnya yang dilakukan oleh fungsi kerugian kita tidak bergantung pada bilangan acak, meneruskan generator bilangan acak tidak diperlukan, jadi kita juga dapat meneruskan None
untuk argumen rng
. (Perhatikan bahwa jika perhitungan Anda menggunakan angka acak, meneruskan None
untuk rng
akan menyebabkan kesalahan muncul.) Dalam contoh di atas, kami meminta Haiku untuk melakukan ini secara otomatis dengan:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
Karena apply
adalah fungsi murni, kita dapat meneruskannya ke jax.grad
(atau transformasi JAX lainnya):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
Loop pelatihan dalam contoh ini sangat sederhana. Satu detail yang perlu diperhatikan adalah penggunaan jax.tree.map
untuk menerapkan fungsi sgd
di semua entri yang cocok di params
dan grads
. Hasilnya memiliki struktur yang sama dengan params
sebelumnya dan dapat digunakan kembali dengan apply
.
Haiku ditulis dengan Python murni, tetapi bergantung pada kode C++ melalui JAX.
Karena instalasi JAX berbeda tergantung pada versi CUDA Anda, Haiku tidak mencantumkan JAX sebagai dependensi di requirements.txt
.
Pertama, ikuti petunjuk berikut untuk menginstal JAX dengan dukungan akselerator yang relevan.
Kemudian, instal Haiku menggunakan pip:
$ pip install git+https://github.com/deepmind/dm-haiku
Alternatifnya, Anda dapat menginstal melalui PyPI:
$ pip install -U dm-haiku
Contoh kami mengandalkan perpustakaan tambahan (misalnya bsuite). Anda dapat menginstal seluruh persyaratan tambahan menggunakan pip:
$ pip install -r examples/requirements.txt
Di Haiku, semua modul adalah subkelas dari hk.Module
. Anda dapat menerapkan metode apa pun yang Anda suka (tidak ada kasus khusus), tetapi biasanya modul mengimplementasikan __init__
dan __call__
.
Mari kita bekerja melalui penerapan lapisan linier:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
Semua modul memiliki nama. Ketika tidak ada argumen name
yang diteruskan ke modul, namanya disimpulkan dari nama kelas Python (misalnya MyLinear
menjadi my_linear
). Modul dapat memiliki parameter bernama yang diakses menggunakan hk.get_parameter(param_name, ...)
. Kami menggunakan API ini (bukan hanya menggunakan properti objek) sehingga kami dapat mengonversi kode Anda menjadi fungsi murni menggunakan hk.transform
.
Saat menggunakan modul, Anda perlu mendefinisikan fungsi dan mengubahnya menjadi sepasang fungsi murni menggunakan hk.transform
. Lihat mulai cepat kami untuk detail selengkapnya tentang fungsi yang dikembalikan dari transform
:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
Beberapa model mungkin memerlukan pengambilan sampel acak sebagai bagian dari penghitungan. Misalnya, pada autoencoder variasional dengan trik reparametrisasi, diperlukan sampel acak dari distribusi normal standar. Untuk dropout kita memerlukan masker acak untuk menjatuhkan unit dari input. Kendala utama dalam membuat pekerjaan ini dengan JAX adalah dalam pengelolaan kunci PRNG.
Di Haiku kami menyediakan API sederhana untuk memelihara rangkaian kunci PRNG yang terkait dengan modul: hk.next_rng_key()
(atau next_rng_keys()
untuk beberapa kunci):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
Untuk melihat lebih lengkap cara bekerja dengan model stokastik, silakan lihat contoh VAE kami.
Catatan: hk.next_rng_key()
tidak murni secara fungsional yang berarti Anda harus menghindari menggunakannya bersama transformasi JAX yang ada di dalam hk.transform
. Untuk informasi lebih lanjut dan kemungkinan solusi, silakan baca dokumen tentang transformasi Haiku dan wrapper yang tersedia untuk transformasi JAX di dalam jaringan Haiku.
Beberapa model mungkin ingin mempertahankan keadaan internal yang dapat berubah. Misalnya, dalam normalisasi batch, nilai rata-rata bergerak yang ditemui selama pelatihan dipertahankan.
Di Haiku kami menyediakan API sederhana untuk mempertahankan keadaan yang bisa berubah yang terkait dengan modul: hk.set_state
dan hk.get_state
. Saat menggunakan fungsi-fungsi ini, Anda perlu mengubah fungsi Anda menggunakan hk.transform_with_state
karena tanda tangan dari pasangan fungsi yang dikembalikan berbeda:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
Jika Anda lupa menggunakan hk.transform_with_state
jangan khawatir, kami akan mencetak kesalahan yang jelas yang mengarahkan Anda ke hk.transform_with_state
daripada menghapus status Anda secara diam-diam.
jax.pmap
Fungsi murni yang dikembalikan dari hk.transform
(atau hk.transform_with_state
) sepenuhnya kompatibel dengan jax.pmap
. Untuk lebih jelasnya mengenai pemrograman SPMD dengan jax.pmap
, lihat di sini.
Salah satu penggunaan umum jax.pmap
dengan Haiku adalah untuk pelatihan paralel data pada banyak akselerator, yang mungkin terjadi di banyak host. Dengan Haiku, tampilannya mungkin seperti ini:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
Untuk melihat lebih lengkap pelatihan Haiku terdistribusi, lihat contoh ResNet-50 kami di ImageNet.
Mengutip repositori ini:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
Dalam entri bibtex ini, nomor versi dimaksudkan dari haiku/__init__.py
, dan tahunnya sesuai dengan rilis sumber terbuka proyek.