نظرة عامة | لماذا هايكو؟ | البداية السريعة | التثبيت | أمثلة | دليل المستخدم | التوثيق | نقلا عن هايكو
مهم
اعتبارًا من يوليو 2023، توصي Google DeepMind بأن تعتمد المشاريع الجديدة الكتان بدلاً من Haiku. Flax هي مكتبة شبكة عصبية تم تطويرها في الأصل بواسطة Google Brain والآن بواسطة Google DeepMind.
في وقت كتابة هذا التقرير، كان لدى Flax مجموعة شاملة من الميزات المتوفرة في Haiku، وفريق تطوير أكبر وأكثر نشاطًا والمزيد من الاعتماد مع المستخدمين خارج Alphabet. يحتوي Flax على وثائق وأمثلة أكثر شمولاً ومجتمع نشط يقوم بإنشاء أمثلة شاملة.
سيظل Haiku مدعومًا بأقصى جهد، إلا أن المشروع سيدخل في وضع الصيانة، مما يعني أن جهود التطوير ستركز على إصلاحات الأخطاء والتوافق مع الإصدارات الجديدة من JAX.
سيتم إصدار إصدارات جديدة للحفاظ على عمل Haiku مع الإصدارات الأحدث من Python وJAX، ولكننا لن نضيف (أو نقبل العلاقات العامة) لميزات جديدة.
لدينا استخدام كبير لـ Haiku داخليًا في Google DeepMind ونخطط حاليًا لدعم Haiku في هذا الوضع إلى أجل غير مسمى.
الهايكو هي أداة
لبناء الشبكات العصبية
فكر في: "Sonnet for JAX"
Haiku هي مكتبة شبكة عصبية بسيطة لـ JAX تم تطويرها بواسطة بعض مؤلفي Sonnet، وهي مكتبة شبكة عصبية لـ TensorFlow.
يمكن العثور على وثائق عن هايكو على https://dm-haiku.readthedocs.io/.
توضيح: إذا كنت تبحث عن نظام التشغيل Haiku، فيرجى الاطلاع على https://haiku-os.org/.
JAX هي مكتبة حوسبة رقمية تجمع بين NumPy والتمايز التلقائي ودعم GPU/TPU من الدرجة الأولى.
Haiku هي مكتبة شبكة عصبية بسيطة لـ JAX تمكن المستخدمين من استخدام نماذج البرمجة الموجهة للكائنات المألوفة مع السماح بالوصول الكامل إلى تحويلات وظائف JAX النقية.
يوفر Haiku أداتين أساسيتين: تجريد الوحدة، hk.Module
، وتحويل الوظيفة البسيط، hk.transform
.
hk.Module
s هي كائنات Python التي تحتوي على مراجع لمعلماتها الخاصة، والوحدات النمطية الأخرى، والأساليب التي تطبق الوظائف على مدخلات المستخدم.
يحول hk.transform
الدوال التي تستخدم هذه الوحدات الموجهة للكائنات "غير النقية" وظيفيًا إلى دوال خالصة يمكن استخدامها مع jax.jit
و jax.grad
و jax.pmap
وما إلى ذلك.
يوجد عدد من مكتبات الشبكات العصبية لـ JAX. لماذا يجب عليك اختيار هايكو؟
Module
لـ Sonnet لإدارة الحالة مع الاحتفاظ بإمكانية الوصول إلى تحويلات وظائف JAX.hk.transform
)، يهدف Haiku إلى مطابقة واجهة برمجة التطبيقات الخاصة بـ Sonnet 2. يجب أن تتطابق الوحدات النمطية والأساليب وأسماء الوسائط والافتراضيات وأنظمة التهيئة.hk.next_rng_key()
بإرجاع مفتاح rng فريد.دعونا نلقي نظرة على مثال للشبكة العصبية، ووظيفة الخسارة، وحلقة التدريب. (لمزيد من الأمثلة، راجع دليل الأمثلة الخاص بنا. يعد مثال MNIST مكانًا جيدًا للبدء.)
import haiku as hk
import jax . numpy as jnp
def softmax_cross_entropy ( logits , labels ):
one_hot = jax . nn . one_hot ( labels , logits . shape [ - 1 ])
return - jnp . sum ( jax . nn . log_softmax ( logits ) * one_hot , axis = - 1 )
def loss_fn ( images , labels ):
mlp = hk . Sequential ([
hk . Linear ( 300 ), jax . nn . relu ,
hk . Linear ( 100 ), jax . nn . relu ,
hk . Linear ( 10 ),
])
logits = mlp ( images )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
rng = jax . random . PRNGKey ( 42 )
dummy_images , dummy_labels = next ( input_dataset )
params = loss_fn_t . init ( rng , dummy_images , dummy_labels )
def update_rule ( param , update ):
return param - 0.01 * update
for images , labels in input_dataset :
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
params = jax . tree . map ( update_rule , params , grads )
جوهر هايكو هو hk.transform
. تتيح لك وظيفة transform
كتابة وظائف الشبكة العصبية التي تعتمد على المعلمات (هنا أوزان الطبقات Linear
) دون مطالبتك بكتابة النموذج المعياري بشكل صريح لتهيئة تلك المعلمات. يقوم transform
بذلك عن طريق تحويل الوظيفة إلى زوج من الوظائف النقية (كما هو مطلوب بواسطة JAX) init
و apply
.
init
تسمح لك وظيفة init
، مع params = init(rng, ...)
(حيث ...
هي الوسائط الخاصة بالوظيفة غير المحولة)، بجمع القيمة الأولية لأي معلمات في الشبكة. يقوم Haiku بذلك عن طريق تشغيل وظيفتك، وتتبع أي معلمات مطلوبة من خلال hk.get_parameter
(يُستدعى بواسطة hk.Linear
على سبيل المثال ) وإعادتها إليك.
كائن params
الذي تم إرجاعه عبارة عن بنية بيانات متداخلة لجميع المعلمات في شبكتك، وهي مصممة لتتمكن من فحصها ومعالجتها. بشكل ملموس، فهو عبارة عن تعيين لاسم الوحدة النمطية لمعلمات الوحدة النمطية، حيث تكون معلمة الوحدة النمطية عبارة عن تعيين لاسم المعلمة لقيمة المعلمة. على سبيل المثال:
{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
'w': ndarray(..., shape=(28, 300), dtype=float32)},
'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
'w': ndarray(..., shape=(1000, 100), dtype=float32)},
'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
'w': ndarray(..., shape=(100, 10), dtype=float32)}}
apply
تسمح لك وظيفة apply
، التي تحتوي على result = apply(params, rng, ...)
، بإدخال قيم المعلمات في وظيفتك. عندما يتم استدعاء hk.get_parameter
، فإن القيمة التي يتم إرجاعها ستأتي من params
التي تقدمها كمدخل apply
:
loss = loss_fn_t . apply ( params , rng , images , labels )
لاحظ أنه نظرًا لأن الحساب الفعلي الذي يتم إجراؤه بواسطة دالة الخسارة لدينا لا يعتمد على أرقام عشوائية، فإن تمرير منشئ أرقام عشوائية ليس ضروريًا، لذا يمكننا أيضًا تمرير None
للوسيطة rng
. (لاحظ أنه إذا كان حسابك يستخدم أرقامًا عشوائية، فإن تمرير None
لـ rng
سيؤدي إلى ظهور خطأ.) في مثالنا أعلاه، نطلب من Haiku القيام بذلك نيابةً عنا تلقائيًا باستخدام:
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
بما أن apply
دالة خالصة، يمكننا تمريرها إلى jax.grad
(أو أي من تحويلات JAX الأخرى):
grads = jax . grad ( loss_fn_t . apply )( params , images , labels )
حلقة التدريب في هذا المثال بسيطة جدًا. إحدى التفاصيل التي يجب ملاحظتها هي استخدام jax.tree.map
لتطبيق الدالة sgd
عبر جميع الإدخالات المطابقة في params
grads
. النتيجة لها نفس بنية params
السابقة ويمكن استخدامها مرة أخرى مع apply
.
الهايكو مكتوب بلغة بايثون النقية، ولكنه يعتمد على كود C++ عبر JAX.
نظرًا لأن تثبيت JAX يختلف وفقًا لإصدار CUDA الخاص بك، فإن Haiku لا يدرج JAX باعتباره تبعية في requirements.txt
.
أولاً، اتبع هذه الإرشادات لتثبيت JAX مع دعم التسريع ذي الصلة.
ثم قم بتثبيت Haiku باستخدام النقطة:
$ pip install git+https://github.com/deepmind/dm-haiku
وبدلاً من ذلك، يمكنك التثبيت عبر PyPI:
$ pip install -U dm-haiku
تعتمد أمثلتنا على مكتبات إضافية (مثل bsuite). يمكنك تثبيت المجموعة الكاملة من المتطلبات الإضافية باستخدام النقطة:
$ pip install -r examples/requirements.txt
في Haiku، جميع الوحدات هي فئة فرعية من hk.Module
. يمكنك تنفيذ أي طريقة تريدها (ليس هناك ما هو خاص)، ولكن عادةً ما تقوم الوحدات بتنفيذ __init__
و __call__
.
دعونا نعمل من خلال تنفيذ طبقة خطية:
class MyLinear ( hk . Module ):
def __init__ ( self , output_size , name = None ):
super (). __init__ ( name = name )
self . output_size = output_size
def __call__ ( self , x ):
j , k = x . shape [ - 1 ], self . output_size
w_init = hk . initializers . TruncatedNormal ( 1. / np . sqrt ( j ))
w = hk . get_parameter ( "w" , shape = [ j , k ], dtype = x . dtype , init = w_init )
b = hk . get_parameter ( "b" , shape = [ k ], dtype = x . dtype , init = jnp . zeros )
return jnp . dot ( x , w ) + b
جميع الوحدات لها اسم. عندما لا يتم تمرير أي وسيطة name
إلى الوحدة، يتم استنتاج اسمها من اسم فئة Python (على سبيل المثال يصبح MyLinear
my_linear
). يمكن أن تحتوي الوحدات على معلمات مسماة يتم الوصول إليها باستخدام hk.get_parameter(param_name, ...)
. نحن نستخدم واجهة برمجة التطبيقات هذه (بدلاً من استخدام خصائص الكائن فقط) حتى نتمكن من تحويل التعليمات البرمجية الخاصة بك إلى وظيفة خالصة باستخدام hk.transform
.
عند استخدام الوحدات، تحتاج إلى تعريف الوظائف وتحويلها إلى زوج من الوظائف النقية باستخدام hk.transform
. راجع البداية السريعة للحصول على مزيد من التفاصيل حول الوظائف التي يتم إرجاعها من transform
:
def forward_fn ( x ):
model = MyLinear ( 10 )
return model ( x )
# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk . transform ( forward_fn )
x = jnp . ones ([ 1 , 1 ])
# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk . PRNGSequence ( 42 )
params = forward . init ( next ( key ), x )
# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument. Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` if this is undesirable.
y = forward . apply ( params , None , x )
قد تتطلب بعض النماذج أخذ عينات عشوائية كجزء من الحساب. على سبيل المثال، في أجهزة التشفير التلقائي المتغيرة مع خدعة إعادة المعلمة، هناك حاجة إلى عينة عشوائية من التوزيع الطبيعي القياسي. بالنسبة للتسرب، نحتاج إلى قناع عشوائي لإسقاط الوحدات من الإدخال. العقبة الرئيسية أمام إنجاز هذا العمل مع JAX هي إدارة مفاتيح PRNG.
في Haiku، نوفر واجهة برمجة تطبيقات بسيطة للحفاظ على تسلسل مفاتيح PRNG المرتبط بالوحدات النمطية: hk.next_rng_key()
(أو next_rng_keys()
لمفاتيح متعددة):
class MyDropout ( hk . Module ):
def __init__ ( self , rate = 0.5 , name = None ):
super (). __init__ ( name = name )
self . rate = rate
def __call__ ( self , x ):
key = hk . next_rng_key ()
p = jax . random . bernoulli ( key , 1.0 - self . rate , shape = x . shape )
return x * p / ( 1.0 - self . rate )
forward = hk . transform ( lambda x : MyDropout ()( x ))
key1 , key2 = jax . random . split ( jax . random . PRNGKey ( 42 ), 2 )
params = forward . init ( key1 , x )
prediction = forward . apply ( params , key2 , x )
للحصول على نظرة أكثر اكتمالاً حول العمل مع النماذج العشوائية، يرجى الاطلاع على مثال VAE الخاص بنا.
ملاحظة: hk.next_rng_key()
ليس خالصًا وظيفيًا مما يعني أنه يجب عليك تجنب استخدامه مع تحويلات JAX الموجودة داخل hk.transform
. لمزيد من المعلومات والحلول الممكنة، يرجى مراجعة المستندات الخاصة بتحويلات Haiku والأغلفة المتوفرة لتحويلات JAX داخل شبكات Haiku.
قد ترغب بعض النماذج في الحفاظ على حالة داخلية قابلة للتغيير. على سبيل المثال، في التسوية الدفعية، يتم الحفاظ على متوسط متحرك للقيم التي تمت مواجهتها أثناء التدريب.
في Haiku، نقدم واجهة برمجة تطبيقات بسيطة للحفاظ على الحالة القابلة للتغيير المرتبطة بالوحدات النمطية: hk.set_state
و hk.get_state
. عند استخدام هذه الوظائف، تحتاج إلى تحويل وظيفتك باستخدام hk.transform_with_state
نظرًا لأن توقيع زوج الوظائف الذي تم إرجاعه مختلف:
def forward ( x , is_training ):
net = hk . nets . ResNet50 ( 1000 )
return net ( x , is_training )
forward = hk . transform_with_state ( forward )
# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params , state = forward . init ( rng , x , is_training = True )
# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits , state = forward . apply ( params , state , rng , x , is_training = True )
إذا نسيت استخدام hk.transform_with_state
فلا تقلق، فسنطبع خطأ واضحًا يوجهك إلى hk.transform_with_state
بدلاً من إسقاط حالتك بصمت.
jax.pmap
الوظائف النقية التي يتم إرجاعها من hk.transform
(أو hk.transform_with_state
) متوافقة تمامًا مع jax.pmap
. لمزيد من التفاصيل حول برمجة SPMD باستخدام jax.pmap
، انظر هنا.
أحد الاستخدامات الشائعة لـ jax.pmap
مع Haiku هو التدريب المتوازي للبيانات على العديد من المسرعات، وربما عبر مضيفين متعددين. مع هايكو، قد يبدو الأمر كما يلي:
def loss_fn ( inputs , labels ):
logits = hk . nets . MLP ([ 8 , 4 , 2 ])( x )
return jnp . mean ( softmax_cross_entropy ( logits , labels ))
loss_fn_t = hk . transform ( loss_fn )
loss_fn_t = hk . without_apply_rng ( loss_fn_t )
# Initialize the model on a single device.
rng = jax . random . PRNGKey ( 428 )
sample_image , sample_label = next ( input_dataset )
params = loss_fn_t . init ( rng , sample_image , sample_label )
# Replicate params onto all devices.
num_devices = jax . local_device_count ()
params = jax . tree . map ( lambda x : np . stack ([ x ] * num_devices ), params )
def make_superbatch ():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [ next ( input_dataset ) for _ in range ( num_devices )]
superbatch_images , superbatch_labels = zip ( * superbatch )
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np . stack ( superbatch_images )
superbatch_labels = np . stack ( superbatch_labels )
return superbatch_images , superbatch_labels
def update ( params , inputs , labels , axis_name = 'i' ):
"""Updates params based on performance on inputs and labels."""
grads = jax . grad ( loss_fn_t . apply )( params , inputs , labels )
# Take the mean of the gradients across all data-parallel replicas.
grads = jax . lax . pmean ( grads , axis_name )
# Update parameters using SGD or Adam or ...
new_params = my_update_rule ( params , grads )
return new_params
# Run several training updates.
for _ in range ( 10 ):
superbatch_images , superbatch_labels = make_superbatch ()
params = jax . pmap ( update , axis_name = 'i' )( params , superbatch_images ,
superbatch_labels )
لإلقاء نظرة أكثر اكتمالاً على تدريب Haiku الموزع، قم بإلقاء نظرة على ResNet-50 الخاص بنا على مثال ImageNet.
للاستشهاد بهذا المستودع:
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.13},
year = {2020},
}
في إدخال bibtex هذا، من المفترض أن يكون رقم الإصدار من haiku/__init__.py
، والعام يتوافق مع إصدار المشروع مفتوح المصدر.