Probabilistische Programmierung von JAX für Autograd und JIT -Kompilierung für GPU/TPU/CPU.
Dokumente und Beispiele | Forum
Numpyro ist eine leichte probabilistische Programmierbibliothek, die Pyro ein numpy Backend bietet. Wir verlassen uns auf JAX zur automatischen Differenzierung und JIT -Zusammenstellung zu GPU / CPU. Numpyro steht in einer aktiven Entwicklung, also hüten Sie sich vor Bröckel, Fehler und Änderungen an der API, während sich das Design entwickelt.
Numpyro ist so konzipiert, dass es leicht ist, und konzentriert sich auf ein flexibles Substrat, auf dem Benutzer aufbauen können:
sample
und param
regelmäßig Python- und Numpy -Code enthalten. Der Modellcode sollte Pyro mit Ausnahme einiger kleiner Unterschiede zwischen Pytorch und Numpys API sehr ähnlich aussehen. Siehe das Beispiel unten.jit
zusammenstellen und den gesamten Integrationsschritt in einen XLA -optimierten Kernel grad
. Wir eliminieren auch Python -Overhead, indem wir die gesamte Baumgebäudestufe in Nüssen zusammenstellen (dies ist mit iterativen Nüssen möglich). Es gibt auch eine grundlegende Implementierung von Variationsinferenz zusammen mit vielen flexiblen (AUTO) -Luides zur automatischen Differenzierungsvariationsinferenz (ADVI). Die Implementierung der Variationsinferenz unterstützt eine Reihe von Funktionen, einschließlich der Unterstützung für Modelle mit diskreten latenten Variablen (siehe Tracegraph_elbo und Traceenum_elbo).torch.distributions
verlassen. Zusätzlich zu Verteilungen sind constraints
und transforms
sehr nützlich, wenn sie in Verteilungsklassen mit begrenzter Unterstützung operieren. Schließlich können Verteilungen aus der Tensorflow -Wahrscheinlichkeit (TFP) direkt in Numpyro -Modellen verwendet werden.sample
und param
nicht standardmäßige Interpretationen unter Verwendung von Effekt-Händlern aus dem Numpyro.Handlers-Modul bereitgestellt werden, und diese können leicht erweitert werden, um benutzerdefinierte Inferenzalgorithmen und Inferenz-Dienstprogramme zu implementieren. Lassen Sie uns Numpyro anhand eines einfachen Beispiels erkunden. Wir werden das Beispiel für acht Schulen von Gelman et al., Bayes'sche Datenanalyse: Sec. 5.5, 2003, in dem die Auswirkungen des Coachings auf die SAT -Leistung in acht Schulen untersucht werden.
Die Daten sind gegeben durch:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
, wo y
die Behandlungseffekte und sigma
der Standardfehler sind. Wir erstellen ein hierarchisches Modell für die Studie, in der wir annehmen, dass die Parameter theta
Gruppenebene für jede Schule aus einer Normalverteilung mit unbekannter mittlerer mu
und Standardabweichung tau
abgetastet werden, während die beobachteten Daten wiederum aus einer Normalverteilung mit Mittelwert erzeugt werden und Standardabweichung durch theta
(wahrer Effekt) bzw. sigma
. Dies ermöglicht es uns, die Parameter auf Bevölkerungsebene zu schätzen tau
indem mu
von allen Beobachtungen zusammengefasst sind und gleichzeitig die individuellen Unterschiede zwischen den Schulen unter Verwendung der theta
-Parameter auf Gruppenebene ermöglicht.
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
Lassen Sie uns die Werte der unbekannten Parameter in unserem Modell schließen, indem wir MCMC unter Verwendung des No-U-Turn-Samplers (Muttern) ausführen. Beachten Sie die Verwendung des Arguments extra_fields
in mcmc.run. Standardmäßig sammeln wir nur Proben aus der Zielverteilung (posterior), wenn wir Inferenz mit MCMC
ausführen. Das Sammeln zusätzlicher Felder wie potentielle Energie oder die Akzeptanzwahrscheinlichkeit einer Stichprobe kann jedoch durch die Verwendung des Arguments extra_fields
leicht erreicht werden. Eine Liste möglicher Felder, die gesammelt werden können, finden Sie im HMCState -Objekt. In diesem Beispiel sammeln wir zusätzlich die potential_energy
für jede Probe.
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
Wir können die Zusammenfassung des MCMC -Laufs drucken und untersuchen, ob wir während der Inferenz Divergenzen beobachtet haben. Da wir die potentielle Energie für jede der Proben gesammelt haben, können wir die erwartete Log -Gelenkdichte leicht berechnen.
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
Die Werte über 1 für die geteilte Gelman Rubin Diagnostic ( r_hat
) geben an, dass die Kette nicht vollständig konvergiert ist. Der niedrige Wert für die effektive Stichprobengröße ( n_eff
), insbesondere für tau
, und die Anzahl der unterschiedlichen Übergänge sieht problematisch aus. Glücklicherweise ist dies eine gemeinsame Pathologie, die durch Verwendung einer nicht zentrierten Parametrisierung für tau
in unserem Modell behoben werden kann. Dies ist in Numpyro unkompliziert, indem eine Transformeddistribution -Instanz zusammen mit einem Reparameterizationseffekt -Handler verwendet wird. Schreiben wir dasselbe Modell um, aber anstatt theta
aus einer Normal(mu, tau)
zu probieren, werden wir es stattdessen von einer Normal(0, 1)
-Ververteilung probieren, die unter Verwendung einer Affinetransform transformiert wird. Beachten Sie, dass Numpyro auf diese Weise HMC ausführt, indem Sie stattdessen die Proben theta_base
für die Normal(0, 1)
-D -Verteilung generieren. Wir sehen, dass die resultierende Kette nicht unter derselben Pathologie leidet - die Gelman Rubin -Diagnose ist 1 für alle Parameter und die effektive Stichprobengröße sieht ziemlich gut aus!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
Beachten Sie, dass wir für die Klasse der Verteilungen mit loc,scale
wie Normal
, Cauchy
, StudentT
auch einen LocscalereParam Reparameterizer zur Verfügung stellen, um den gleichen Zweck zu erreichen. Der entsprechende Code wird sein
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
Nehmen wir nun an, dass wir eine neue Schule haben, für die wir keine Testergebnisse beobachtet haben, aber wir möchten Vorhersagen generieren. Numpyro bietet eine Vorhersageklasse für einen solchen Zweck. Beachten Sie, dass wir ohne beobachtete Daten einfach die Parameter auf Bevölkerungsebene verwenden, um Vorhersagen zu erzeugen. Die Predictive
Versorgungsbedingungen Die nicht beobachteten mu
und tau
-Stellen zu Werten aus der hinteren Verteilung unseres letzten MCMC -Laufs und führt das Modell vor, um Vorhersagen zu generieren.
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
Weitere Beispiele zum Angeben von Modellen und zur Folgerung in Numpyro:
lax.scan
primitive für schnelle Inferenz konvertiert wird.Pyro -Benutzer werden feststellen, dass die API für die Modellspezifikation und die Inferenz weitgehend mit Pyro, einschließlich der Distributions -API, nach Design entspricht. Es gibt jedoch einige wichtige Kernunterschiede (die sich in den Interna widerspiegeln), die Benutzer kennen sollten. In Numpyro gibt es keinen globalen Parameterspeicher oder zufälligen Zustand, um es uns zu ermöglichen, die JIT -Zusammenstellung von JAX zu nutzen. Außerdem müssen Benutzer ihre Modelle möglicherweise in einem funktionalen Stil schreiben, der mit JAX besser funktioniert. Eine Liste von Unterschieden finden Sie in FAQs.
Wir bieten einen Überblick über die meisten Inferenzalgorithmen, die von Numpyro unterstützt werden, und bieten einige Richtlinien darüber, welche Inferenzalgorithmen für verschiedene Modelleklassen geeignet sein können.
Wie bei HMC/Nuts unterstützen alle verbleibenden MCMC -Algorithmen die Aufzählung über diskrete latente Variablen nach Möglichkeit (siehe Einschränkungen). Aufzählende Stellen müssen mit infer={'enumerate': 'parallel'}
wie im Annotationsbeispiel markiert werden.
Trace_ELBO
, berechnet jedoch einen Teil des ELBO analytisch, wenn dies möglich ist.Weitere Informationen finden Sie in den Dokumenten.
Begrenzte Windows -Unterstützung: Beachten Sie, dass Numpyro unter Windows nicht getestet wird und möglicherweise Jaxlib aus der Quelle erstellt. Weitere Informationen finden Sie in dieser JAX -Ausgabe. Alternativ können Sie Windows -Subsystem für Linux installieren und Numpyro als Linux -System verwenden. Siehe auch CUDA auf Windows Subsystem für Linux und in diesem Forum -Beitrag, wenn Sie GPUs unter Windows verwenden möchten.
Um Numpyro mit der neuesten CPU -Version von JAX zu installieren, können Sie PIP verwenden:
pip install numpyro
Bei Kompatibilitätsproblemen entstehen während der Ausführung des obigen Befehls Sie stattdessen die Installation einer bekannten kompatiblen CPU -Version von JAX mit erzwingen
pip install numpyro[cpu]
Um Numpyro auf der GPU zu verwenden, müssen Sie zuerst CUDA installieren und dann den folgenden PIP -Befehl verwenden:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Wenn Sie weitere Anleitungen benötigen, sehen Sie sich die Anweisungen zur Installation von JAX GPU an.
Um Numpyro auf Cloud -TPUs auszuführen, können Sie sich mit Cloud -TPU -Beispielen JAX ansehen.
Für Cloud TPU VM müssen Sie das TPU -Backend wie in der Cloud TPU VM JAX QuickStart -Handbuch detailliert einrichten. Nachdem Sie überprüft haben, ob das TPU -Backend ordnungsgemäß eingerichtet ist, können Sie Numpyro über den Befehl pip install numpyro
installieren.
Standardplattform: JAX wird standardmäßig GPU verwenden, wenn das CUDA-unterstützte
jaxlib
Paket installiert ist. Sie können SET_PLATFORM Utilitynumpyro.set_platform("cpu")
zu Beginn Ihres Programms zu CPU umstellen.
Sie können Numpyro auch aus Quelle installieren:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
Sie können Numpyro auch mit Conda installieren:
conda install -c conda-forge numpyro
Im Gegensatz zu Pyro funktioniert numpyro.sample('x', dist.Normal(0, 1))
nicht. Warum?
Sie verwenden höchstwahrscheinlich eine Anweisung numpyro.sample
außerhalb eines Inferenzkontexts. JAX hat keinen globalen Zufallszustand, und als solche benötigen Verteilungsabtastproben einen expliziten Zufallszahlengenerator -Schlüssel (Prngkey), um Proben aus zu generieren. Inferenzalgorithmen von Numpyro verwenden den Saatguthandler, um in einer zufälligen Zahlengenerator -Taste hinter den Kulissen zu fädeln.
Ihre Optionen sind:
Rufen Sie die Verteilung direkt an und geben Sie einen PRNGKey
an, z dist.Normal(0, 1).sample(PRNGKey(0))
Geben Sie das Argument rng_key
an numpyro.sample
an. numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
Wickeln Sie den Code in einen seed
-Handler ein, der entweder als Kontextmanager oder als eine Funktion verwendet wird, die sich über den ursprünglichen Callable wickelt. z.B
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
, oder als Funktion höherer Ordnung:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
Kann ich das gleiche Pyro -Modell verwenden, um in Numpyro Schlussfolgerungen zu erzielen?
Wie Sie aus den Beispielen vielleicht bemerkt haben, unterstützt Numpyro alle Pyro -Primitiven wie sample
, param
, plate
und module
und Effekt -Handler. Darüber hinaus haben wir sichergestellt, dass die API der Distributions -API auf torch.distributions
basiert, und die Inferenzklassen wie SVI
und MCMC
haben dieselbe Schnittstelle. Dies zusammen mit der Ähnlichkeit in der API für Numpy- und Pytorch -Operationen stellt sicher, dass Modelle, die Pyro -primitive Aussagen enthalten, mit beiden Backends mit einigen geringfügigen Änderungen verwendet werden können. Beispiel für einige Unterschiede sowie die erforderlichen Änderungen sind nachstehend festgestellt:
torch
in Ihrem Modell muss in Bezug auf die entsprechende jax.numpy
-Operation geschrieben werden. Darüber hinaus haben nicht alle torch
ein numpy
-Gegenstück (und umgekehrt), und manchmal gibt es geringfügige Unterschiede in der API.pyro.sample
-Aussagen außerhalb eines Inferenzkontexts müssen wie oben erwähnt in einen seed
eingewickelt werden.numpyro.param
außerhalb eines Inferenzkontexts keinen Einfluss. Um die optimierten Parameterwerte von SVI abzurufen, verwenden Sie die Methode SVI.get_params. Beachten Sie, dass Sie in einem Modell weiterhin param
-Anweisungen verwenden können, und Numpyro verwendet den Ersatz -Effekt -Handler intern, um die Werte beim Optimierer beim Ausführen des Modells in SVI zu ersetzen.Für die meisten kleinen Modelle sollten Änderungen, die zur Durchführung von Inferenz in Numpyro erforderlich sind, gering sein. Darüber hinaus arbeiten wir an Pyro-API, mit dem Sie denselben Code schreiben und in mehrere Backends, einschließlich Numpyro, schicken können. Dies wird notwendigerweise restriktiver sein, hat aber den Vorteil, Agnald -Agnald zu sein. Ein Beispiel finden Sie in der Dokumentation und teilen Sie uns Ihr Feedback mit.
Wie kann ich zum Projekt beitragen?
Vielen Dank für Ihr Interesse an dem Projekt! Sie können sich anfängerfreundliche Probleme ansehen, die mit dem guten ersten Ausgabe -Tag auf GitHub gekennzeichnet sind. Bitte fühlen Sie sich auch im Forum an uns.
Kurzfristig planen wir, am Folgenden zu arbeiten. Bitte öffnen Sie neue Probleme für Feature -Anfragen und -verbesserungen:
Die motivierenden Ideen hinter Numpyro und eine Beschreibung der iterativen Nüsse finden Sie in diesem Artikel, der in den Neurips 2019 -Programmtransformationen für maschinelles Lernens Workshop aufgetreten ist.
Wenn Sie Numpyro verwenden, erwägen Sie bitte:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
sowie
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}