Pemrograman probabilistik yang ditenagai oleh JAX untuk kompilasi Autograd dan JIT ke GPU/TPU/CPU.
Dokumen dan Contoh | Forum
Numpyro adalah perpustakaan pemrograman probabilistik ringan yang menyediakan backend yang tidak bisa dipenuhi piro. Kami mengandalkan JAX untuk diferensiasi otomatis dan kompilasi JIT ke GPU / CPU. Numpyro sedang dalam pengembangan aktif, jadi waspadalah terhadap kerapuhan, bug, dan perubahan pada API saat desain berkembang.
Numpyro dirancang agar ringan dan berfokus pada penyediaan substrat fleksibel yang dapat dibangun oleh pengguna:
sample
dan param
. Kode model harus terlihat sangat mirip dengan pyro kecuali untuk beberapa perbedaan kecil antara Pytorch dan API Numpy. Lihat contoh di bawah ini.jit
dan grad
untuk menyusun seluruh langkah integrasi menjadi kernel yang dioptimalkan XLA. Kami juga menghilangkan overhead python dengan JIT menyusun seluruh tahap bangunan pohon dalam kacang -kacangan (ini dimungkinkan menggunakan kacang berulang). Ada juga implementasi inferensi variasi dasar bersama dengan banyak panduan fleksibel (otomatis) untuk inferensi variasional diferensiasi otomatis (Advi). Implementasi inferensi variasional mendukung sejumlah fitur, termasuk dukungan untuk model dengan variabel laten diskrit (lihat Tracegraph_elbo dan traceenum_elbo).torch.distributions
. Selain distribusi, constraints
dan transforms
sangat berguna ketika beroperasi pada kelas distribusi dengan dukungan terikat. Akhirnya, distribusi dari TensorFlow Probability (TFP) dapat langsung digunakan dalam model Numpyro.sample
dan param
dapat diberikan interpretasi yang tidak standar menggunakan penangan efek dari modul Numpyro.handlers, dan ini dapat dengan mudah diperluas untuk menerapkan algoritma inferensi khusus dan utilitas inferensi. Mari kita jelajahi Numpyro menggunakan contoh sederhana. Kami akan menggunakan contoh delapan sekolah dari Gelman et al., Analisis Data Bayesian: Sec. 5.5, 2003, yang mempelajari pengaruh pembinaan pada kinerja SAT di delapan sekolah.
Data diberikan oleh:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
, di mana y
adalah efek pengobatan dan sigma
kesalahan standar. Kami membangun model hierarkis untuk penelitian di mana kami mengasumsikan bahwa parameter tingkat kelompok theta
untuk setiap sekolah diambil sampelnya dari distribusi normal dengan rata-rata mu
dan standar deviasi tau
, sedangkan data yang diamati pada gilirannya dihasilkan dari distribusi normal dengan rata-rata dan standar deviasi yang diberikan oleh theta
(efek sejati) dan sigma
, masing -masing. Hal ini memungkinkan kita untuk memperkirakan parameter tingkat populasi mu
dan tau
dengan menggabungkan dari semua pengamatan, sementara masih memungkinkan untuk variasi individu di antara sekolah-sekolah menggunakan parameter theta
tingkat kelompok.
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
Mari kita menyimpulkan nilai-nilai parameter yang tidak diketahui dalam model kami dengan menjalankan MCMC menggunakan sampler no-U-turn (NUTS). Perhatikan penggunaan argumen extra_fields
di MCMC.Run. Secara default, kami hanya mengumpulkan sampel dari distribusi target (posterior) ketika kami menjalankan inferensi menggunakan MCMC
. Namun, mengumpulkan bidang tambahan seperti energi potensial atau probabilitas penerimaan sampel dapat dengan mudah dicapai dengan menggunakan argumen extra_fields
. Untuk daftar kemungkinan bidang yang dapat dikumpulkan, lihat objek HMCState. Dalam contoh ini, kami juga akan mengumpulkan potential_energy
untuk setiap sampel.
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
Kita dapat mencetak ringkasan menjalankan MCMC, dan memeriksa jika kita mengamati divergensi selama inferensi. Selain itu, karena kami mengumpulkan energi potensial untuk masing -masing sampel, kami dapat dengan mudah menghitung kepadatan sambungan log yang diharapkan.
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
Nilai di atas 1 untuk split Gelman Rubin Diagnostik ( r_hat
) menunjukkan bahwa rantai belum sepenuhnya konvergen. Nilai rendah untuk ukuran sampel efektif ( n_eff
), terutama untuk tau
, dan jumlah transisi divergen terlihat bermasalah. Untungnya, ini adalah patologi umum yang dapat diperbaiki dengan menggunakan parameterisasi yang tidak berpusat pada tau
dalam model kami. Ini mudah dilakukan di Numpyro dengan menggunakan instance transformedDistribution bersama dengan penangan efek reparameterisasi. Mari kita tulis ulang model yang sama tetapi alih -alih mencicipi theta
dari Normal(mu, tau)
, kami sebaliknya akan mencicipinya dari distribusi basis Normal(0, 1)
yang ditransformasikan menggunakan affinetransform. Perhatikan bahwa dengan melakukannya, Numpyro menjalankan HMC dengan menghasilkan sampel theta_base
untuk distribusi basis Normal(0, 1)
sebagai gantinya. Kita melihat bahwa rantai yang dihasilkan tidak menderita patologi yang sama - Diagnostik Gelman Rubin adalah 1 untuk semua parameter dan ukuran sampel yang efektif terlihat cukup bagus!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
Perhatikan bahwa untuk kelas distribusi dengan loc,scale
seperti Normal
, Cauchy
, StudentT
, kami juga menyediakan LocScaleReparam Reparameterizer untuk mencapai tujuan yang sama. Kode yang sesuai akan
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
Sekarang, mari kita asumsikan bahwa kita memiliki sekolah baru yang belum kita amati nilai tes, tetapi kita ingin menghasilkan prediksi. Numpyro menyediakan kelas prediktif untuk tujuan seperti itu. Perhatikan bahwa dengan tidak adanya data yang diamati, kami cukup menggunakan parameter tingkat populasi untuk menghasilkan prediksi. Kondisi utilitas Predictive
situs mu
dan tau
yang tidak teramati untuk nilai -nilai yang diambil dari distribusi posterior dari menjalankan MCMC terakhir kami, dan menjalankan model ke depan untuk menghasilkan prediksi.
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
Untuk beberapa contoh lagi tentang menentukan model dan melakukan inferensi di Numpyro:
lax.scan
primitif untuk inferensi cepat.Pengguna pyro akan mencatat bahwa API untuk spesifikasi model dan inferensi sebagian besar sama dengan piro, termasuk API distribusi, dengan desain. Namun, ada beberapa perbedaan inti penting (tercermin dalam internal) yang harus diperhatikan pengguna. Misalnya di Numpyro, tidak ada toko parameter global atau keadaan acak, untuk memungkinkan kita memanfaatkan kompilasi JIT Jax. Juga, pengguna mungkin perlu menulis model mereka dengan gaya yang lebih fungsional yang bekerja lebih baik dengan JAX. Lihat FAQ untuk daftar perbedaan.
Kami memberikan gambaran tentang sebagian besar algoritma inferensi yang didukung oleh Numpyro dan menawarkan beberapa pedoman tentang algoritma inferensi mana yang mungkin sesuai untuk berbagai kelas model.
Seperti HMC/Nuts, semua algoritma MCMC yang tersisa mendukung enumerasi atas variabel laten diskrit jika memungkinkan (lihat pembatasan). Situs -situs yang disebutkan perlu ditandai dengan infer={'enumerate': 'parallel'}
seperti dalam contoh anotasi.
Trace_ELBO
tetapi menghitung bagian dari Elbo secara analitik jika hal itu dimungkinkan.Lihat dokumen untuk lebih jelasnya.
Dukungan Windows Terbatas: Perhatikan bahwa Numpyro belum teruji di jendela, dan mungkin memerlukan membangun jaxlib dari sumber. Lihat masalah JAX ini untuk lebih jelasnya. Atau, Anda dapat menginstal Subsistem Windows untuk Linux dan menggunakan Numpyro di atasnya seperti pada sistem Linux. Lihat juga CUDA di Subsistem Windows untuk Linux dan posting forum ini jika Anda ingin menggunakan GPU di Windows.
Untuk menginstal Numpyro dengan JAX versi CPU terbaru, Anda dapat menggunakan PIP:
pip install numpyro
Dalam hal masalah kompatibilitas muncul selama pelaksanaan perintah di atas, Anda sebaliknya dapat memaksa pemasangan versi CPU yang kompatibel dengan JAX dengan
pip install numpyro[cpu]
Untuk menggunakan Numpyro pada GPU , Anda perlu menginstal CUDA terlebih dahulu dan kemudian menggunakan perintah PIP berikut:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Jika Anda memerlukan panduan lebih lanjut, silakan lihat instruksi instalasi GPU JAX.
Untuk menjalankan Numpyro di Cloud TPU , Anda dapat melihat beberapa contoh Jax di Cloud TPU.
Untuk Cloud TPU VM, Anda perlu mengatur backend TPU sebagaimana dirinci dalam panduan Cloud TPU VM Jax QuickStart. Setelah Anda memverifikasi bahwa backend TPU diatur dengan benar, Anda dapat menginstal Numpyro menggunakan pip install numpyro
Command.
Platform default: JAX akan menggunakan GPU secara default jika paket
jaxlib
yang didukung CUDA diinstal. Anda dapat menggunakan set_platform utilitynumpyro.set_platform("cpu")
untuk beralih ke CPU di awal program Anda.
Anda juga dapat menginstal numpyro dari sumber:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
Anda juga dapat menginstal numpyro dengan conda:
conda install -c conda-forge numpyro
Tidak seperti di Pyro, numpyro.sample('x', dist.Normal(0, 1))
tidak berfungsi. Mengapa?
Anda kemungkinan besar menggunakan pernyataan numpyro.sample
di luar konteks inferensi. Jax tidak memiliki keadaan acak global, dan dengan demikian, sampler distribusi memerlukan kunci generator bilangan acak eksplisit (PRNGKEY) untuk menghasilkan sampel dari. Algoritma inferensi Numpyro menggunakan pawang benih untuk memasukkan kunci generator bilangan acak, di belakang layar.
Pilihan Anda adalah:
Hubungi distribusi secara langsung dan berikan PRNGKey
, misalnya dist.Normal(0, 1).sample(PRNGKey(0))
Berikan argumen rng_key
kepada numpyro.sample
. misalnya numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
.
Bungkus kode dalam penangan seed
, digunakan sebagai manajer konteks atau sebagai fungsi yang membungkus yang dapat dipanggil asli. misalnya
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
, atau sebagai fungsi urutan yang lebih tinggi:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
Bisakah saya menggunakan model pyro yang sama untuk melakukan inferensi di Numpyro?
Seperti yang mungkin Anda perhatikan dari contoh -contohnya, Numpyro mendukung semua primitif piro seperti sample
, param
, plate
dan module
, dan penangan efek. Selain itu, kami telah memastikan bahwa API Distribusi didasarkan pada torch.distributions
, dan kelas inferensi seperti SVI
dan MCMC
memiliki antarmuka yang sama. Ini bersama dengan kesamaan dalam API untuk operasi Numpy dan Pytorch memastikan bahwa model yang mengandung pernyataan primitif pyro dapat digunakan dengan backend dengan beberapa perubahan kecil. Contoh beberapa perbedaan bersama dengan perubahan yang diperlukan, dicatat di bawah ini:
torch
apa pun dalam model Anda perlu ditulis dalam hal operasi jax.numpy
yang sesuai. Selain itu, tidak semua operasi torch
memiliki mitra numpy
(dan sebaliknya), dan kadang-kadang ada perbedaan kecil dalam API.pyro.sample
di luar konteks inferensi perlu dibungkus dengan penangan seed
, seperti yang disebutkan di atas.numpyro.param
di luar konteks inferensi tidak akan berpengaruh. Untuk mengambil nilai parameter yang dioptimalkan dari SVI, gunakan metode svi.get_params. Perhatikan bahwa Anda masih dapat menggunakan pernyataan param
di dalam model dan Numpyro akan menggunakan penangan efek pengganti secara internal untuk mengganti nilai dari pengoptimal saat menjalankan model di SVI.Untuk sebagian besar model kecil, perubahan yang diperlukan untuk menjalankan inferensi di Numpyro harus kecil. Selain itu, kami sedang mengerjakan Pyro-API yang memungkinkan Anda untuk menulis kode yang sama dan mengirimkannya ke beberapa backend, termasuk Numpyro. Ini tentu akan lebih ketat, tetapi memiliki keuntungan menjadi agnostik backend. Lihat contoh dokumentasi, dan beri tahu kami umpan balik Anda.
Bagaimana saya bisa berkontribusi pada proyek?
Terima kasih atas minat Anda pada proyek ini! Anda dapat melihat masalah ramah pemula yang ditandai dengan tag edisi pertama yang baik di GitHub. Juga, silakan menghubungi kami di forum.
Dalam waktu dekat, kami berencana untuk mengerjakan yang berikut ini. Harap buka masalah baru untuk permintaan dan peningkatan fitur:
Gagasan memotivasi di balik Numpyro dan deskripsi kacang berulang dapat ditemukan dalam makalah ini yang muncul dalam transformasi program Neurips 2019 untuk lokakarya pembelajaran mesin.
Jika Anda menggunakan numpyro, harap pertimbangkan mengutip:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
maupun
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}