ภาพรวม | ทำไมต้องไฮกุ? - เริ่มต้นอย่างรวดเร็ว | การติดตั้ง | ตัวอย่าง | คู่มือการใช้งาน | เอกสารประกอบ | อ้างถึงไฮกุ
สำคัญ
ในเดือนกรกฎาคม 2023 Google DeepMind แนะนำให้โปรเจ็กต์ใหม่ใช้ Flax แทน Haiku Flax เป็นไลบรารีโครงข่ายประสาทเทียมที่สร้างสรรค์โดย Google Brain และขณะนี้โดย Google DeepMind
ในขณะที่เขียน Flax มีคุณสมบัติที่เหนือกว่าที่มีอยู่ใน Haiku ซึ่งเป็นทีมพัฒนาที่ใหญ่ขึ้นและกระตือรือร้นมากขึ้น และมีการนำไปใช้มากขึ้นกับผู้ใช้ภายนอก Alphabet Flax มีเอกสาร ตัวอย่าง และชุมชนที่กระตือรือร้นที่สร้างตัวอย่างตั้งแต่ต้นจนจบ
Haiku จะยังคงได้รับการสนับสนุนอย่างเต็มที่ อย่างไรก็ตาม โปรเจ็กต์จะเข้าสู่โหมดการบำรุงรักษา ซึ่งหมายความว่าความพยายามในการพัฒนาจะมุ่งเน้นไปที่การแก้ไขข้อบกพร่องและความเข้ากันได้กับ JAX รุ่นใหม่
จะมีการออกรุ่นใหม่เพื่อให้ Haiku ทำงานร่วมกับ Python และ JAX เวอร์ชันใหม่ได้ อย่างไรก็ตาม เราจะไม่เพิ่ม (หรือยอมรับ PR สำหรับ) คุณสมบัติใหม่
เรามีการใช้งาน Haiku จำนวนมากเป็นการภายในที่ Google DeepMind และขณะนี้วางแผนที่จะสนับสนุน Haiku ในโหมดนี้อย่างไม่มีกำหนด
ไฮกุเป็นเครื่องมือ
สำหรับการสร้างโครงข่ายประสาทเทียม
คิดว่า: "โคลงสำหรับ JAX"
Haiku เป็นไลบรารีโครงข่ายประสาทเทียมแบบเรียบง่ายสำหรับ JAX ที่พัฒนาโดยผู้เขียน Sonnet ซึ่งเป็นไลบรารีโครงข่ายประสาทเทียมสำหรับ TensorFlow
สามารถอ่านเอกสารเกี่ยวกับไฮกุได้ที่https://dm-haiku.readthedocs.io/
แก้ความกำกวม: หากคุณกำลังมองหา Haiku ระบบปฏิบัติการ โปรดดู https://haiku-os.org/
JAX เป็นไลบรารีการคำนวณเชิงตัวเลขที่รวม NumPy การสร้างความแตกต่างอัตโนมัติ และการรองรับ GPU/TPU ชั้นหนึ่ง
Haiku เป็นไลบรารีโครงข่ายประสาทเทียมอย่างง่ายสำหรับ JAX ที่ให้ผู้ใช้สามารถใช้ โมเดลการเขียนโปรแกรมเชิงวัตถุ ที่คุ้นเคย ในขณะเดียวกันก็อนุญาตให้เข้าถึงการแปลงฟังก์ชันที่แท้จริงของ JAX ได้อย่างเต็มที่
Haiku มีเครื่องมือหลักสองอย่าง: module abstraction, hk.Module
และการแปลงฟังก์ชันอย่างง่าย hk.transform
hk.Module
s เป็นอ็อบเจ็กต์ Python ที่เก็บการอ้างอิงถึงพารามิเตอร์ของตัวเอง โมดูลอื่นๆ และเมธอดที่ใช้ฟังก์ชันกับอินพุตของผู้ใช้
hk.transform
เปลี่ยนฟังก์ชันที่ใช้โมดูลเชิงวัตถุและเชิงฟังก์ชัน "ไม่บริสุทธิ์" เหล่านี้ให้เป็นฟังก์ชันล้วนๆ ที่สามารถใช้ได้กับ jax.jit
, jax.grad
, jax.pmap
ฯลฯ
มีไลบรารีโครงข่ายประสาทเทียมจำนวนหนึ่งสำหรับ JAX ทำไมคุณถึงเลือกไฮกุ?
Module
ของ Sonnet สำหรับการจัดการสถานะ ในขณะที่ยังคงสามารถเข้าถึงการแปลงฟังก์ชันของ JAX ได้hk.transform
) ไฮกุมุ่งหวังที่จะจับคู่ API ของ Sonnet 2 โมดูล วิธีการ ชื่ออาร์กิวเมนต์ ค่าเริ่มต้น และโครงร่างการเริ่มต้นควรตรงกันhk.next_rng_key()
จะส่งคืนคีย์ rng ที่ไม่ซ้ำใครมาดูตัวอย่างโครงข่ายประสาทเทียม ฟังก์ชันการสูญเสีย และลูปการฝึก (สำหรับตัวอย่างเพิ่มเติม โปรดดูไดเร็กทอรีตัวอย่างของเรา ตัวอย่าง MNIST เป็นจุดเริ่มต้นที่ดี)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
แก่นแท้ของไฮกุคือ hk.transform
ฟังก์ชัน transform
ช่วยให้คุณสามารถเขียนฟังก์ชันโครงข่ายประสาทเทียมที่ต้องอาศัยพารามิเตอร์ (นี่คือน้ำหนักของเลเยอร์ Linear
เส้น) โดยไม่ต้องให้คุณเขียนต้นแบบสำหรับการเริ่มต้นพารามิเตอร์เหล่านั้นอย่างชัดเจน transform
ทำสิ่งนี้โดยการแปลงฟังก์ชันให้เป็นคู่ของฟังก์ชันที่ บริสุทธิ์ (ตามที่ JAX กำหนด) init
และ apply
init
ฟังก์ชัน init
พร้อมด้วย params = init(rng, ...)
(โดยที่ ...
คืออาร์กิวเมนต์ของฟังก์ชันที่ยังไม่แปลง) ช่วยให้คุณสามารถ รวบรวม ค่าเริ่มต้นของพารามิเตอร์ใดๆ ในเครือข่ายได้ Haiku ดำเนินการนี้โดยการเรียกใช้ฟังก์ชันของคุณ ติดตามพารามิเตอร์ใดๆ ที่ร้องขอผ่าน hk.get_parameter
(เรียกโดย เช่น hk.Linear
) และส่งคืนให้กับคุณ
อ็อบเจ็กต์ params
ที่ส่งคืนคือโครงสร้างข้อมูลที่ซ้อนกันของพารามิเตอร์ทั้งหมดในเครือข่ายของคุณ ซึ่งออกแบบมาเพื่อให้คุณตรวจสอบและจัดการ โดยสรุปแล้ว เป็นการแมปชื่อโมดูลกับพารามิเตอร์ของโมดูล โดยที่พารามิเตอร์โมดูลคือการแมปชื่อพารามิเตอร์กับค่าพารามิเตอร์ ตัวอย่างเช่น:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
ฟังก์ชัน apply
พร้อมด้วยลายเซ็น result = apply(params, rng, ...)
ช่วยให้คุณสามารถ ฉีด ค่าพารามิเตอร์ลงในฟังก์ชันของคุณได้ เมื่อใดก็ตามที่มีการเรียก hk.get_parameter
ค่าที่ส่งคืนจะมาจาก params
ที่คุณระบุเป็นอินพุตเพื่อ apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
โปรดทราบว่าเนื่องจากการคำนวณจริงที่ดำเนินการโดยฟังก์ชันการสูญเสียของเรานั้นไม่ได้ขึ้นอยู่กับตัวเลขสุ่ม การส่งผ่านตัวสร้างตัวเลขสุ่มจึงไม่จำเป็น ดังนั้นเราจึงสามารถส่งผ่าน None
สำหรับอาร์กิวเมนต์ rng
ได้เช่นกัน (โปรดทราบว่าหากการคำนวณของคุณ ใช้ ตัวเลขสุ่ม การส่งค่า None
สำหรับ rng
จะทำให้เกิดข้อผิดพลาด) ในตัวอย่างด้านบน เราขอให้ Haiku ทำสิ่งนี้ให้เราโดยอัตโนมัติด้วย:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
เนื่องจาก apply
เป็นฟังก์ชันบริสุทธิ์ เราจึงสามารถส่งต่อไปยัง jax.grad
(หรือการแปลงอื่น ๆ ของ JAX):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
วงจรการฝึกอบรมในตัวอย่างนี้นั้นง่ายมาก รายละเอียดอย่างหนึ่งที่ควรทราบคือการใช้ jax.tree.map
เพื่อใช้ฟังก์ชัน sgd
กับรายการที่ตรงกันทั้งหมดใน params
และ grads
ผลลัพธ์มีโครงสร้างเหมือนกับ params
ก่อนหน้า และสามารถนำมาใช้กับ apply
ได้อีกครั้ง
Haiku เขียนด้วยภาษา Python ล้วนๆ แต่ขึ้นอยู่กับโค้ด C++ ผ่าน JAX
เนื่องจากการติดตั้ง JAX จะแตกต่างกันไปขึ้นอยู่กับเวอร์ชัน CUDA ของคุณ Haiku จึงไม่แสดงรายการ JAX เป็นการขึ้นต่อกันใน requirements.txt
ขั้นแรก ทำตามคำแนะนำเหล่านี้เพื่อติดตั้ง JAX ด้วยการสนับสนุนตัวเร่งความเร็วที่เกี่ยวข้อง
จากนั้นติดตั้ง Haiku โดยใช้ pip:
$ pip install git+https://github.com/deepmind/dm-haiku
หรือคุณสามารถติดตั้งผ่าน PyPI:
$ pip install -U dm-haiku
ตัวอย่างของเราใช้ไลบรารีเพิ่มเติม (เช่น bsuite) คุณสามารถติดตั้งข้อกำหนดเพิ่มเติมครบชุดได้โดยใช้ pip:
$ pip install -r examples/requirements.txt
ในไฮกุ โมดูลทั้งหมดเป็นคลาสย่อยของ hk.Module
คุณสามารถใช้วิธีการใดก็ได้ที่คุณต้องการ (ไม่มีกรณีพิเศษใด ๆ ) แต่โดยทั่วไปแล้วโมดูลจะใช้ __init__
และ __call__
เรามาทำงานผ่านการใช้เลเยอร์เชิงเส้นกัน:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
โมดูลทั้งหมดมีชื่อ เมื่อไม่มีการส่งอาร์กิวเมนต์ name
ไปยังโมดูล ชื่อของมันจะถูกอนุมานจากชื่อของคลาส Python (เช่น MyLinear
กลายเป็น my_linear
) โมดูลสามารถตั้งชื่อพารามิเตอร์ที่เข้าถึงได้โดยใช้ hk.get_parameter(param_name, ...)
เราใช้ API นี้ (แทนที่จะใช้แค่คุณสมบัติอ็อบเจ็กต์) เพื่อให้เราสามารถแปลงโค้ดของคุณให้เป็นฟังก์ชันล้วนๆ โดยใช้ hk.transform
เมื่อใช้โมดูล คุณจะต้องกำหนดฟังก์ชันและแปลงให้เป็นฟังก์ชันคู่เดียวโดยใช้ hk.transform
ดูการเริ่มต้นอย่างรวดเร็วของเราสำหรับรายละเอียดเพิ่มเติมเกี่ยวกับฟังก์ชันที่ส่งคืนจาก transform
:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
บางรุ่นอาจต้องมีการสุ่มตัวอย่างซึ่งเป็นส่วนหนึ่งของการคำนวณ ตัวอย่างเช่น ในตัวเข้ารหัสอัตโนมัติแบบแปรผันที่มีเคล็ดลับการปรับพารามิเตอร์ใหม่ จำเป็นต้องมีตัวอย่างแบบสุ่มจากการแจกแจงแบบปกติมาตรฐาน สำหรับการออกกลางคันเราจำเป็นต้องมีมาสก์แบบสุ่มเพื่อปล่อยหน่วยออกจากอินพุต อุปสรรคหลักในการทำงานกับ JAX คือการจัดการคีย์ PRNG
ในไฮกุ เรามี API อย่างง่ายสำหรับการรักษาลำดับคีย์ PRNG ที่เกี่ยวข้องกับโมดูล: hk.next_rng_key()
(หรือ next_rng_keys()
สำหรับหลายคีย์):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
หากต้องการดูการทำงานกับโมเดลสุ่มให้สมบูรณ์ยิ่งขึ้น โปรดดูตัวอย่าง VAE ของเรา
หมายเหตุ: hk.next_rng_key()
นั้นไม่ได้ใช้งานได้จริง ซึ่งหมายความว่าคุณควรหลีกเลี่ยงการใช้มันควบคู่ไปกับการแปลง JAX ซึ่งอยู่ภายใน hk.transform
สำหรับข้อมูลเพิ่มเติมและวิธีแก้ปัญหาที่เป็นไปได้ โปรดอ่านเอกสารเกี่ยวกับการแปลง Haiku และ Wrapper ที่พร้อมใช้งานสำหรับการแปลง JAX ภายในเครือข่าย Haiku
บางรุ่นอาจต้องการรักษาสถานะภายในที่ไม่แน่นอนไว้ ตัวอย่างเช่น ในการทำให้เป็นมาตรฐานแบบแบตช์ จะมีการรักษาค่าเฉลี่ยเคลื่อนที่ของค่าที่พบระหว่างการฝึกไว้
ในไฮกุ เรามี API อย่างง่ายสำหรับการรักษาสถานะที่ไม่แน่นอนซึ่งเชื่อมโยงกับโมดูล: hk.set_state
และ hk.get_state
เมื่อใช้ฟังก์ชันเหล่านี้ คุณจะต้องแปลงฟังก์ชันโดยใช้ hk.transform_with_state
เนื่องจากลายเซ็นของคู่ฟังก์ชันที่ส่งคืนแตกต่างกัน:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
หากคุณลืมใช้ hk.transform_with_state
ไม่ต้องกังวล เราจะพิมพ์ข้อผิดพลาดที่ชัดเจนซึ่งชี้ให้คุณไปที่ hk.transform_with_state
แทนที่จะทิ้งสถานะของคุณอย่างเงียบๆ
jax.pmap
ฟังก์ชั่นแท้ที่ส่งคืนจาก hk.transform
(หรือ hk.transform_with_state
) เข้ากันได้อย่างสมบูรณ์กับ jax.pmap
สำหรับรายละเอียดเพิ่มเติมเกี่ยวกับการเขียนโปรแกรม SPMD ด้วย jax.pmap
ดูที่นี่
การใช้งานทั่วไปอย่างหนึ่งของ jax.pmap
กับ Haiku คือสำหรับการฝึกอบรมข้อมูลแบบขนานบนตัวเร่งความเร็วหลายตัว ซึ่งอาจข้ามหลายโฮสต์ ด้วยไฮกุนั่นอาจมีลักษณะเช่นนี้:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
หากต้องการดูการฝึกอบรม Haiku แบบกระจายที่สมบูรณ์ยิ่งขึ้น โปรดดูตัวอย่าง ResNet-50 บน ImageNet ของเรา
หากต้องการอ้างอิงที่เก็บนี้:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
ในรายการ bibtex นี้ หมายเลขเวอร์ชันตั้งใจให้มาจาก haiku/__init__.py
และปีสอดคล้องกับการเปิดตัวโอเพ่นซอร์สของโปรเจ็กต์