由JAX驱动的自摩rad和JIT汇编的概率编程到GPU/TPU/CPU。
文档和示例|论坛
Numpyro是一个轻巧的概率编程库,为Pyro提供了Numpy后端。我们依靠JAX自动分化和与GPU / CPU的JIT汇编。 Numpyro正在积极开发中,因此随着设计的发展,请当心API的脆性,错误和更改。
Numpyro的设计为轻量级,专注于提供一种灵活的底物,用户可以基于:
sample
和param
类的pyro原语。模型代码看起来应该与Pyro非常相似,除了Pytorch和Numpy的API之间的一些较小差异。请参见下面的示例。jit
和grad
,以将整个集成步骤编译为XLA优化的内核。我们还通过将JIT汇编成坚果的整个树木建筑阶段(使用迭代性螺母)来消除python的头顶。还有一个基本的变异推理实现以及许多灵活(自动)指南用于自动分化变化推理(ADVI)。变异推理实现支持许多功能,包括对具有离散潜在变量的模型的支持(请参阅TraceGraph_elbo和TraceEnum_elbo)。torch.distributions
中的相同的API和批处理语义。除分布外,在具有有限支持的分配类别上运行时, constraints
和transforms
非常有用。最后,来自Tensorflow概率(TFP)的分布可以直接用于Numpyro模型。sample
和param
的原始物,并且可以轻松地扩展到实现自定义推理算法和推理实用程序。 让我们使用一个简单的示例探索Numpyro。我们将使用贝叶斯数据分析Gelman等人的八所学校示例:秒。 5.5,2003,研究教练对八所学校SAT表现的影响。
数据由:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
,其中y
是治疗效果和sigma
标准误差。我们为研究构建了一个层次模型,我们假设每个学校的组级参数theta
是从平均mu
和标准偏差tau
的正态分布中取样的,而观察到的数据又是由平均分布产生的。和theta
(真实效果)和sigma
的标准偏差。这使我们能够通过从所有观测值中汇总来估算人口级参数mu
和tau
,同时仍然允许使用组级别的theta
参数在学校之间进行单个变化。
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
让我们通过使用No-U-Turn采样器(NUTS)运行MCMC来推断模型中未知参数的值。请注意mcmc.run中extra_fields
参数的用法。默认情况下,我们仅在使用MCMC
进行推理时才从目标(后部)分布中收集样本。但是,通过使用extra_fields
参数,可以轻松地收集诸如势能或样本的接受概率之类的其他字段。有关可以收集的可能字段的列表,请参见HMCSTATE对象。在此示例中,我们还将为每个样品收集potential_energy
_energy。
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
我们可以打印MCMC运行的摘要,并检查在推理过程中是否观察到任何分歧。此外,由于我们收集了每个样品的势能,因此我们可以轻松计算预期的对数关节密度。
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
拆分Gelman Rubin诊断( r_hat
)的值高于1的值表明该链没有完全收敛。有效样本量( n_eff
)的低值,特别是对于tau
,而不同的过渡数似乎有问题。幸运的是,这是一种常见的病理,可以通过在我们的模型中使用tau
的非中心参数化来纠正。通过使用变换分布实例以及重新聚集效应处理程序,这在Numpyro中很简单。让我们重写相同的模型,但我们将不是从Normal(mu, tau)
中采样theta
,而是将其从使用Affinetransform转换的基础Normal(0, 1)
分布中进行采样。请注意,通过这样做,Numpyro通过生成samples theta_base
的基本Normal(0, 1)
分布来运行HMC。我们看到所得链不会患有相同的病理 - 所有参数的Gelman Rubin诊断为1,有效的样本量看起来不错!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
请注意,对于具有loc,scale
例如Normal
, Cauchy
, StudentT
,我们还提供了Locscalereparam Reparameperizer,以实现相同的目的。相应的代码将是
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
现在,让我们假设我们有一所新学校,我们尚未观察到任何考试成绩,但我们想产生预测。 Numpyro为此目的提供了一个预测类。请注意,在没有任何观察到的数据的情况下,我们只需使用总体级参数来生成预测。 Predictive
效用条件条件未观察到的mu
和tau
位点是从我们上次MCMC运行的后验分布中得出的值,并将模型向前运行以生成预测。
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
有关指定模型并在Numpyro中进行推断的更多示例:
lax.scan
原始推理。Pyro用户会注意到,用于模型规范和推理的API与Pyro(包括Distribution api)的API大致相同。但是,用户应意识到的一些重要的核心差异(反映在内部质量中)。例如在Numpyro中,没有全局参数存储或随机状态,可以使我们有可能利用JAX的JAX汇编。此外,用户可能需要以更具功能的样式编写其模型,该模型与JAX更好。有关差异列表,请参阅常见问题解答。
我们概述了Numpyro支持的大多数推理算法,并提供了一些指南,涉及哪种推理算法可能适用于不同类别的模型。
像HMC/坚果一样,所有剩余的MCMC算法(如果可能)在离散的潜在变量上支持枚举(请参阅限制)。需要在注释示例中标记枚举站点,以infer={'enumerate': 'parallel'}
。
Trace_ELBO
一样,但是如果可能的话,可以分析地计算ELBO的一部分。有关更多详细信息,请参见文档。
有限的Windows支持:请注意,Numpyro在Windows上未经测试,可能需要从源构建Jaxlib。有关更多详细信息,请参见此JAX问题。另外,您可以安装Windows子系统的Linux,并在Linux系统上使用Numpyro。如果您想在Windows上使用GPU,请参见Windows子系统上的CUDA和该论坛帖子。
要使用最新的CPU版本JAX安装Numpyro,您可以使用PIP:
pip install numpyro
如果在执行上述命令期间出现兼容性问题,则可以将已知兼容的CPU版本的JAX的安装与
pip install numpyro[cpu]
要在GPU上使用Numpyro ,您需要先安装CUDA,然后使用以下PIP命令:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
如果您需要进一步的指导,请查看JAX GPU安装说明。
要在Cloud TPU上运行Numpyro ,您可以在云TPU示例上查看一些JAX。
对于Cloud TPU VM,您需要在云TPU VM JAX QuickStart指南中详细介绍TPU后端。验证了正确设置TPU后端后,您可以使用pip install numpyro
命令安装NumPyro。
默认平台:如果安装了CUDA支持的
jaxlib
软件包,JAX将默认使用GPU。您可以使用set_platform实用程序numpyro.set_platform("cpu")
在程序开头切换到CPU。
您也可以从来源安装Numpyro:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
您也可以使用conda安装numpyro:
conda install -c conda-forge numpyro
与pyro不同, numpyro.sample('x', dist.Normal(0, 1))
不起作用。为什么?
您最有可能在推理上下文之外使用numpyro.sample
语句。 JAX没有全局随机状态,因此,分发采样器需要一个显式的随机数生成器密钥(PRNGKEY)来生成样本。 Numpyro的推理算法使用种子处理程序在场景后面的随机数生成器键中螺纹。
您的选择是:
直接调用分布并提供一个PRNGKey
,例如dist.Normal(0, 1).sample(PRNGKey(0))
向numpyro.sample
提供rng_key
参数。例如numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
。
将代码包装在seed
处理程序中,用作上下文管理器或用原始可召唤的函数。例如
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
,或作为高阶功能:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
我可以使用相同的Pyro模型在Numpyro中进行推断吗?
从示例中您可能注意到的那样,Numpyro支持所有Pyro原语,例如sample
, param
, plate
和module
以及效果处理程序。此外,我们确保了分布API基于torch.distributions
,并且SVI
和MCMC
等推理类具有相同的接口。这与numpy和pytorch操作的API中的相似性确保了包含Pyro原始语句的模型可以与任何一个较小的更改一起使用。下面指出了一些差异以及所需的更改的示例:
torch
操作都需要根据相应的jax.numpy
操作编写。此外,并非所有的torch
操作都有一个numpy
对应物(反之亦然),有时API差异很小。pyro.sample
样本语句将需要包裹在seed
处理程序中。numpyro.param
在推理上下文之外没有效果。要从SVI检索优化的参数值,请使用SVI.GEG_PARAMS方法。请注意,您仍然可以在模型中使用param
语句,而Numpyro将在内部使用替代效果处理程序在SVI中运行模型时从优化器中替换值。对于大多数小型型号,在Numpyro中进行推断所需的变化应很小。此外,我们正在研究Pyro-API,该pyro-api使您可以编写相同的代码并将其派遣到包括Numpyro在内的多个后端。这必然会更加限制,但具有后端不可知论的优势。请参阅文档以获取示例,并让我们知道您的反馈。
我该如何为该项目做出贡献?
感谢您对该项目的兴趣!您可以查看Github上标有良好第一期标签的初学者友好问题。另外,请在论坛上与我们联系。
在短期内,我们计划在以下工作。请为功能请求和增强功能打开新问题:
Numpyro背后的动机思想以及对机器学习研讨会的Neurips 2019计划转换中出现的本文中可以找到对迭代性坚果的描述。
如果您使用numpyro,请考虑引用:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
也
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}