البداية السريعة | التحولات | دليل التثبيت | مكتبات الشبكة العصبية | تغيير السجلات | المستندات المرجعية
JAX هي مكتبة Python لحساب المصفوفات الموجهة نحو التسريع وتحويل البرامج، وهي مصممة للحوسبة الرقمية عالية الأداء والتعلم الآلي على نطاق واسع.
بفضل الإصدار المحدث من Autograd، يمكن لـ JAX التمييز تلقائيًا بين وظائف Python وNumPy الأصلية. ويمكنه التفريق من خلال الحلقات والفروع والتكرار والإغلاق، ويمكنه أخذ مشتقات مشتقات المشتقات. وهو يدعم التمايز في الوضع العكسي (المعروف أيضًا باسم الانتشار العكسي) عبر grad
وكذلك التمايز في الوضع الأمامي، ويمكن تكوين الاثنين بشكل تعسفي لأي ترتيب.
الجديد هو أن JAX يستخدم XLA لتجميع برامج NumPy وتشغيلها على وحدات معالجة الرسومات ووحدات TPU. يحدث التجميع بشكل افتراضي، حيث يتم تجميع وتنفيذ مكالمات المكتبة في الوقت المناسب. لكن JAX يتيح لك أيضًا تجميع وظائف Python الخاصة بك في الوقت المناسب إلى نواة محسنة لـ XLA باستخدام واجهة برمجة التطبيقات ذات الوظيفة الواحدة، jit
. يمكن تجميع التجميع والتمايز التلقائي بشكل تعسفي، حتى تتمكن من التعبير عن خوارزميات معقدة والحصول على أقصى قدر من الأداء دون مغادرة بايثون. يمكنك أيضًا برمجة عدة وحدات معالجة رسوميات أو نوى TPU في وقت واحد باستخدام pmap
، والتمييز بين كل شيء.
قم بالتعمق أكثر، وسترى أن JAX هو بالفعل نظام قابل للتوسيع لتحويلات الوظائف القابلة للتركيب. يعد كل من grad
و jit
مثالين على مثل هذه التحولات. البعض الآخر هو vmap
للتوجيه التلقائي و pmap
للبرمجة المتوازية لبرنامج واحد متعدد البيانات (SPMD) للمسرعات المتعددة، مع المزيد في المستقبل.
هذا مشروع بحثي، وليس أحد منتجات Google الرسمية. توقع الأخطاء والحواف الحادة. الرجاء المساعدة من خلال تجربتها والإبلاغ عن الأخطاء وإخبارنا برأيك!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
انتقل مباشرة إلى استخدام جهاز كمبيوتر محمول في متصفحك، متصل بوحدة معالجة الرسومات Google Cloud. فيما يلي بعض دفاتر الملاحظات المبتدئة:
grad
للتمايز، jit
للتجميع، و vmap
للتوجيهيعمل JAX الآن على وحدات TPU السحابية. لتجربة المعاينة، راجع Cloud TPU Colabs.
للتعمق أكثر في JAX:
يعد JAX في جوهره نظامًا قابلاً للتوسيع لتحويل الوظائف الرقمية. فيما يلي أربعة تحويلات ذات أهمية أساسية: grad
و jit
و vmap
و pmap
.
grad
تمتلك JAX تقريبًا نفس واجهة برمجة التطبيقات مثل Autograd. الوظيفة الأكثر شيوعًا هي grad
لتدرجات الوضع العكسي:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
يمكنك التفريق بين أي أمر مع grad
.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
للحصول على تمييز تلقائي أكثر تقدمًا، يمكنك استخدام jax.vjp
لمنتجات المتجهات-Jacobian ذات الوضع العكسي و jax.jvp
لمنتجات ناقلات Jacobian ذات الوضع الأمامي. يمكن أن يتكون الاثنان بشكل تعسفي مع بعضهما البعض، ومع تحويلات JAX الأخرى. فيما يلي إحدى الطرق لتكوين تلك لإنشاء دالة تحسب بكفاءة مصفوفات هسه الكاملة:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
كما هو الحال مع Autograd، لديك الحرية في استخدام التمايز مع هياكل التحكم في Python:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
راجع المستندات المرجعية حول التمايز التلقائي وكتاب JAX Autodiff Cookbook للمزيد.
jit
يمكنك استخدام XLA لتجميع وظائفك من البداية إلى النهاية باستخدام jit
، والتي تُستخدم إما كمصمم @jit
أو كدالة ذات ترتيب أعلى.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
يمكنك مزج jit
و grad
وأي تحويلات JAX أخرى كيفما تشاء.
يؤدي استخدام jit
إلى وضع قيود على نوع تدفق التحكم في Python الذي يمكن للوظيفة استخدامه؛ راجع البرنامج التعليمي حول التحكم في التدفق والمشغلين المنطقيين باستخدام JIT للمزيد.
vmap
vmap
هي الخريطة الموجهة. تحتوي على دلالات مألوفة لتعيين دالة على طول محاور المصفوفة، ولكن بدلاً من الاحتفاظ بالحلقة في الخارج، فإنها تدفع الحلقة لأسفل إلى العمليات البدائية للوظيفة للحصول على أداء أفضل.
يمكن أن يوفر عليك استخدام vmap
الاضطرار إلى حمل أبعاد الدُفعة في التعليمات البرمجية الخاصة بك. على سبيل المثال، ضع في اعتبارك وظيفة التنبؤ البسيطة للشبكة العصبية غير المجمعة :
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
غالبًا ما نكتب بدلاً من ذلك jnp.dot(activations, W)
للسماح ببُعد دفعة على الجانب الأيسر من activations
، ولكننا كتبنا دالة التنبؤ المحددة هذه لتطبيقها فقط على متجهات الإدخال الفردية. إذا أردنا تطبيق هذه الوظيفة على مجموعة من المدخلات مرة واحدة، فيمكننا الكتابة فقط من الناحية الدلالية
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
لكن دفع مثال واحد عبر الشبكة في كل مرة سيكون بطيئًا! من الأفضل توجيه الحساب، بحيث نقوم في كل طبقة بضرب المصفوفة-المصفوفة بدلاً من الضرب بالمصفوفة-المتجه.
تقوم وظيفة vmap
بهذا التحويل بالنسبة لنا. يعني إذا كتبنا
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
ثم تقوم وظيفة vmap
بدفع الحلقة الخارجية داخل الوظيفة، وسينتهي جهازنا بتنفيذ عمليات ضرب المصفوفة والمصفوفة تمامًا كما لو كنا قد قمنا بعملية التجميع يدويًا.
من السهل تجميع شبكة عصبية بسيطة يدويًا بدون vmap
، لكن في حالات أخرى قد يكون التحويل اليدوي غير عملي أو مستحيل. لنأخذ على سبيل المثال مشكلة حساب التدرجات لكل مثال بكفاءة: أي أنه بالنسبة لمجموعة ثابتة من المعلمات، نريد حساب تدرج دالة الخسارة لدينا التي يتم تقييمها بشكل منفصل في كل مثال في الدفعة. مع vmap
، الأمر سهل:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
بالطبع، يمكن تكوين vmap
بشكل تعسفي باستخدام jit
و grad
وأي تحويلات JAX أخرى! نحن نستخدم vmap
مع التمايز التلقائي في الوضعين الأمامي والخلفي لإجراء حسابات مصفوفة جاكوبيان وهسي السريعة في jax.jacfwd
و jax.jacrev
و jax.hessian
.
pmap
للبرمجة المتوازية للمسرعات المتعددة، مثل وحدات معالجة الرسومات المتعددة، استخدم pmap
. باستخدام pmap
يمكنك كتابة برامج أحادية البرنامج ومتعددة البيانات (SPMD)، بما في ذلك عمليات الاتصال الجماعي المتوازي السريع. تطبيق pmap
يعني أن الوظيفة التي تكتبها يتم تجميعها بواسطة XLA (على غرار jit
)، ثم يتم نسخها وتنفيذها بالتوازي عبر الأجهزة.
فيما يلي مثال على جهاز مزود بـ 8 GPU:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
بالإضافة إلى التعبير عن خرائط خالصة، يمكنك استخدام عمليات الاتصال الجماعي السريع بين الأجهزة:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
يمكنك أيضًا دمج وظائف pmap
للحصول على أنماط اتصال أكثر تطورًا.
يتم تركيب كل شيء، لذلك لديك الحرية في التمييز من خلال الحسابات المتوازية:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
عند التمييز بين دالة pmap
في الوضع العكسي (على سبيل المثال مع grad
)، يتم توازي التمرير الخلفي للحساب تمامًا مثل التمرير الأمامي.
راجع كتاب الطبخ SPMD ومصنف SPMD MNIST من مثال البداية للمزيد.
للحصول على مسح أكثر شمولاً للمشاكل الحالية، مع الأمثلة والتفسيرات، نوصي بشدة بقراءة دفتر Gotchas. بعض أبرزها:
is
الاحتفاظ باختبار هوية الكائن). إذا كنت تستخدم تحويل JAX على دالة Python غير نقية، فقد ترى خطأ مثل Exception: Can't lift Traced...
أو Exception: Different traces at same level
.x[i] += y
، غير مدعومة، ولكن هناك بدائل وظيفية. ضمن jit
، ستعيد هذه البدائل الوظيفية استخدام المخازن المؤقتة في مكانها تلقائيًا.jax.lax
.float32
) بشكل افتراضي، ولتمكين الدقة المزدوجة (64 بت، على سبيل المثال float64
) يحتاج المرء إلى تعيين متغير jax_enable_x64
عند بدء التشغيل (أو تعيين متغير البيئة JAX_ENABLE_X64=True
) . في TPU، يستخدم JAX قيم 32 بت افتراضيًا لكل شيء باستثناء المتغيرات المؤقتة الداخلية في العمليات "المشابهة لـ matmul"، مثل jax.numpy.dot
و lax.conv
. تحتوي هذه العمليات على معلمة precision
يمكن استخدامها لتقريب عمليات 32 بت عبر ثلاث تمريرات bfloat16، مع تكلفة وقت تشغيل ربما أبطأ. العمليات غير المكتملة على TPU أقل من التطبيقات التي غالبًا ما تؤكد على السرعة أكثر من الدقة، لذلك في الممارسة العملية، ستكون الحسابات على TPU أقل دقة من الحسابات المماثلة على الواجهات الخلفية الأخرى.np.add(1, np.array([2], np.float32)).dtype
هو float64
بدلاً من float32
.jit
، تقيد كيفية استخدام تدفق التحكم في Python. سوف تحصل دائمًا على أخطاء عالية إذا حدث خطأ ما. قد يتعين عليك استخدام معلمة static_argnums
الخاصة بـ jit
، أو أساسيات تدفق التحكم المنظم مثل lax.scan
، أو مجرد استخدام jit
على وظائف فرعية أصغر. لينكس x86_64 | لينكس آرتش64 | ماك x86_64 | ماك آرتش64 | ويندوز x86_64 | ويندوز WSL2 x86_64 | |
---|---|---|---|---|---|---|
وحدة المعالجة المركزية | نعم | نعم | نعم | نعم | نعم | نعم |
نفيديا GPU | نعم | نعم | لا | غير متوفر | لا | تجريبي |
جوجل تي بي يو | نعم | غير متوفر | غير متوفر | غير متوفر | غير متوفر | غير متوفر |
معالج رسوميات AMD | نعم | لا | تجريبي | غير متوفر | لا | لا |
معالج رسوميات أبل | غير متوفر | لا | غير متوفر | تجريبي | غير متوفر | غير متوفر |
إنتل GPU | تجريبي | غير متوفر | غير متوفر | غير متوفر | لا | لا |
منصة | تعليمات |
---|---|
وحدة المعالجة المركزية | pip install -U jax |
نفيديا GPU | pip install -U "jax[cuda12]" |
جوجل تي بي يو | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
وحدة معالجة الرسومات AMD (لينكس) | استخدم Docker، أو العجلات المعدة مسبقًا، أو قم بالبناء من المصدر. |
ماك GPU | اتبع تعليمات أبل. |
إنتل GPU | اتبع تعليمات إنتل. |
راجع الوثائق للحصول على معلومات حول استراتيجيات التثبيت البديلة. يتضمن ذلك التجميع من المصدر، والتثبيت باستخدام Docker، واستخدام إصدارات أخرى من CUDA، وبناء conda مدعوم من المجتمع، والإجابات على بعض الأسئلة الشائعة.
تعمل مجموعات بحث Google المتعددة في Google DeepMind وAlphabet على تطوير ومشاركة المكتبات لتدريب الشبكات العصبية في JAX. إذا كنت تريد مكتبة كاملة الميزات للتدريب على الشبكات العصبية مع أمثلة وأدلة إرشادية، فجرّب Flax وموقع التوثيق الخاص به.
راجع قسم JAX Ecosystem على موقع وثائق JAX للحصول على قائمة بمكتبات الشبكات المستندة إلى JAX، والتي تتضمن Optax لمعالجة التدرج وتحسينه، وchex للحصول على تعليمات برمجية واختبارات موثوقة، وEquinox للشبكات العصبية. (شاهد النظام البيئي NeurIPS 2020 JAX في DeepMind وهو يتحدث هنا للحصول على تفاصيل إضافية.)
للاستشهاد بهذا المستودع:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
في إدخال bibtex أعلاه، تكون الأسماء مرتبة أبجديًا، والمقصود أن يكون رقم الإصدار هو ذلك من jax/version.py، والسنة تتوافق مع إصدار المشروع مفتوح المصدر.
تم وصف إصدار حديث من JAX، يدعم التمايز والتجميع التلقائي فقط إلى XLA، في ورقة بحثية ظهرت في SysML 2018. نحن نعمل حاليًا على تغطية أفكار JAX وقدراتها في ورقة أكثر شمولاً وحداثة.
للحصول على تفاصيل حول JAX API، راجع الوثائق المرجعية.
للبدء كمطور JAX، راجع وثائق المطور.