Programación probabilística alimentada por JAX para Autograd y JIT Compilación a GPU/TPU/CPU.
Documentos y ejemplos | Foro
Numpyro es una biblioteca de programación probabilística ligera que proporciona un backend numpy para Pyro. Confiamos en Jax para la diferenciación automática y la compilación JIT a GPU / CPU. Numpyro está en desarrollo activo, así que tenga cuidado con la fragilidad, los errores y los cambios en la API a medida que evoluciona el diseño.
Numpyro está diseñado para ser liviano y se centra en proporcionar un sustrato flexible en el que los usuarios puedan construir:
sample
y param
. El código del modelo debe verse muy similar a Pyro, excepto por algunas diferencias menores entre Pytorch y la API de Numpy. Vea el ejemplo a continuación.jit
y grad
para compilar todo el paso de integración en un núcleo optimizado XLA. También eliminamos la sobrecarga de Python al compilar toda la etapa de construcción de árboles en las nueces (esto es posible usando nueces iterativas). También hay una implementación de inferencia de variacional básica junto con muchas guías flexibles (auto) para la inferencia de variacional de diferenciación automática (AVENI). La implementación de inferencia variacional admite una serie de características, incluido el soporte para modelos con variables latentes discretas (ver TraceGraph_elbo y TraceEnum_elbo).torch.distributions
. Además de las distribuciones, constraints
y transforms
son muy útiles cuando se operan en clases de distribución con soporte limitado. Finalmente, las distribuciones de la probabilidad de flujo de tensor (TFP) se pueden usar directamente en modelos Numpyro.sample
y param
se pueden proporcionar interpretaciones no estándar utilizando manejadores de efectos del módulo numpyro.handlers, y estas se pueden extender fácilmente para implementar algoritmos de inferencia y utilidades de inferencia de inferencia personalizadas. Exploremos Numpyro usando un ejemplo simple. Usaremos el ejemplo de las ocho escuelas de Gelman et al., Análisis de datos bayesianos: Sec. 5.5, 2003, que estudia el efecto del entrenamiento en el rendimiento del SAT en ocho escuelas.
Los datos están dados por:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
, donde y
son los efectos del tratamiento y sigma
el error estándar. Construimos tau
modelo jerárquico para el estudio donde suponemos que los theta
a nivel de grupo se muestrean a partir de una distribución normal con mu
media desconocida y desviación estándar, mientras que los datos observados se generan a partir de una distribución normal con media y desviación estándar dada por theta
(efecto verdadero) y sigma
, respectivamente. Esto nos permite estimar los parámetros a nivel de población mu
y tau
al agrupar todas las observaciones, al tiempo que permiten la variación individual entre las escuelas que utilizan los parámetros theta
a nivel de grupo.
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
Inferimos los valores de los parámetros desconocidos en nuestro modelo ejecutando MCMC utilizando la muestra de no-u-girath (nueces). Tenga en cuenta el uso del argumento extra_fields
en McMc.run. Por defecto, solo recolectamos muestras de la distribución de destino (posterior) cuando ejecutamos inferencia usando MCMC
. Sin embargo, la recolección de campos adicionales como la energía potencial o la probabilidad de aceptación de una muestra se puede lograr fácilmente utilizando el argumento extra_fields
. Para obtener una lista de posibles campos que se pueden recopilar, consulte el objeto HMCState. En este ejemplo, además recolectaremos el potential_energy
para cada muestra.
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
Podemos imprimir el resumen de la ejecución de MCMC y examinar si observamos alguna divergencia durante la inferencia. Además, dado que recolectamos la energía potencial para cada una de las muestras, podemos calcular fácilmente la densidad de la junta de registro esperada.
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
Los valores anteriores 1 para el diagnóstico dividido de Gelman Rubin ( r_hat
) indican que la cadena no ha convergido completamente. El bajo valor para el tamaño de muestra efectivo ( n_eff
), particularmente para tau
, y el número de transiciones divergentes parece problemático. Afortunadamente, esta es una patología común que puede rectificarse mediante el uso de una parametrización no centrada para tau
en nuestro modelo. Esto es sencillo en Numpyro mediante el uso de una instancia de distribución transformada junto con un controlador de efecto de reparameterización. Reescribamos el mismo modelo, pero en lugar de muestrear theta
de un Normal(mu, tau)
, en su lugar lo probaremos de una distribución base Normal(0, 1)
que se transforma utilizando una affinetransform. Tenga en cuenta que al hacerlo, Numpyro ejecuta HMC generando muestras theta_base
para la distribución base Normal(0, 1)
en su lugar. Vemos que la cadena resultante no sufre de la misma patología: el diagnóstico de Gelman Rubin es 1 para todos los parámetros y el tamaño de muestra efectivo se ve bastante bien.
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
Tenga en cuenta que para la clase de distribuciones con loc,scale
como Normal
, Cauchy
, StudentT
, también proporcionamos un reparador de LocScalereParam para lograr el mismo propósito. El código correspondiente será
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
Ahora, supongamos que tenemos una nueva escuela para la cual no hemos observado ningún puntaje de prueba, pero nos gustaría generar predicciones. Numpyro proporciona una clase predictiva para tal propósito. Tenga en cuenta que en ausencia de cualquier datos observados, simplemente utilizamos los parámetros a nivel de población para generar predicciones. La utilidad Predictive
condiciona los sitios mu
y tau
no observados a los valores extraídos de la distribución posterior de nuestra última ejecución de MCMC, y ejecuta el modelo hacia adelante para generar predicciones.
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
Para algunos ejemplos más sobre especificar modelos y hacer inferencia en Numpyro:
lax.scan
primitivo de Jax para una inferencia rápida.Los usuarios de PYRO notarán que la API para la especificación e inferencia del modelo es en gran medida lo mismo que Pyro, incluida la API de distribuciones, por diseño. Sin embargo, existen algunas diferencias centrales importantes (reflejadas en las partes internas) que los usuarios deben tener en cuenta. Por ejemplo, en Numpyro, no existe una tienda de parámetros global o un estado aleatorio, para que sea posible que aprovechemos la compilación JAT de Jax. Además, los usuarios pueden necesitar escribir sus modelos en un estilo más funcional que funcione mejor con Jax. Consulte las Preguntas frecuentes para obtener una lista de diferencias.
Proporcionamos una visión general de la mayoría de los algoritmos de inferencia compatibles con Numpyro y ofrecemos algunas pautas sobre qué algoritmos de inferencia pueden ser apropiados para diferentes clases de modelos.
Al igual que HMC/Nuts, todos los algoritmos MCMC restantes respaldan la enumeración sobre variables latentes discretas si es posible (ver restricciones). Los sitios enumerados deben marcarse con infer={'enumerate': 'parallel'}
como en el ejemplo de anotación.
Trace_ELBO
pero calcula parte del Elbo analíticamente si es posible hacerlo.Vea los documentos para obtener más detalles.
Soporte limitado de Windows: Tenga en cuenta que Numpyro no se ha probado en Windows y puede requerir construir Jaxlib desde la fuente. Vea este tema de Jax para más detalles. Alternativamente, puede instalar el subsistema de Windows para Linux y usar Numpyro en él como en un sistema Linux. Consulte también CUDA en el subsistema de Windows para Linux y esta publicación del foro si desea usar GPU en Windows.
Para instalar Numpyro con la última versión de CPU de Jax, puede usar PIP:
pip install numpyro
En caso de que surjan problemas de compatibilidad durante la ejecución del comando anterior, puede forzar la instalación de una versión de CPU compatible conocida de Jax con
pip install numpyro[cpu]
Para usar Numpyro en la GPU , primero debe instalar CUDA y luego usar el siguiente comando PIP:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Si necesita más orientación, eche un vistazo a las instrucciones de instalación de JAX GPU.
Para ejecutar Numpyro en TPUS en la nube , puede ver algunos ejemplos de Jax en la TPU de la nube.
Para Cloud TPU VM, debe configurar el backend de TPU como se detalla en la guía Cloud TPU VM Jax QuickStart. Después de haber verificado que el backend de TPU está configurado correctamente, puede instalar Numpyro utilizando el comando pip install numpyro
.
Plataforma predeterminada: Jax usará GPU de forma predeterminada si se instala el paquete
jaxlib
respaldado por CUDA. Puede usar SET_PLATFORM Utilitynumpyro.set_platform("cpu")
para cambiar a CPU al comienzo de su programa.
También puede instalar numpyro desde la fuente:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
También puede instalar numpyro con conda:
conda install -c conda-forge numpyro
A diferencia de Pyro, numpyro.sample('x', dist.Normal(0, 1))
no funciona. ¿Por qué?
Lo más probable es que esté utilizando una declaración numpyro.sample
fuera de un contexto de inferencia. Jax no tiene un estado aleatorio global y, como tal, los muestreadores de distribución necesitan una clave explícita del generador de números aleatorios (PRNGKEY) para generar muestras desde. Los algoritmos de inferencia de Numpyro usan el controlador de semillas para enhebrar en una tecla de generador de números aleatorios, detrás de escena.
Tus opciones son:
Llame a la distribución directamente y proporcione un PRNGKey
, por ejemplo, dist.Normal(0, 1).sample(PRNGKey(0))
Proporcione el argumento rng_key
a numpyro.sample
. por ejemplo numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
.
Envuelva el código en un controlador seed
, utilizado como administrador de contexto o en función que se envuelva sobre el invocado original. p.ej
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
, o como una función de orden superior:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
¿Puedo usar el mismo modelo Pyro para hacer inferencia en Numpyro?
Como puede haber notado en los ejemplos, Numpyro admite todas las primitivas de Pyro como sample
, param
, plate
y module
, y los manejadores de efectos. Además, hemos asegurado que la API de distribuciones se basa en torch.distributions
, y las clases de inferencia como SVI
y MCMC
tienen la misma interfaz. Esto junto con la similitud en la API para las operaciones Numpy y Pytorch asegura que los modelos que contienen declaraciones Pyro primitivas se pueden usar con backend con algunos cambios menores. El ejemplo de algunas diferencias junto con los cambios necesarios se indican a continuación:
torch
en su modelo deberá escribirse en términos de la operación jax.numpy
correspondiente. Además, no todas las operaciones torch
tienen una contraparte numpy
(y viceversa), y a veces hay diferencias menores en la API.pyro.sample
fuera de un contexto de inferencia deberán envolverse en un controlador seed
, como se mencionó anteriormente.numpyro.param
fuera de un contexto de inferencia no tendrá ningún efecto. Para recuperar los valores de parámetros optimizados de SVI, use el método svi.get_params. Tenga en cuenta que aún puede usar las declaraciones param
dentro de un modelo y Numpyro usará el controlador de efecto sustituto internamente para sustituir los valores del optimizador cuando ejecuta el modelo en SVI.Para la mayoría de los modelos pequeños, los cambios necesarios para ejecutar inferencia en Numpyro deberían ser menores. Además, estamos trabajando en Pyro-API, que le permite escribir el mismo código y enviarlo a múltiples backends, incluido Numpyro. Esto será necesariamente más restrictivo, pero tiene la ventaja de ser agnóstico de backend. Vea la documentación para obtener un ejemplo y háganos saber sus comentarios.
¿Cómo puedo contribuir al proyecto?
¡Gracias por su interés en el proyecto! Puede echar un vistazo a los problemas amigables para principiantes que están marcados con la buena etiqueta de primer problema en GitHub. Además, sienta que nos comunique con nosotros en el foro.
En el corto plazo, planeamos trabajar en lo siguiente. Abra nuevos problemas para las solicitudes y mejoras de funciones:
Las ideas motivadoras detrás de Numpyro y una descripción de las nueces iterativas se pueden encontrar en este documento que apareció en las transformaciones del programa Neurips 2019 para el taller de aprendizaje automático.
Si usa Numpyro, considere citar:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
así como
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}