Programmation probabiliste alimentée par JAX pour la compilation Autograd et JIT à GPU / TPU / CPU.
Docs et exemples | Forum
Numpyro est une bibliothèque de programmation probabiliste légère qui fournit un backend Numpy pour Pyro. Nous comptons sur JAX pour la différenciation automatique et la compilation JIT au GPU / CPU. Numpyro est en cours de développement actif, alors méfiez-vous de la fragilité, des bogues et des modifications de l'API à mesure que la conception évolue.
Numpyro est conçu pour être léger et se concentre sur la fourniture d'un substrat flexible sur lequel les utilisateurs peuvent construire:
sample
et param
. Le code du modèle doit ressembler très à Pyro, à l'exception de quelques différences mineures entre Pytorch et l'API de Numpy. Voir l'exemple ci-dessous.jit
et grad
pour compiler toute l'étape d'intégration dans un noyau optimisé XLA. Nous éliminons également les frais généraux de Python en compilant le stade entier de la construction d'arbres dans les noix (cela est possible en utilisant des noix itératives). Il existe également une mise en œuvre de base de l'inférence variationnelle avec de nombreux guides flexibles (auto) pour l'inférence variationnelle de différenciation automatique (ADVI). L'implémentation de l'inférence variationnelle prend en charge un certain nombre de fonctionnalités, y compris la prise en charge des modèles avec des variables latentes discrètes (voir TraceGraph_elbo et Traceenum_elbo).torch.distributions
. En plus des distributions, constraints
et transforms
sont très utiles lors du fonctionnement des classes de distribution avec un support limité. Enfin, les distributions de la probabilité TensorFlow (TFP) peuvent être directement utilisées dans les modèles Numpyro.sample
et param
peuvent être fournies des interprétations non standard en utilisant des aileurs d'effet du module Numpyro.Handlers, et ceux-ci peuvent être facilement étendus pour implémenter des algorithmes d'inférence personnalisés et des utilitaires d'inférence. Explorons Numpyro en utilisant un exemple simple. Nous utiliserons l'exemple des huit écoles de Gelman et al., Bayesian Data Analysis: Sec. 5.5, 2003, qui étudie l'effet du coaching sur les performances SAT dans huit écoles.
Les données sont données par:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
, où y
sont les effets du traitement et sigma
l'erreur standard. Nous construisons un modèle hiérarchique pour l'étude où nous supposons que les paramètres au niveau du groupe theta
pour chaque école sont échantillonnés à partir d'une distribution normale avec mu
moyen inconnu et tau
d'écart type, tandis que les données observées sont à leur tour générées à partir d'une distribution normale avec une moyenne avec moyenne et l'écart type donné par theta
(effet réel) et sigma
, respectivement. Cela nous permet d'estimer les paramètres au niveau de la population mu
et tau
en se présentant à partir de toutes les observations, tout en permettant une variation individuelle entre les écoles en utilisant les paramètres theta
au niveau du groupe.
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
Inférons les valeurs des paramètres inconnus dans notre modèle en exécutant MCMC en utilisant l'échantillonneur sans tour (noix). Remarquez l'utilisation de l'argument extra_fields
dans mcmc.run. Par défaut, nous collectons uniquement des échantillons à partir de la distribution cible (postérieure) lorsque nous exécutons l'inférence en utilisant MCMC
. Cependant, la collecte de champs supplémentaires comme l'énergie potentielle ou la probabilité d'acceptation d'un échantillon peuvent être facilement réalisées en utilisant l'argument extra_fields
. Pour une liste de champs possibles qui peuvent être collectés, consultez l'objet HMCState. Dans cet exemple, nous collecterons également le potential_energy
pour chaque échantillon.
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
Nous pouvons imprimer le résumé de la course MCMC et examiner si nous avons observé des divergences pendant l'inférence. De plus, puisque nous avons collecté l'énergie potentielle pour chacun des échantillons, nous pouvons facilement calculer la densité de joint logarithmique attendue.
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
Les valeurs supérieures à 1 pour le diagnostic de Gelman Rubin divisé ( r_hat
) indiquent que la chaîne n'a pas complètement convergé. La faible valeur pour la taille effective de l'échantillon ( n_eff
), en particulier pour tau
, et le nombre de transitions divergentes semble problématique. Heureusement, il s'agit d'une pathologie courante qui peut être rectifiée en utilisant une paramétrisation non centrée pour tau
dans notre modèle. Ceci est simple à faire dans Numpyro en utilisant une instance de distribution transformée avec un gestionnaire d'effet de réparamétrisation. Reprécions le même modèle mais au lieu de l'échantillonnage theta
à partir d'une Normal(mu, tau)
, nous le goûterons à partir d'une distribution Normal(0, 1)
qui est transformée à l'aide d'un AffineTransform. Notez qu'en faisant cela, Numpyro exécute HMC en générant des échantillons theta_base
pour la distribution Normal(0, 1)
à la place. Nous voyons que la chaîne résultante ne souffre pas de la même pathologie - le diagnostic Gelman Rubin est 1 pour tous les paramètres et la taille effective de l'échantillon semble assez belle!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
Notez que pour la classe de distributions avec loc,scale
tels que Normal
, Cauchy
, StudentT
, nous fournissons également un réparamètre locscalereparam pour atteindre le même objectif. Le code correspondant sera
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
Maintenant, supposons que nous avons une nouvelle école pour laquelle nous n'avons observé aucun score des tests, mais nous aimerions générer des prédictions. Numpyro fournit une classe prédictive à un tel but. Notez qu'en l'absence de toutes les données observées, nous utilisons simplement les paramètres au niveau de la population pour générer des prédictions. L'utilité Predictive
conditionne les sites mu
et tau
non observés aux valeurs tirées de la distribution postérieure de notre dernière exécution MCMC, et exécute le modèle vers l'avant pour générer des prédictions.
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
Pour quelques exemples supplémentaires sur la spécification des modèles et l'inférence dans Numpyro:
lax.scan
de Jax pour une inférence rapide.Les utilisateurs de Pyro noteront que l'API pour la spécification et l'inférence du modèle est en grande partie la même que Pyro, y compris l'API Distributions, par conception. Cependant, il existe des différences de base importantes (reflétées dans les internes) dont les utilisateurs doivent être conscients. Par exemple dans Numpyro, il n'y a pas de magasin de paramètres global ou d'état aléatoire, pour nous permettre de tirer parti de la compilation JIT de Jax. En outre, les utilisateurs peuvent avoir besoin d'écrire leurs modèles dans un style plus fonctionnel qui fonctionne mieux avec Jax. Reportez-vous aux FAQ pour une liste de différences.
Nous donnons un aperçu de la plupart des algorithmes d'inférence pris en charge par Numpyro et proposons quelques directives sur les algorithmes d'inférence peut être approprié pour différentes classes de modèles.
Comme HMC / Nuts, tous les algorithmes MCMC restants prennent en charge l'énumération sur des variables latentes discrètes si possible (voir les restrictions). Les sites énumérés doivent être marqués avec infer={'enumerate': 'parallel'}
comme dans l'exemple d'annotation.
Trace_ELBO
mais calcule une partie de l'ELBO analytiquement si cela est possible.Voir les documents pour plus de détails.
Prise en charge des fenêtres limitées: notez que Numpyro n'est pas testé sur les fenêtres et peut nécessiter la construction de Jaxlib de Source. Voir ce problème JAX pour plus de détails. Alternativement, vous pouvez installer le sous-système Windows pour Linux et utiliser Numpyro dessus comme sur un système Linux. Voir aussi CUDA sur le sous-système Windows pour Linux et ce message de forum si vous souhaitez utiliser des GPU sur Windows.
Pour installer Numpyro avec la dernière version CPU de Jax, vous pouvez utiliser PIP:
pip install numpyro
En cas de problèmes de compatibilité survenant lors de l'exécution de la commande ci-dessus, vous pouvez plutôt forcer l'installation d'une version CPU compatible connue de JAX avec
pip install numpyro[cpu]
Pour utiliser Numpyro sur le GPU , vous devez d'abord installer CUDA, puis utiliser la commande PIP suivante:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Si vous avez besoin de conseils supplémentaires, veuillez consulter les instructions d'installation du GPU JAX.
Pour exécuter Numpyro sur les TPU cloud , vous pouvez consulter des exemples JAX sur Cloud TPU.
Pour Cloud TPU VM, vous devez configurer le backend TPU comme détaillé dans le Guide Cloud TPU VM JAX QuickStart. Après avoir vérifié que le backend TPU est correctement configuré, vous pouvez installer Numpyro à l'aide de la commande pip install numpyro
.
Plateforme par défaut: JAX utilisera GPU par défaut si le package
jaxlib
soutenu par CUDA est installé. Vous pouvez utiliser set_platform utilitairenumpyro.set_platform("cpu")
pour passer au CPU au début de votre programme.
Vous pouvez également installer Numpyro à partir de la source:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
Vous pouvez également installer Numpyro avec conda:
conda install -c conda-forge numpyro
Contrairement à Pyro, numpyro.sample('x', dist.Normal(0, 1))
ne fonctionne pas. Pourquoi?
Vous utilisez très probablement une déclaration numpyro.sample
en dehors d'un contexte d'inférence. JAX n'a pas d'état aléatoire global et, en tant que tel, les échantillonneurs de distribution ont besoin d'une clé de générateur de nombres aléatoires explicites (PRNGKE) pour générer des échantillons à partir de. Les algorithmes d'inférence de Numpyro utilisent le gestionnaire de graines pour enfiler dans une touche de générateur de nombres aléatoires, dans les coulisses.
Vos options sont:
Appelez directement la distribution et fournissez un PRNGKey
, par exemple dist.Normal(0, 1).sample(PRNGKey(0))
Fournissez l'argument rng_key
à numpyro.sample
. par exemple numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
.
Enveloppez le code dans un gestionnaire seed
, utilisé soit comme gestionnaire de contexte, soit comme une fonction qui s'enroule sur le callable d'origine. par exemple
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
, ou comme fonction d'ordre supérieur:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
Puis-je utiliser le même modèle pyro pour faire l'inférence dans Numpyro?
Comme vous l'avez peut-être remarqué à partir des exemples, Numpyro prend en charge toutes les primitives pyro comme sample
, param
, plate
et module
et les gestionnaires d'effet. De plus, nous nous sommes assurés que l'API de distribution est basée sur torch.distributions
, et les classes d'inférence comme SVI
et MCMC
ont la même interface. Ceci ainsi que la similitude de l'API pour les opérations Numpy et Pytorch garantissent que les modèles contenant des instructions primitives pyro peuvent être utilisés avec l'un ou l'autre backend avec quelques changements mineurs. Exemple de certaines différences ainsi que les changements nécessaires, sont notés ci-dessous:
torch
dans votre modèle devra être écrite en termes de fonctionnement jax.numpy
correspondant. De plus, toutes les opérations torch
n'ont pas un homologue numpy
(et vice-versa), et parfois il existe des différences mineures dans l'API.pyro.sample
en dehors d'un contexte d'inférence devront être enveloppées dans un gestionnaire seed
, comme mentionné ci-dessus.numpyro.param
en dehors d'un contexte d'inférence n'aura aucun effet. Pour récupérer les valeurs de paramètres optimisées à partir de SVI, utilisez la méthode svi.get_params. Notez que vous pouvez toujours utiliser des instructions param
dans un modèle et Numpyro utilisera le gestionnaire d'effet de substitut en interne pour remplacer les valeurs de l'optimiseur lors de l'exécution du modèle dans SVI.Pour la plupart des petits modèles, les modifications requises pour exécuter l'inférence dans Numpyro devraient être mineures. De plus, nous travaillons sur Pyro-API qui vous permet d'écrire le même code et de le expédier à plusieurs backends, y compris Numpyro. Cela sera nécessairement plus restrictif, mais a l'avantage d'être backend agnostique. Voir la documentation pour un exemple et faites-nous part de vos commentaires.
Comment puis-je contribuer au projet?
Merci de votre intérêt pour le projet! Vous pouvez jeter un œil aux problèmes adaptés aux débutants qui sont marqués par la bonne étiquette de premier numéro sur Github. Soyez également sensible à nous contacter sur le forum.
À court terme, nous prévoyons de travailler sur ce qui suit. Veuillez ouvrir de nouveaux problèmes pour les demandes de fonctionnalités et les améliorations:
Les idées de motivation derrière Numpyro et une description des noix itératives peuvent être trouvées dans cet article qui est apparu dans les transformations de programme de Neirips 2019 pour l'atelier d'apprentissage automatique.
Si vous utilisez Numpyro, veuillez envisager de citer:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
ainsi que
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}