由JAX驅動的自摩rad和JIT彙編的概率編程到GPU/TPU/CPU。
文檔和示例|論壇
Numpyro是一個輕巧的概率編程庫,為Pyro提供了Numpy後端。我們依靠JAX自動分化和與GPU / CPU的JIT彙編。 Numpyro正在積極開發中,因此隨著設計的發展,請當心API的脆性,錯誤和更改。
Numpyro的設計為輕量級,專注於提供一種靈活的底物,用戶可以基於:
sample
和param
類的pyro原語。模型代碼看起來應該與Pyro非常相似,除了Pytorch和Numpy的API之間的一些較小差異。請參見下面的示例。jit
和grad
,以將整個集成步驟編譯為XLA優化的內核。我們還通過將JIT彙編成堅果的整個樹木建築階段(使用迭代性螺母)來消除python的頭頂。還有一個基本的變異推理實現以及許多靈活(自動)指南用於自動分化變化推理(ADVI)。變異推理實現支持許多功能,包括對具有離散潛在變量的模型的支持(請參閱TraceGraph_elbo和TraceEnum_elbo)。torch.distributions
中的相同的API和批處理語義。除分佈外,在具有有限支持的分配類別上運行時, constraints
和transforms
非常有用。最後,來自Tensorflow概率(TFP)的分佈可以直接用於Numpyro模型。sample
和param
的原始物,並且可以輕鬆地擴展到實現自定義推理算法和推理實用程序。 讓我們使用一個簡單的示例探索Numpyro。我們將使用貝葉斯數據分析Gelman等人的八所學校示例:秒。 5.5,2003,研究教練對八所學校SAT表現的影響。
數據由:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
,其中y
是治療效果和sigma
標準誤差。我們為研究構建了一個層次模型,我們假設每個學校的組級參數theta
是從平均mu
和標準偏差tau
的正態分佈中取樣的,而觀察到的數據又是由平均分佈產生的。和theta
(真實效果)和sigma
的標準偏差。這使我們能夠通過從所有觀測值中匯總來估算人口級參數mu
和tau
,同時仍然允許使用組級別的theta
參數在學校之間進行單個變化。
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
讓我們通過使用No-U-Turn採樣器(NUTS)運行MCMC來推斷模型中未知參數的值。請注意mcmc.run中extra_fields
參數的用法。默認情況下,我們僅在使用MCMC
進行推理時才從目標(後部)分佈中收集樣本。但是,通過使用extra_fields
參數,可以輕鬆地收集諸如勢能或樣本的接受概率之類的其他字段。有關可以收集的可能字段的列表,請參見HMCSTATE對象。在此示例中,我們還將為每個樣品收集potential_energy
_energy。
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
我們可以打印MCMC運行的摘要,並檢查在推理過程中是否觀察到任何分歧。此外,由於我們收集了每個樣品的勢能,因此我們可以輕鬆計算預期的對數關節密度。
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
拆分Gelman Rubin診斷( r_hat
)的值高於1的值表明該鏈沒有完全收斂。有效樣本量( n_eff
)的低值,特別是對於tau
,而不同的過渡數似乎有問題。幸運的是,這是一種常見的病理,可以通過在我們的模型中使用tau
的非中心參數化來糾正。通過使用變換分佈實例以及重新聚集效應處理程序,這在Numpyro中很簡單。讓我們重寫相同的模型,但我們將不是從Normal(mu, tau)
中採樣theta
,而是將其從使用Affinetransform轉換的基礎Normal(0, 1)
分佈中進行採樣。請注意,通過這樣做,Numpyro通過生成samples theta_base
的基本Normal(0, 1)
分佈來運行HMC。我們看到所得鏈不會患有相同的病理 - 所有參數的Gelman Rubin診斷為1,有效的樣本量看起來不錯!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
請注意,對於具有loc,scale
例如Normal
, Cauchy
, StudentT
,我們還提供了Locscalereparam Reparameperizer,以實現相同的目的。相應的代碼將是
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
現在,讓我們假設我們有一所新學校,我們尚未觀察到任何考試成績,但我們想產生預測。 Numpyro為此目的提供了一個預測類。請注意,在沒有任何觀察到的數據的情況下,我們只需使用總體級參數來生成預測。 Predictive
效用條件條件未觀察到的mu
和tau
位點是從我們上次MCMC運行的後驗分佈中得出的值,並將模型向前運行以生成預測。
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
有關指定模型並在Numpyro中進行推斷的更多示例:
lax.scan
原始推理。Pyro用戶會注意到,用於模型規範和推理的API與Pyro(包括Distribution api)的API大致相同。但是,用戶應意識到的一些重要的核心差異(反映在內部質量中)。例如在Numpyro中,沒有全局參數存儲或隨機狀態,可以使我們有可能利用JAX的JIT彙編。此外,用戶可能需要以更具功能的樣式編寫其模型,該模型與JAX更好。有關差異列表,請參閱常見問題解答。
我們概述了Numpyro支持的大多數推理算法,並提供了一些指南,涉及哪種推理算法可能適用於不同類別的模型。
像HMC/堅果一樣,所有剩餘的MCMC算法(如果可能)在離散的潛在變量上支持枚舉(請參閱限制)。需要在註釋示例中標記枚舉站點,以infer={'enumerate': 'parallel'}
。
Trace_ELBO
一樣,但是如果可能的話,可以分析地計算ELBO的一部分。有關更多詳細信息,請參見文檔。
有限的Windows支持:請注意,Numpyro在Windows上未經測試,可能需要從源構建Jaxlib。有關更多詳細信息,請參見此JAX問題。另外,您可以安裝Windows子系統的Linux,並在Linux系統上使用Numpyro。如果您想在Windows上使用GPU,請參見Windows子系統上的CUDA和該論壇帖子。
要使用最新的CPU版本JAX安裝Numpyro,您可以使用PIP:
pip install numpyro
如果在執行上述命令期間出現兼容性問題,則可以將已知兼容的CPU版本的JAX的安裝與
pip install numpyro[cpu]
要在GPU上使用Numpyro ,您需要先安裝CUDA,然後使用以下PIP命令:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
如果您需要進一步的指導,請查看JAX GPU安裝說明。
要在Cloud TPU上運行Numpyro ,您可以在雲TPU示例上查看一些JAX。
對於Cloud TPU VM,您需要在雲TPU VM JAX QuickStart指南中詳細介紹TPU後端。驗證了正確設置TPU後端后,您可以使用pip install numpyro
命令安裝NumPyro。
默認平台:如果安裝了CUDA支持的
jaxlib
軟件包,JAX將默認使用GPU。您可以使用set_platform實用程序numpyro.set_platform("cpu")
在程序開頭切換到CPU。
您也可以從來源安裝Numpyro:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
您也可以使用conda安裝numpyro:
conda install -c conda-forge numpyro
與pyro不同, numpyro.sample('x', dist.Normal(0, 1))
不起作用。為什麼?
您最有可能在推理上下文之外使用numpyro.sample
語句。 JAX沒有全局隨機狀態,因此,分發採樣器需要一個顯式的隨機數生成器密鑰(PRNGKEY)來生成樣本。 Numpyro的推理算法使用種子處理程序在場景後面的隨機數生成器鍵中螺紋。
您的選擇是:
直接調用分佈並提供一個PRNGKey
,例如dist.Normal(0, 1).sample(PRNGKey(0))
向numpyro.sample
提供rng_key
參數。例如numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
。
將代碼包裝在seed
處理程序中,用作上下文管理器或用原始可召喚的函數。例如
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
,或作為高階功能:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
我可以使用相同的Pyro模型在Numpyro中進行推斷嗎?
從示例中您可能注意到的那樣,Numpyro支持所有Pyro原語,例如sample
, param
, plate
和module
以及效果處理程序。此外,我們確保了分佈API基於torch.distributions
,並且SVI
和MCMC
等推理類具有相同的接口。這與numpy和pytorch操作的API中的相似性確保了包含Pyro原始語句的模型可以與任何一個較小的更改一起使用。下面指出了一些差異以及所需的更改的示例:
torch
操作都需要根據相應的jax.numpy
操作編寫。此外,並非所有的torch
操作都有一個numpy
對應物(反之亦然),有時API差異很小。pyro.sample
樣本語句將需要包裹在seed
處理程序中。numpyro.param
在推理上下文之外沒有效果。要從SVI檢索優化的參數值,請使用SVI.GEG_PARAMS方法。請注意,您仍然可以在模型中使用param
語句,而Numpyro將在內部使用替代效果處理程序在SVI中運行模型時從優化器中替換值。對於大多數小型型號,在Numpyro中進行推斷所需的變化應很小。此外,我們正在研究Pyro-API,該pyro-api使您可以編寫相同的代碼並將其派遣到包括Numpyro在內的多個後端。這必然會更加限制,但具有後端不可知論的優勢。請參閱文檔以獲取示例,並讓我們知道您的反饋。
我該如何為該項目做出貢獻?
感謝您對該項目的興趣!您可以查看Github上標有良好第一期標籤的初學者友好問題。另外,請在論壇上與我們聯繫。
在短期內,我們計劃在以下工作。請為功能請求和增強功能打開新問題:
Numpyro背後的動機思想以及對機器學習研討會的Neurips 2019計劃轉換中出現的本文中可以找到對迭代性堅果的描述。
如果您使用numpyro,請考慮引用:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
也
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}