การเขียนโปรแกรมความน่าจะเป็นขับเคลื่อนโดย JAX สำหรับการรวบรวม Autograd และ JIT ไปยัง GPU/TPU/CPU
เอกสารและตัวอย่าง | ฟอรัม
Numpyro เป็นไลบรารีการเขียนโปรแกรมที่น่าจะเป็นน้ำหนักเบาที่ให้แบ็กเอนด์ Numpy สำหรับ Pyro เราพึ่งพา JAX สำหรับความแตกต่างอัตโนมัติและการรวบรวม JIT ไปยัง GPU / CPU Numpyro อยู่ระหว่างการพัฒนาที่ใช้งานอยู่ดังนั้นระวังความเปราะบางข้อบกพร่องและการเปลี่ยนแปลง API เมื่อการออกแบบวิวัฒนาการ
Numpyro ได้รับการออกแบบให้มี น้ำหนักเบา และมุ่งเน้นไปที่การจัดหาสารตั้งต้นที่ยืดหยุ่นที่ผู้ใช้สามารถสร้างได้:
sample
และ param
รหัสโมเดลควรดูคล้ายกับ Pyro มากยกเว้นความแตกต่างเล็กน้อยระหว่าง Pytorch และ API ของ Numpy ดูตัวอย่างด้านล่างjit
และ grad
เพื่อรวบรวมขั้นตอนการรวมทั้งหมดลงในเคอร์เนลที่ปรับให้เหมาะสม XLA นอกจากนี้เรายังกำจัดค่าใช้จ่ายของ Python โดย JIT รวบรวมเวทีการสร้างต้นไม้ทั้งหมดในถั่ว (เป็นไปได้โดยใช้น็อตวนซ้ำ) นอกจากนี้ยังมีการใช้การอนุมานแบบแปรผันขั้นพื้นฐานพร้อมกับคำแนะนำ (อัตโนมัติ) ที่ยืดหยุ่นจำนวนมากสำหรับการอนุมานความแปรปรวนของความแตกต่างอัตโนมัติ (ADVI) การใช้การอนุมานแบบแปรปรวนรองรับคุณสมบัติจำนวนมากรวมถึงการสนับสนุนสำหรับรุ่นที่มีตัวแปรแฝงแบบไม่ต่อเนื่อง (ดู Tracegraph_elbo และ Traceenum_elbo)torch.distributions
นอกเหนือจากการแจกแจง constraints
และ transforms
ยังมีประโยชน์มากเมื่อทำงานในคลาสการแจกจ่ายด้วยการสนับสนุนที่มีขอบเขต ในที่สุดการแจกแจงจากความน่าจะเป็นของ Tensorflow (TFP) สามารถใช้โดยตรงในรุ่น NUMPYROsample
และ param
สามารถให้การตีความที่ไม่เป็นมาตรฐานโดยใช้ตัวจัดการเอฟเฟกต์จากโมดูล numpyro.handlers และสิ่งเหล่านี้สามารถขยายได้อย่างง่ายดายเพื่อใช้อัลกอริทึมการอนุมานที่กำหนดเองและยูทิลิตี้การอนุมาน ให้เราสำรวจ numpyro โดยใช้ตัวอย่างง่ายๆ เราจะใช้ตัวอย่างโรงเรียนแปดแห่งจาก Gelman และคณะการวิเคราะห์ข้อมูลแบบเบย์: วินาที 5.5, 2003 ซึ่งศึกษาผลของการฝึกสอนต่อประสิทธิภาพ SAT ในแปดโรงเรียน
ข้อมูลได้รับจาก:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
โดยที่ y
คือผลการรักษาและ sigma
เป็นข้อผิดพลาดมาตรฐาน เราสร้างแบบจำลองลำดับชั้นสำหรับการศึกษาที่เราสมมติว่าพารามิเตอร์ระดับกลุ่ม theta
สำหรับแต่ละโรงเรียนจะได้รับการสุ่มตัวอย่างจากการแจกแจงแบบปกติโดยไม่ทราบค่าเฉลี่ย mu
และค่าเบี่ยงเบนมาตรฐาน tau
ในขณะที่ข้อมูลที่สังเกตได้ถูกสร้างขึ้นจากการแจกแจงปกติด้วยค่าเฉลี่ย และค่าเบี่ยงเบนมาตรฐานที่ได้รับจาก theta
(ผลจริง) และ sigma
มาตามลำดับ สิ่งนี้ช่วยให้เราสามารถประเมินพารามิเตอร์ระดับประชากร mu
และ tau
โดยรวมจากการสังเกตทั้งหมดในขณะที่ยังคงอนุญาตให้มีการเปลี่ยนแปลงส่วนบุคคลระหว่างโรงเรียนโดยใช้พารามิเตอร์ theta
ระดับกลุ่ม
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
ให้เราอนุมานค่าของพารามิเตอร์ที่ไม่รู้จักในโมเดลของเราโดยเรียกใช้ MCMC โดยใช้ตัวอย่าง No-U-Turn (Nuts) หมายเหตุการใช้อาร์กิวเมนต์ extra_fields
ใน MCMC.Run โดยค่าเริ่มต้นเราจะรวบรวมตัวอย่างจากการกระจายเป้าหมาย (หลัง) เมื่อเราเรียกใช้การอนุมานโดยใช้ MCMC
อย่างไรก็ตามการรวบรวมฟิลด์เพิ่มเติมเช่นพลังงานที่มีศักยภาพหรือความน่าจะเป็นในการยอมรับของตัวอย่างสามารถทำได้อย่างง่ายดายโดยใช้อาร์กิวเมนต์ extra_fields
สำหรับรายการฟิลด์ที่เป็นไปได้ที่สามารถรวบรวมได้ให้ดูที่วัตถุ HMCSTATE ในตัวอย่างนี้เราจะรวบรวม potential_energy
สำหรับแต่ละตัวอย่าง
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
เราสามารถพิมพ์บทสรุปของ MCMC Run และตรวจสอบว่าเราสังเกตเห็นความแตกต่างใด ๆ ในระหว่างการอนุมาน นอกจากนี้เนื่องจากเรารวบรวมพลังงานที่เป็นไปได้สำหรับแต่ละตัวอย่างเราจึงสามารถคำนวณความหนาแน่นร่วมของบันทึกที่คาดหวังได้อย่างง่ายดาย
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
ค่าที่สูงกว่า 1 สำหรับการวินิจฉัยแบบแยกเจลแมน Rubin ( r_hat
) บ่งชี้ว่าห่วงโซ่ยังไม่ได้มาบรรจบกันอย่างเต็มที่ ค่าต่ำสำหรับขนาดตัวอย่างที่มีประสิทธิภาพ ( n_eff
) โดยเฉพาะอย่างยิ่งสำหรับ tau
และจำนวนการเปลี่ยนที่แตกต่างนั้นดูมีปัญหา โชคดีที่นี่เป็นพยาธิสภาพทั่วไปที่สามารถแก้ไขได้โดยใช้การกำหนดพารามิเตอร์ที่ไม่ได้เป็นศูนย์กลางสำหรับ tau
ในแบบจำลองของเรา สิ่งนี้ตรงไปตรงมาที่จะทำใน Numpyro โดยใช้อินสแตนซ์ transformedDistribution พร้อมกับตัวจัดการเอฟเฟกต์ reparametization ให้เราเขียนโมเดลเดียวกันใหม่ แต่แทนที่จะสุ่มตัวอย่าง theta
จาก Normal(mu, tau)
เราจะสุ่มตัวอย่างจากการแจกแจงฐาน Normal(0, 1)
ที่เปลี่ยนโดยใช้ AffineTransform โปรดทราบว่าโดยการทำเช่นนั้น NumPyro เรียกใช้ HMC โดยการสร้างตัวอย่าง theta_base
สำหรับการแจกแจงฐาน Normal(0, 1)
แทน เราเห็นว่าห่วงโซ่ผลลัพธ์ไม่ได้รับผลกระทบจากพยาธิสภาพเดียวกัน - การวินิจฉัย Gelman Rubin คือ 1 สำหรับพารามิเตอร์ทั้งหมดและขนาดตัวอย่างที่มีประสิทธิภาพดูค่อนข้างดี!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
โปรดทราบว่าสำหรับชั้นเรียนของการแจกแจงด้วย loc,scale
เช่น Normal
Cauchy
StudentT
เรายังให้ reparameterizer locscalereparam reparameterizer เพื่อให้บรรลุวัตถุประสงค์เดียวกัน รหัสที่เกี่ยวข้องจะเป็น
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
ตอนนี้ให้เราสมมติว่าเรามีโรงเรียนใหม่ที่เราไม่ได้สังเกตคะแนนการทดสอบใด ๆ แต่เราต้องการสร้างการคาดการณ์ Numpyro จัดให้มีคลาสการทำนายเพื่อจุดประสงค์ดังกล่าว โปรดทราบว่าในกรณีที่ไม่มีข้อมูลใด ๆ ที่สังเกตได้เราเพียงใช้พารามิเตอร์ระดับประชากรเพื่อสร้างการคาดการณ์ เงื่อนไขยูทิลิตี้ Predictive
ไซต์ mu
และ tau
ที่ไม่ได้รับการตรวจสอบไปจนถึงค่าที่ดึงมาจากการกระจายหลังจาก MCMC ครั้งล่าสุดของเราและเรียกใช้แบบจำลองไปข้างหน้าเพื่อสร้างการทำนาย
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
สำหรับตัวอย่างเพิ่มเติมเกี่ยวกับการระบุโมเดลและการอนุมานใน NUMPYRO:
lax.scan
Primitive สำหรับการอนุมานอย่างรวดเร็วผู้ใช้ Pyro จะทราบว่า API สำหรับข้อกำหนดของแบบจำลองและการอนุมานส่วนใหญ่เหมือนกับ Pyro รวมถึงการแจกแจง API โดยการออกแบบ อย่างไรก็ตามมีความแตกต่างหลักที่สำคัญบางประการ (สะท้อนให้เห็นในภายใน) ที่ผู้ใช้ควรตระหนักถึง เช่นใน numpyro ไม่มีการจัดเก็บพารามิเตอร์ทั่วโลกหรือสถานะสุ่มเพื่อให้เราสามารถใช้ประโยชน์จากการรวบรวม JIT ของ Jax นอกจากนี้ผู้ใช้อาจต้องเขียนโมเดลของพวกเขาในรูปแบบ การทำงาน ที่ดีขึ้นซึ่งทำงานได้ดีขึ้นกับ JAX อ้างถึงคำถามที่พบบ่อยสำหรับรายการความแตกต่าง
เราให้ภาพรวมของอัลกอริทึมการอนุมานส่วนใหญ่ที่สนับสนุนโดย NUMPYRO และเสนอแนวทางบางอย่างเกี่ยวกับอัลกอริทึมการอนุมานที่อาจเหมาะสมสำหรับคลาสที่แตกต่างกันของแบบจำลอง
เช่นเดียวกับ HMC/Nuts อัลกอริทึม MCMC ที่เหลือทั้งหมดสนับสนุนการแจงนับมากกว่าตัวแปรแฝงแบบไม่ต่อเนื่องหากเป็นไปได้ (ดูข้อ จำกัด ) ไซต์ที่แจกแจงจะต้องทำเครื่องหมายด้วย infer={'enumerate': 'parallel'}
เช่นในตัวอย่างคำอธิบายประกอบ
Trace_ELBO
แต่คำนวณส่วนหนึ่งของ ELBO ในการวิเคราะห์หากเป็นไปได้ดูเอกสารสำหรับรายละเอียดเพิ่มเติม
การสนับสนุน Windows ที่ จำกัด : โปรดทราบว่า Numpyro ยังไม่ได้ทดสอบบน Windows และอาจต้องสร้าง Jaxlib จากแหล่งที่มา ดูปัญหา JAX นี้สำหรับรายละเอียดเพิ่มเติม หรือคุณสามารถติดตั้งระบบย่อย Windows สำหรับ Linux และใช้ numpyro บนระบบในระบบ Linux ดูเพิ่มเติม CUDA บนระบบย่อย Windows สำหรับ Linux และโพสต์ฟอรัมนี้หากคุณต้องการใช้ GPU บน Windows
ในการติดตั้ง numpyro ด้วย JAX เวอร์ชันล่าสุดของ CPU คุณสามารถใช้ PIP:
pip install numpyro
ในกรณีของปัญหาความเข้ากันได้เกิดขึ้นระหว่างการดำเนินการคำสั่งข้างต้นคุณสามารถบังคับให้ติดตั้ง JAX เวอร์ชันที่เข้ากันได้
pip install numpyro[cpu]
ในการใช้ NUMPYRO บน GPU คุณต้องติดตั้ง CUDA ก่อนจากนั้นใช้คำสั่ง PIP ต่อไปนี้:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
หากคุณต้องการคำแนะนำเพิ่มเติมโปรดดูคำแนะนำการติดตั้ง JAX GPU
ในการเรียกใช้ numpyro บนคลาวด์ TPUs คุณสามารถดูตัวอย่าง JAX บนตัวอย่าง Cloud TPU
สำหรับ Cloud TPU VM คุณต้องตั้งค่าแบ็กเอนด์ TPU ตามรายละเอียดในคู่มือ Cloud TPU VM JAX Quickstart หลังจากที่คุณได้ตรวจสอบแล้วว่าแบ็กเอนด์ TPU ได้รับการตั้งค่าอย่างถูกต้องคุณสามารถติดตั้ง numpyro โดยใช้คำสั่ง pip install numpyro
แพลตฟอร์มเริ่มต้น: JAX จะใช้ GPU โดยค่าเริ่มต้นหากติดตั้งแพ็คเกจ
jaxlib
ที่รองรับ Cuda คุณสามารถใช้ SET_PLATFORM UTILITYnumpyro.set_platform("cpu")
เพื่อเปลี่ยนเป็น CPU ที่จุดเริ่มต้นของโปรแกรมของคุณ
นอกจากนี้คุณยังสามารถติดตั้ง numpyro จากแหล่งที่มา:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
คุณยังสามารถติดตั้ง numpyro ด้วย conda:
conda install -c conda-forge numpyro
ซึ่งแตกต่างจากใน pyro, numpyro.sample('x', dist.Normal(0, 1))
ไม่ทำงาน ทำไม
คุณมักจะใช้คำสั่ง numpyro.sample
นอกบริบทการอนุมาน JAX ไม่มีสถานะสุ่มทั่วโลกและเช่นนี้ตัวอย่างการแจกจ่ายจำเป็นต้องใช้คีย์ตัวสร้างตัวเลขสุ่มที่ชัดเจน (PRNGKEY) เพื่อสร้างตัวอย่างจาก อัลกอริธึมการอนุมานของ Numpyro ใช้ตัวจัดการเมล็ดเพื่อเธรดในคีย์เครื่องกำเนิดตัวเลขแบบสุ่มเบื้องหลัง
ตัวเลือกของคุณคือ:
เรียกการกระจายโดยตรงและให้ PRNGKey
เช่น dist.Normal(0, 1).sample(PRNGKey(0))
ระบุอาร์กิวเมนต์ rng_key
ให้กับ numpyro.sample
เช่น numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
ห่อรหัสในตัวจัดการ seed
ใช้เป็นตัวจัดการบริบทหรือเป็นฟังก์ชั่นที่พันผ่าน callable ดั้งเดิม เช่น
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
หรือเป็นฟังก์ชั่นการสั่งซื้อที่สูงขึ้น:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
ฉันสามารถใช้โมเดล Pyro เดียวกันเพื่อทำการอนุมานใน Numpyro ได้หรือไม่?
ดังที่คุณอาจสังเกตเห็นจากตัวอย่าง NumPyro รองรับ Pyro Primitives ทั้งหมดเช่น sample
, param
, plate
และ module
และ Effect Handlers นอกจากนี้เรายังมั่นใจได้ว่าการแจกแจง API นั้นขึ้นอยู่กับ torch.distributions
และคลาสการอนุมานเช่น SVI
และ MCMC
มีอินเทอร์เฟซเดียวกัน สิ่งนี้พร้อมกับความคล้ายคลึงกันใน API สำหรับการดำเนินการ NUMPY และ Pytorch ทำให้มั่นใจได้ว่าแบบจำลองที่มีข้อความ Pyro Primitive สามารถใช้กับแบ็กเอนด์ที่มีการเปลี่ยนแปลงเล็กน้อย ตัวอย่างความแตกต่างบางอย่างพร้อมกับการเปลี่ยนแปลงที่จำเป็นจะมีการบันทึกไว้ด้านล่าง:
torch
ใด ๆ ในโมเดลของคุณจะต้องเขียนในแง่ของการดำเนินการ jax.numpy
ที่เกี่ยวข้อง นอกจากนี้การดำเนินการ torch
ทั้งหมดไม่ได้มี numpy
คู่ (และในทางกลับกัน) และบางครั้งก็มีความแตกต่างเล็กน้อยใน APIpyro.sample
นอกบริบทการอนุมานจะต้องถูกห่อหุ้มด้วยตัวจัดการ seed
ตามที่กล่าวไว้ข้างต้นnumpyro.param
นอกบริบทการอนุมานจะไม่มีผล ในการดึงค่าพารามิเตอร์ที่ดีที่สุดจาก SVI ให้ใช้วิธี SVI.Get_Params โปรดทราบว่าคุณยังสามารถใช้คำสั่ง param
ภายในโมเดลและ NUMPYRO จะใช้ตัวจัดการเอฟเฟกต์ทดแทนภายในเพื่อทดแทนค่าจากเครื่องมือเพิ่มประสิทธิภาพเมื่อเรียกใช้โมเดลใน SVIสำหรับรุ่นเล็กส่วนใหญ่การเปลี่ยนแปลงที่จำเป็นในการเรียกใช้การอนุมานใน Numpyro ควรเป็นเล็กน้อย นอกจากนี้เรากำลังทำงานกับ Pyro-API ซึ่งช่วยให้คุณสามารถเขียนรหัสเดียวกันและส่งไปยังแบ็กเอนด์หลายรายการรวมถึง NUMPyro สิ่งนี้จะต้องมีข้อ จำกัด มากขึ้น แต่มีข้อได้เปรียบในการเป็นผู้ไม่เชื่อเรื่องพระเจ้า ดูเอกสารสำหรับตัวอย่างและแจ้งให้เราทราบความคิดเห็นของคุณ
ฉันจะมีส่วนร่วมในโครงการได้อย่างไร?
ขอบคุณสำหรับความสนใจในโครงการ! คุณสามารถดูปัญหาที่เป็นมิตรกับผู้เริ่มต้นที่ทำเครื่องหมายด้วยแท็กปัญหาแรกที่ดีใน GitHub นอกจากนี้โปรดรู้สึกถึงเราในฟอรัม
ในระยะเวลาอันใกล้นี้เราวางแผนที่จะทำงานต่อไปนี้ กรุณาเปิดปัญหาใหม่สำหรับคำขอและการปรับปรุงคุณสมบัติ:
แนวคิดที่สร้างแรงบันดาลใจที่อยู่เบื้องหลัง Numpyro และคำอธิบายของถั่วซ้ำสามารถพบได้ในบทความนี้ที่ปรากฏใน Neurips 2019 การแปลงโปรแกรมสำหรับการประชุมเชิงปฏิบัติการการเรียนรู้ของเครื่อง
หากคุณใช้ numpyro โปรดพิจารณาอ้าง:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
เช่นเดียวกับ
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}