البرمجة الاحتمالية مدعومة بواسطة Jax لتجميع Autograd و Jit إلى GPU/TPU/وحدة المعالجة المركزية.
مستندات وأمثلة | المنتدى
Numpyro هي مكتبة برمجة احتمالية خفيفة الوزن توفر الواجهة الخلفية لـ Pyro. نعتمد على JAX من أجل التمايز التلقائي وتجميع JIT إلى GPU / CPU. يخضع Numpyro تحت التطوير النشط ، لذا احذر من الهشاشة ، والبق ، والتغييرات في واجهة برمجة التطبيقات مع تطور التصميم.
تم تصميم Numpyro ليكون خفيف الوزن ويركز على توفير ركيزة مرنة يمكن للمستخدمين البناء عليها:
sample
param
. يجب أن يبدو رمز النموذج مشابهًا جدًا للبيرو باستثناء بعض الاختلافات البسيطة بين Pytorch و Numpy's API. انظر المثال أدناه.jit
و grad
لتجميع خطوة التكامل بأكملها في نواة XLA المحسنة. نقوم أيضًا بإزالة Python النفقات العامة بواسطة JIT لتجميع مرحلة بناء الأشجار بأكملها بالمكسرات (هذا ممكن باستخدام المكسرات التكرارية). هناك أيضًا تطبيق أساسي للاستدلال المتغير مع العديد من أدلة المرنة (AUTO) للاستدلال المتغير التلقائي التلقائي (ADVI). يدعم تطبيق الاستدلال المتغير عددًا من الميزات ، بما في ذلك دعم النماذج ذات المتغيرات الكامنة المنفصلة (انظر TraceGraph_elbo و Traceenum_elbo).torch.distributions
. بالإضافة إلى التوزيعات ، تكون constraints
transforms
مفيدة للغاية عند العمل على فئات التوزيع بدعم محدود. أخيرًا ، يمكن استخدام توزيعات من احتمال TensorFlow (TFP) مباشرة في نماذج Numpyro.sample
param
باستخدام معالجات التأثير من وحدة Numpyro.handlers ، ويمكن تمديدها بسهولة لتنفيذ خوارزميات الاستدلال المخصصة وأدوات الاستدلال. دعنا نستكشف Numpyro باستخدام مثال بسيط. سوف نستخدم مثال المدارس الثمانية من Gelman et al. ، تحليل بيانات Bayesian: Sec. 5.5 ، 2003 ، الذي يدرس تأثير التدريب على أداء SAT في ثماني مدارس.
يتم تقديم البيانات بواسطة:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
، حيث y
هي تأثيرات العلاج و sigma
الخطأ القياسي. نحن نبني نموذجًا هرميًا للدراسة حيث نفترض أن المعلمات على مستوى المجموعة theta
لكل مدرسة يتم أخذ عينات منها من توزيع طبيعي مع mu
متوسط غير معروف و tau
المعروف ، في حين أن البيانات المرصودة يتم توليدها بدورها من التوزيع الطبيعي بوسط والانحراف المعياري الذي قدمه theta
(التأثير الحقيقي) sigma
، على التوالي. يتيح لنا ذلك تقدير المعلمات على مستوى السكان mu
و tau
من خلال التجميع من جميع الملاحظات ، مع السماح بالتغير الفردي بين المدارس باستخدام معلمات theta
على مستوى المجموعة.
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
دعنا نستنتج قيم المعلمات غير المعروفة في نموذجنا عن طريق تشغيل MCMC باستخدام عينة NO-U-Turn (المكسرات). لاحظ استخدام وسيطة extra_fields
في MCMC.Run. بشكل افتراضي ، نجمع فقط عينات من التوزيع الهدف (الخلفي) عند تشغيل الاستدلال باستخدام MCMC
. ومع ذلك ، يمكن تحقيق جمع حقول إضافية مثل الطاقة المحتملة أو احتمال قبول العينة بسهولة باستخدام وسيطة extra_fields
. للحصول على قائمة بالحقول الممكنة التي يمكن جمعها ، راجع كائن HMCSTATE. في هذا المثال ، سنقوم بالإضافة إلى ذلك بجمع potential_energy
لكل عينة.
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
يمكننا طباعة ملخص تشغيل MCMC ، وفحصنا إذا لاحظنا أي اختلافات أثناء الاستدلال. بالإضافة إلى ذلك ، نظرًا لأننا جمعنا الطاقة المحتملة لكل عينات ، يمكننا بسهولة حساب كثافة مفصل السجل المتوقعة.
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
تشير القيم المذكورة أعلاه 1 لـ Gelman Rubin Diagnostic ( r_hat
) إلى أن السلسلة لم تتقارب بالكامل. القيمة المنخفضة لحجم العينة الفعال ( n_eff
) ، وخاصة بالنسبة tau
، ويبدو أن عدد التحولات المتباينة مشكلة. لحسن الحظ ، هذا هو أمراض شائعة يمكن تصحيحها باستخدام معلمة غير محددة لـ tau
في نموذجنا. هذا أمر واضح ومباشر للقيام به في Numpyro باستخدام مثيل TransformedDistribution مع معالج تأثير إعادة التعويض. دعنا نعيد كتابة نفس النموذج ، ولكن بدلاً من أخذ عينات من theta
من أحد Normal(mu, tau)
، سنقوم بدلاً من ذلك بتجربة توزيع Normal(0, 1)
يتم تحويله باستخدام AffinetRansform. لاحظ أنه من خلال القيام بذلك ، يقوم Numpyro بتشغيل HMC عن طريق إنشاء عينات theta_base
لتوزيع القاعدة Normal(0, 1)
بدلاً من ذلك. نرى أن السلسلة الناتجة لا تعاني من نفس الأمراض - تشخيص Gelman Rubin هو 1 لجميع المعلمات وأن حجم العينة الفعال يبدو جيدًا جدًا!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
لاحظ أنه بالنسبة لفئة التوزيعات مع loc,scale
مثل Normal
و Cauchy
و StudentT
، ونحن نقدم أيضًا إعادة تجديد LoccalereParam لتحقيق نفس الغرض. سيكون الرمز المقابل
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
الآن ، دعنا نفترض أن لدينا مدرسة جديدة لم نلاحظ فيها أي درجات اختبار ، لكننا نود إنشاء تنبؤات. يوفر Numpyro فئة تنبؤية لهذا الغرض. لاحظ أنه في حالة عدم وجود أي بيانات ملحوظة ، نستخدم ببساطة معلمات مستوى السكان لإنشاء تنبؤات. ظروف المنفعة Predictive
مواقع mu
و tau
غير الملحوظة للقيم المرسومة من التوزيع الخلفي من آخر تشغيل MCMC لدينا ، وتشغيل النموذج للأمام لتوليد التنبؤات.
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
للحصول على المزيد من الأمثلة على تحديد النماذج والقيام بالاستدلال في Numpyro:
lax.scan
البدائية للاستدلال السريع.سوف يلاحظ مستخدمو Pyro أن واجهة برمجة التطبيقات (API) لمواصفات النماذج والاستدلال عليها إلى حد كبير مثل Pyro ، بما في ذلك توزيعات API ، حسب التصميم. ومع ذلك ، هناك بعض الاختلافات الأساسية المهمة (التي تنعكس في الداخلية) التي يجب أن يكون المستخدمون على دراية بها. على سبيل المثال ، في Numpyro ، لا يوجد متجر عالمي للمعلمات أو حالة عشوائية ، لجعل من الممكن بالنسبة لنا الاستفادة من مجموعة JIT الخاصة بـ Jax. أيضًا ، قد يحتاج المستخدمون إلى كتابة نماذجهم بأسلوب أكثر وظيفية يعمل بشكل أفضل مع Jax. الرجوع إلى الأسئلة الشائعة للحصول على قائمة من الاختلافات.
نحن نقدم نظرة عامة على معظم خوارزميات الاستدلال التي تدعمها Numpyro ونقدم بعض الإرشادات حول خوارزميات الاستدلال التي قد تكون مناسبة لفئات مختلفة من النماذج.
مثل HMC/NUTS ، تدعم جميع خوارزميات MCMC المتبقية التعداد على المتغيرات الكامنة المنفصلة إن أمكن (انظر القيود). يجب تمييز المواقع المذكورة مع infer={'enumerate': 'parallel'}
كما هو الحال في مثال التعليقات التوضيحية.
Trace_ELBO
ولكنه يحسب جزءًا من Elbo تحليليًا إذا كان ذلك ممكنًا.انظر المستندات لمزيد من التفاصيل.
دعم Windows المحدود: لاحظ أن Numpyro لم يتم اختباره على Windows ، وقد يتطلب بناء Jaxlib من المصدر. انظر قضية Jax هذه لمزيد من التفاصيل. بدلاً من ذلك ، يمكنك تثبيت نظام Windows الفرعي لـ Linux واستخدام Numpyro عليه كما في نظام Linux. انظر أيضًا CUDA على نظام Windows الفرعي لـ Linux ونشر هذا المنتدى إذا كنت ترغب في استخدام وحدات معالجة الرسومات على Windows.
لتثبيت Numpyro مع أحدث إصدار من وحدة المعالجة المركزية من Jax ، يمكنك استخدام PIP:
pip install numpyro
في حالة وجود مشكلات التوافق أثناء تنفيذ الأمر أعلاه ، يمكنك بدلاً من ذلك فرض تثبيت إصدار وحدة المعالجة المركزية المتوافقة مع Jax with
pip install numpyro[cpu]
لاستخدام Numpyro على وحدة معالجة الرسومات ، تحتاج إلى تثبيت CUDA أولاً ثم استخدام أمر PIP التالي:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
إذا كنت بحاجة إلى مزيد من التوجيه ، فيرجى إلقاء نظرة على تعليمات تثبيت GPU Jax.
لتشغيل Numpyro على السحابة TPUs ، يمكنك إلقاء نظرة على بعض jax على أمثلة Cloud TPU.
بالنسبة إلى Cloud TPU VM ، تحتاج إلى إعداد الواجهة الخلفية TPU كما هو مفصل في دليل Cloud TPU VM Jax QuickStart. بعد التحقق من أن الواجهة الخلفية TPU يتم إعدادها بشكل صحيح ، يمكنك تثبيت Numpyro باستخدام أمر pip install numpyro
.
النظام الأساسي الافتراضي: ستستخدم JAX GPU افتراضيًا إذا تم تثبيت حزمة
jaxlib
المدعومة من CUDA. يمكنك استخدام SET_PLATFORM Utilitynumpyro.set_platform("cpu")
للتبديل إلى وحدة المعالجة المركزية في بداية البرنامج.
يمكنك أيضًا تثبيت Numpyro من المصدر:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
يمكنك أيضًا تثبيت Numpyro مع conda:
conda install -c conda-forge numpyro
على عكس Pyro ، لا يعمل numpyro.sample('x', dist.Normal(0, 1))
. لماذا؟
من المرجح أن تستخدم بيان numpyro.sample
خارج سياق الاستدلال. لا يوجد لدى JAX حالة عشوائية عالمية ، وعلى هذا النحو ، تحتاج عينة التوزيع إلى مفتاح مولد أرقام عشوائي صريح (PRNGKEY) لإنشاء عينات من. تستخدم خوارزميات الاستدلال في Numpyro معالج البذور لخيطها في مفتاح مولد الأرقام العشوائية ، وراء الكواليس.
خياراتك هي:
استدعاء التوزيع مباشرة وقم بتوفير PRNGKey
، على سبيل dist.Normal(0, 1).sample(PRNGKey(0))
توفير وسيطة rng_key
إلى numpyro.sample
. على سبيل المثال numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
.
لف الرمز في معالج seed
، يستخدم إما كمدير سياق أو كدالة تلتف فوق المكالمة الأصلية. على سبيل المثال
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
أو كدالة ترتيب أعلى:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
هل يمكنني استخدام نموذج Pyro نفسه للقيام بالاستدلال في Numpyro؟
كما قد لاحظت من الأمثلة ، يدعم Numpyro جميع بدايات Pyro مثل sample
param
plate
module
. بالإضافة إلى ذلك ، تأكدنا من أن واجهة برمجة تطبيقات التوزيعات تعتمد على torch.distributions
، وأن فئات الاستدلال مثل SVI
و MCMC
لها نفس الواجهة. هذا جنبا إلى جنب مع التشابه في API لعمليات Numpy و Pytorch يضمن أن النماذج التي تحتوي على بيانات بدائية Pyro يمكن استخدامها مع إما الخلفية مع بعض التغييرات البسيطة. مثال على بعض الاختلافات جنبا إلى جنب مع التغييرات اللازمة ، ورد أدناه:
torch
في النموذج الخاص بك من حيث عملية jax.numpy
المقابلة. بالإضافة إلى ذلك ، ليس كل عمليات torch
لديها نظير numpy
(والعكس صحيح) ، وأحيانًا توجد اختلافات طفيفة في API.pyro.sample
خارج سياق الاستدلال ملفوفة في معالج seed
، كما ذكر أعلاه.numpyro.param
خارج سياق الاستدلال لن يكون له أي تأثير. لاسترداد قيم المعلمة المحسنة من SVI ، استخدم طريقة svi.get_params. لاحظ أنه لا يزال بإمكانك استخدام عبارات param
داخل نموذج وسيستخدم Numpyro معالج التأثير البديل داخليًا لاستبدال القيم من المُحسّن عند تشغيل النموذج في SVI.بالنسبة لمعظم النماذج الصغيرة ، يجب أن تكون التغييرات المطلوبة لتشغيل الاستدلال في Numpyro بسيطة. بالإضافة إلى ذلك ، نحن نعمل على Pyro-API والتي تتيح لك كتابة نفس الرمز وإرساله إلى العديد من الخلفية ، بما في ذلك Numpyro. سيكون هذا بالضرورة أكثر تقييدًا ، ولكنه يتمتع بميزة كونه لاأدري. راجع الوثائق للحصول على مثال ، وأخبرنا بتعليقاتك.
كيف يمكنني المساهمة في المشروع؟
شكرا لاهتمامك بالمشروع! يمكنك إلقاء نظرة على المشكلات الصديقة للمبتدئين والتي تتميز بعلامة العدد الأول الجيد على Github. أيضا ، يرجى الشعور بالوصول إلينا في المنتدى.
في المدى القريب ، نخطط للعمل على ما يلي. يرجى فتح مشكلات جديدة لطلبات الميزات والتحسينات:
يمكن الاطلاع على الأفكار المحفزة وراء Numpyro ووصف المكسرات التكرارية في هذه الورقة التي ظهرت في برنامج Neups 2019 برامج ورشة عمل للتعلم الآلي.
إذا كنت تستخدم Numpyro ، يرجى التفكير في:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
إلى جانب
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}