GPU/TPU/CPUへのオートグラードおよびJITコンパイル用のJAXを搭載した確率的プログラミング。
ドキュメントと例|フォーラム
Numpyroは、PyroにNumpyバックエンドを提供する軽量確率的プログラミングライブラリです。 GPU / CPUへの自動分化とJITコンピレーションのためにJAXに依存しています。 Numpyroは積極的な開発中であるため、設計が進化するにつれて、Brittleness、Bugs、およびAPIの変更に注意してください。
Numpyroは軽量になるように設計されており、ユーザーが構築できる柔軟な基質の提供に焦点を当てています。
sample
やparam
などのPyro Primitivesに加えて、通常のPythonおよびNumpyコードを含めることができます。モデルコードは、PytorchとNumpyのAPIのいくつかの小さな違いを除いて、Pyroに非常によく似ている必要があります。以下の例を参照してください。jit
とgrad
を作成して、統合ステップ全体をXLA最適化されたカーネルにコンパイルできます。また、JITがナットでツリービルディングステージ全体をコンパイルすることにより、Pythonオーバーヘッドを排除します(これは、反復ナットを使用して可能です)。また、自動分化の変動推論(ADVI)のための多くの柔軟な(自動)ガイドとともに、基本的な変動推論の実装もあります。変分推論の実装は、個別の潜在変数を持つモデルのサポートを含む多くの機能をサポートしています(tracegraph_elboおよびtraceenum_elboを参照)。torch.distributions
と同じAPIおよびバッチセマンティクスに依存することができます。分布に加えて、 constraints
とtransforms
、限界サポートを備えた配布クラスで操作する場合に非常に役立ちます。最後に、Tensorflow確率(TFP)からの分布は、Numpyroモデルで直接使用できます。sample
やparam
などのプリミティブは、numpyro.handlersモジュールの効果ハンドラーを使用して非標準解釈を提供できます。これらを簡単に拡張して、カスタム推論アルゴリズムと推論ユーティリティを実装できます。 簡単な例を使用して、numpyroを探索しましょう。 Gelman et al。、Bayesian Data Analysis:Sec。の8つの学校の例を使用します。 5.5、2003、8つの学校でのSATパフォーマンスに対するコーチングの効果を研究しています。
データは次のように与えられます:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
、ここで、 y
治療効果であり、 sigma
標準誤差です。各学校のグループレベルのtheta
不明な平均mu
および標準偏差tau
を持つ正規分布からサンプリングされ、観測されたデータは平均分布の正規分布から生成されると仮定する研究の階層モデルを構築します。それぞれtheta
(真の効果)とsigma
によって与えられる標準偏差。これにより、すべての観測からプールすることにより、人口レベルのパラメーターmu
とtau
推定することができ、グループレベルのtheta
パラメーターを使用して学校間の個別のバリエーションを可能にします。
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
No-Uターンサンプラー(NUTS)を使用してMCMCを実行することにより、モデルの未知のパラメーターの値を推測しましょう。 MCMC.RUNでのextra_fields
引数の使用法に注意してください。デフォルトでは、 MCMC
使用して推論を実行する場合、ターゲット(事後)分布からサンプルのみを収集します。ただし、ポテンシャルエネルギーやサンプルの受け入れ確率などの追加のフィールドを収集することは、 extra_fields
引数を使用して簡単に実現できます。収集できる可能性のあるフィールドのリストについては、HMCSTATEオブジェクトを参照してください。この例では、各サンプルのpotential_energy
をさらに収集します。
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
MCMC実行の概要を印刷し、推論中に発散を観察したかどうかを調べることができます。さらに、各サンプルのポテンシャルエネルギーを収集したため、予想されるログジョイント密度を簡単に計算できます。
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
Split Gelman Rubin Diagnostic( r_hat
)の1を超える値は、チェーンが完全に収束していないことを示しています。特にtau
の効果的なサンプルサイズ( n_eff
)の値が低く、発散遷移の数には問題があります。幸いなことに、これはモデルのtau
の非中心パラメーター化を使用することで修正できる一般的な病理です。これは、Reparameterization EffectハンドラーとともにTransformedDistributionインスタンスを使用することにより、Numpyroで行うのが簡単です。同じモデルを書き直しましょうが、 Normal(mu, tau)
からtheta
サンプリングする代わりに、Affintransformを使用して変換されるベースNormal(0, 1)
分布からサンプリングします。そうすることにより、numpyroは、代わりにベースNormal(0, 1)
分布のサンプルtheta_base
生成することによりHMCを実行することに注意してください。結果として得られるチェーンは同じ病理に悩まされていないことがわかります。GelmanRubinの診断はすべてのパラメーターに対して1であり、効果的なサンプルサイズは非常によく見えます。
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
Normal
、 Cauchy
、 StudentT
などのloc,scale
パラメーターを使用した分布のクラスについては、同じ目的を達成するためにlocscalererparam Reparameterizerも提供することに注意してください。対応するコードは次のとおりです
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
それでは、テストスコアを観察していない新しい学校があると仮定しましょうが、予測を生成したいと思います。 Numpyroは、そのような目的のために予測クラスを提供します。観察されたデータがない場合、人口レベルのパラメーターを使用して予測を生成するだけであることに注意してください。 Predictive
ユーティリティは、観測されていないmu
およびtau
サイトを、前回のMCMC実行からの事後分布から引き出される値に条件付けし、モデルを前方に実行して予測を生成します。
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
モデルを指定し、numpyroで推論を行うことに関するいくつかの例については、
lax.scan
ます。Pyroユーザーは、モデル仕様と推論のAPIは、Distumionsions APIを含むPyroとほぼ同じであることに注意してください。ただし、ユーザーが知っておくべき重要なコアの違い(内部に反映)があります。たとえば、numpyroでは、JaxのJITコンピレーションを活用できるようにするために、グローバルパラメーターストアやランダム状態はありません。また、ユーザーは、JAXでより適切に機能するより機能的なスタイルでモデルを作成する必要がある場合があります。違いのリストについては、FAQを参照してください。
Numpyroでサポートされているほとんどの推論アルゴリズムの概要を提供し、さまざまなクラスのモデルにどのような推論アルゴリズムが適切であるかについてのいくつかのガイドラインを提供します。
HMC/ナットと同様に、残りのすべてのMCMCアルゴリズムは、可能であれば離散潜在変数よりも列挙をサポートしています(制限を参照)。列挙されたサイトには、注釈の例のように、 infer={'enumerate': 'parallel'}
でマークする必要があります。
Trace_ELBO
に似ていますが、そうすることが可能な場合は分析的にエルボの一部を計算します。詳細については、ドキュメントを参照してください。
限られたウィンドウのサポート: numpyroはWindowsでテストされておらず、ソースからJaxlibを構築する必要がある場合があることに注意してください。詳細については、このJaxの問題を参照してください。または、Linux用のWindowsサブシステムをインストールし、LinuxシステムのようにNumpyroを使用することもできます。 WindowsでGPUを使用する場合は、Linux用のWindowsサブシステムとこのフォーラム投稿も参照してください。
Jaxの最新のCPUバージョンでNumpyroをインストールするには、PIPを使用できます。
pip install numpyro
上記のコマンドの実行中に互換性の問題が発生した場合、代わりに既知の互換性のあるCPUバージョンのjaxのインストールを強制することができます
pip install numpyro[cpu]
GPUでNumpyroを使用するには、最初にCUDAをインストールしてから、次のPIPコマンドを使用する必要があります。
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
さらにガイダンスが必要な場合は、JAX GPUのインストール手順をご覧ください。
クラウドTPUでnumpyroを実行するには、クラウドTPUの例でいくつかのJaxを見ることができます。
Cloud TPU VMの場合、Cloud TPU VM Jax Quickstartガイドで詳述されているように、TPUバックエンドをセットアップする必要があります。 TPUバックエンドが適切にセットアップされていることを確認した後、 pip install numpyro
コマンドを使用してNumpyroをインストールできます。
デフォルトプラットフォーム: JAXは、CUDAがサポートしている
jaxlib
パッケージがインストールされている場合、デフォルトでGPUを使用します。 set_platform utilitynumpyro.set_platform("cpu")
使用して、プログラムの開始時にCPUに切り替えることができます。
SourceからNumpyroをインストールすることもできます。
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
CondaでNumpyroをインストールすることもできます。
conda install -c conda-forge numpyro
Pyroとは異なり、 numpyro.sample('x', dist.Normal(0, 1))
機能しません。なぜ?
推論のコンテキスト外でnumpyro.sample
ステートメントを使用している可能性が高いです。 JAXにはグローバルなランダム状態がないため、分布サンプラーには、サンプルを生成するために明示的な乱数ジェネレーターキー(PRNGKEY)が必要です。 Numpyroの推論アルゴリズムは、シードハンドラーを使用して、舞台裏の乱数ジェネレーターキーに糸を塗ります。
あなたのオプションは次のとおりです。
分布を直接呼び出して、 PRNGKey
を提供しますdist.Normal(0, 1).sample(PRNGKey(0))
numpyro.sample
にrng_key
引数を提供します。たとえば、 numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
。
コードをseed
ハンドラーに包み、コンテキストマネージャーとして、または元の呼び出し可能なものをラップする関数として使用します。例えば
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
、または高次関数として:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
Numpyroで推論を行うために同じPyroモデルを使用できますか?
例から気づいたように、Numpyroは、 sample
、 param
、 plate
、 module
、効果ハンドラーなどのすべてのPyroプリミティブをサポートしています。さらに、分布APIがtorch.distributions
に基づいていることを確認し、 SVI
やMCMC
などの推論クラスに同じインターフェイスがあります。これは、numpyおよびpytorch操作のAPIの類似性とともに、Pyro Primitiveステートメントを含むモデルを、いくつかの小さな変更を伴ういずれかのバックエンドで使用できることを保証します。必要な変更とともにいくつかの違いの例を以下に示します。
torch
操作は、対応するjax.numpy
操作の観点から記述する必要があります。さらに、すべてのtorch
操作にnumpy
カウンターパート(およびその逆)があるわけではなく、APIにわずかな違いがある場合があります。pyro.sample
ステートメントは、前述のように、 seed
ハンドラーに包む必要があります。numpyro.param
使用すると効果がありません。 SVIから最適化されたパラメーター値を取得するには、svi.get_paramsメソッドを使用します。モデル内でparam
ステートメントを使用できることに注意してください。Numpyroは、SVIでモデルを実行するときにオプティマイザーの値を内部的に使用するために、代替エフェクトハンドラーを内部的に使用することに注意してください。ほとんどの小さなモデルでは、Numpyroで推論を実行するために必要な変更はマイナーである必要があります。さらに、Pyro-APIに取り組んでおり、同じコードを書き、Numpyroを含む複数のバックエンドに派遣できます。これは必然的により制限的になりますが、不可知論者になるという利点があります。例のドキュメントを参照して、フィードバックをお知らせください。
どうすればプロジェクトに貢献できますか?
プロジェクトに興味を持ってくれてありがとう! GitHubの優れた最初の問題タグでマークされた初心者向けの問題をご覧ください。また、フォーラムで私たちに連絡してください。
近期では、以下に取り組む予定です。機能のリクエストと拡張機能については、新しい問題を開いてください。
Numpyroの背後にある動機付けのアイデアと反復ナッツの説明は、機械学習ワークショップのNeurips 2019プログラム変換に掲載されたこのペーパーに記載されています。
Numpyroを使用している場合は、引用を検討してください。
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
同様に
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}