Programação probabilística alimentada por Jax para compilação AutoGRAD e JIT para GPU/TPU/CPU.
Documentos e exemplos | Fórum
O Numpyro é uma biblioteca de programação probabilística leve que fornece um back -end numpy para pyro. Contamos com Jax para diferenciação automática e compilação JIT em GPU / CPU. O Numpyro está sob desenvolvimento ativo, portanto, cuidado com a fragilidade, os bugs e as alterações na API à medida que o design evolui.
O Numpyro foi projetado para ser leve e se concentra em fornecer um substrato flexível no qual os usuários podem desenvolver:
sample
e param
. O código do modelo deve parecer muito semelhante ao pyro, exceto por algumas pequenas diferenças entre a API de Pytorch e Numpy. Veja o exemplo abaixo.jit
e grad
para compilar toda a etapa de integração em um kernel otimizado para XLA. Também eliminamos a sobrecarga do Python por JIT compilando todo o estágio de construção de árvores em nozes (isso é possível usando nozes iterativas). Há também uma implementação básica de inferência variacional, juntamente com muitos guias flexíveis (automáticos) para a inferência variacional de diferenciação automática (AVI). A implementação de inferência variacional suporta vários recursos, incluindo suporte para modelos com variáveis latentes discretas (consulte Tracegraph_elbo e Traceenum_elbo).torch.distributions
. Além das distribuições, constraints
e transforms
são muito úteis ao operar em classes de distribuição com suporte limitado. Finalmente, as distribuições da Probabilidade do Tensorflow (TFP) podem ser usadas diretamente nos modelos Numpyro.sample
e param
podem ser fornecidos interpretações fora do padrão usando manipuladores de efeitos do módulo Numpyro. Vamos explorar o Numpyro usando um exemplo simples. Usaremos o exemplo de oito escolas de Gelman et al., Análise de dados bayesianos: Sec. 5.5, 2003, que estuda o efeito do treinamento no desempenho do SAT em oito escolas.
Os dados são fornecidos por:
>> > import numpy as np
>> > J = 8
>> > y = np . array ([ 28.0 , 8.0 , - 3.0 , 7.0 , - 1.0 , 1.0 , 18.0 , 12.0 ])
>> > sigma = np . array ([ 15.0 , 10.0 , 16.0 , 11.0 , 9.0 , 11.0 , 10.0 , 18.0 ])
, onde y
estão os efeitos do tratamento e sigma
o erro padrão. Construímos um modelo hierárquico para o estudo em que assumimos que os parâmetros de nível de grupo theta
para cada escola são amostrados de uma distribuição normal com mu
médio desconhecido e tau
de desvio padrão, enquanto os dados observados são gerados por sua vez a partir de uma distribuição normal com média e desvio padrão dado por theta
(efeito verdadeiro) e sigma
, respectivamente. Isso nos permite estimar os parâmetros de nível populacional mu
e tau
, reunindo-se de todas as observações, enquanto ainda permitem variações individuais entre as escolas que usam os parâmetros theta
em nível de grupo.
>> > import numpyro
>> > import numpyro . distributions as dist
>> > # Eight Schools example
... def eight_schools ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... theta = numpyro . sample ( 'theta' , dist . Normal ( mu , tau ))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
Vamos inferir os valores dos parâmetros desconhecidos em nosso modelo executando o MCMC usando o amostrador não-U-Turn (NUTS). Observe o uso do argumento extra_fields
em mcmc.run. Por padrão, coletamos apenas amostras da distribuição de destino (posterior) quando executamos a inferência usando MCMC
. No entanto, a coleta de campos adicionais como energia potencial ou a probabilidade de aceitação de uma amostra pode ser facilmente alcançada usando o argumento extra_fields
. Para uma lista de campos possíveis que podem ser coletados, consulte o objeto HMCState. Neste exemplo, adicionaremos também o potential_energy
para cada amostra.
>> > from jax import random
>> > from numpyro . infer import MCMC , NUTS
>> > nuts_kernel = NUTS ( eight_schools )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
Podemos imprimir o resumo da execução do MCMC e examinar se observamos alguma divergências durante a inferência. Além disso, como coletamos a energia potencial para cada uma das amostras, podemos calcular facilmente a densidade da junta de log esperada.
>> > mcmc . print_summary () # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.14 3.18 3.87 - 0.76 9.50 115.42 1.01
tau 4.12 3.58 3.12 0.51 8.56 90.64 1.02
theta [ 0 ] 6.40 6.22 5.36 - 2.54 15.27 176.75 1.00
theta [ 1 ] 4.96 5.04 4.49 - 1.98 14.22 217.12 1.00
theta [ 2 ] 3.65 5.41 3.31 - 3.47 13.77 247.64 1.00
theta [ 3 ] 4.47 5.29 4.00 - 3.22 12.92 213.36 1.01
theta [ 4 ] 3.22 4.61 3.28 - 3.72 10.93 242.14 1.01
theta [ 5 ] 3.89 4.99 3.71 - 3.39 12.54 206.27 1.00
theta [ 6 ] 6.55 5.72 5.66 - 1.43 15.78 124.57 1.00
theta [ 7 ] 4.81 5.95 4.19 - 3.90 13.40 299.66 1.00
Number of divergences : 19
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 54.55
Os valores acima de 1 para o diagnóstico de Gelman Rubin Split ( r_hat
) indicam que a cadeia não convergiu totalmente. O baixo valor para o tamanho efetivo da amostra ( n_eff
), particularmente para tau
, e o número de transições divergentes parece problemático. Felizmente, essa é uma patologia comum que pode ser retificada usando uma parametrização não centrada para tau
em nosso modelo. Isso é simples de fazer no Numpyro usando uma instância de distribuição transformada juntamente com um manipulador de efeitos de reparametrização. Vamos reescrever o mesmo modelo, mas, em vez de provar theta
a partir de um Normal(mu, tau)
, a amostraremos de uma distribuição Normal(0, 1)
que é transformada usando um afinetransform. Observe que, ao fazer isso, o Numpyro executa o HMC gerando amostras theta_base
para a distribuição Normal(0, 1)
. Vemos que a cadeia resultante não sofre com a mesma patologia - o diagnóstico de Gelman Rubin é 1 para todos os parâmetros e o tamanho efetivo da amostra parece muito bom!
>> > from numpyro . infer . reparam import TransformReparam
>> > # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered ( J , sigma , y = None ):
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... with numpyro . plate ( 'J' , J ):
... with numpyro . handlers . reparam ( config = { 'theta' : TransformReparam ()}):
... theta = numpyro . sample (
... 'theta' ,
... dist . TransformedDistribution ( dist . Normal ( 0. , 1. ),
... dist . transforms . AffineTransform ( mu , tau )))
... numpyro . sample ( 'obs' , dist . Normal ( theta , sigma ), obs = y )
>> > nuts_kernel = NUTS ( eight_schools_noncentered )
>> > mcmc = MCMC ( nuts_kernel , num_warmup = 500 , num_samples = 1000 )
>> > rng_key = random . PRNGKey ( 0 )
>> > mcmc . run ( rng_key , J , sigma , y = y , extra_fields = ( 'potential_energy' ,))
>> > mcmc . print_summary ( exclude_deterministic = False ) # doctest: +SKIP
mean std median 5.0 % 95.0 % n_eff r_hat
mu 4.08 3.51 4.14 - 1.69 9.71 720.43 1.00
tau 3.96 3.31 3.09 0.01 8.34 488.63 1.00
theta [ 0 ] 6.48 5.72 6.08 - 2.53 14.96 801.59 1.00
theta [ 1 ] 4.95 5.10 4.91 - 3.70 12.82 1183.06 1.00
theta [ 2 ] 3.65 5.58 3.72 - 5.71 12.13 581.31 1.00
theta [ 3 ] 4.56 5.04 4.32 - 3.14 12.92 1282.60 1.00
theta [ 4 ] 3.41 4.79 3.47 - 4.16 10.79 801.25 1.00
theta [ 5 ] 3.58 4.80 3.78 - 3.95 11.55 1101.33 1.00
theta [ 6 ] 6.31 5.17 5.75 - 2.93 13.87 1081.11 1.00
theta [ 7 ] 4.81 5.38 4.61 - 3.29 14.05 954.14 1.00
theta_base [ 0 ] 0.41 0.95 0.40 - 1.09 1.95 851.45 1.00
theta_base [ 1 ] 0.15 0.95 0.20 - 1.42 1.66 1568.11 1.00
theta_base [ 2 ] - 0.08 0.98 - 0.10 - 1.68 1.54 1037.16 1.00
theta_base [ 3 ] 0.06 0.89 0.05 - 1.42 1.47 1745.02 1.00
theta_base [ 4 ] - 0.14 0.94 - 0.16 - 1.65 1.45 719.85 1.00
theta_base [ 5 ] - 0.10 0.96 - 0.14 - 1.57 1.51 1128.45 1.00
theta_base [ 6 ] 0.38 0.95 0.42 - 1.32 1.82 1026.50 1.00
theta_base [ 7 ] 0.10 0.97 0.10 - 1.51 1.65 1190.98 1.00
Number of divergences : 0
>> > pe = mcmc . get_extra_fields ()[ 'potential_energy' ]
>> > # Compare with the earlier value
>> > print ( 'Expected log joint density: {:.2f}' . format ( np . mean ( - pe ))) # doctest: +SKIP
Expected log joint density : - 46.09
Observe que, para a classe de distribuições com loc,scale
como Normal
, Cauchy
, StudentT
, também fornecemos um Reparameterizer LocScaleReparam para alcançar o mesmo objetivo. O código correspondente será
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
Agora, vamos supor que temos uma nova escola para a qual não observamos nenhuma notas de teste, mas gostaríamos de gerar previsões. O Numpyro fornece uma classe preditiva para esse fim. Observe que, na ausência de dados observados, simplesmente usamos os parâmetros no nível da população para gerar previsões. A utilidade Predictive
condiciona os locais não observados mu
e tau
para os valores extraídos da distribuição posterior de nossa última execução do MCMC e executa o modelo adiante para gerar previsões.
>> > from numpyro . infer import Predictive
>> > # New School
... def new_school ():
... mu = numpyro . sample ( 'mu' , dist . Normal ( 0 , 5 ))
... tau = numpyro . sample ( 'tau' , dist . HalfCauchy ( 5 ))
... return numpyro . sample ( 'obs' , dist . Normal ( mu , tau ))
>> > predictive = Predictive ( new_school , mcmc . get_samples ())
>> > samples_predictive = predictive ( random . PRNGKey ( 1 ))
>> > print ( np . mean ( samples_predictive [ 'obs' ])) # doctest: +SKIP
3.9886456
Para mais alguns exemplos sobre a especificação de modelos e a inferência no Numpyro:
lax.scan
Primitivo de Jax para inferência rápida.Os usuários do Pyro observarão que a API para especificação e inferência do modelo é amplamente a mesma que o Pyro, incluindo a API de distribuições, por design. No entanto, existem algumas diferenças importantes importantes (refletidas nos internos) que os usuários devem estar cientes. Por exemplo, no Numpyro, não existe uma loja global de parâmetros ou estado aleatório, para possibilitar a compilação JIT da JAX. Além disso, os usuários podem precisar escrever seus modelos em um estilo mais funcional que funciona melhor com o JAX. Consulte as perguntas frequentes para obter uma lista de diferenças.
Fornecemos uma visão geral da maioria dos algoritmos de inferência suportados pelo Numpyro e oferecemos algumas diretrizes sobre quais algoritmos de inferência podem ser apropriados para diferentes classes de modelos.
Como o HMC/porcas, todos os algoritmos MCMC restantes suportam a enumeração de enumeração sobre variáveis latentes discretas, se possível (consulte Restrições). Os sites enumerados precisam ser marcados com infer={'enumerate': 'parallel'}
como no exemplo de anotação.
Trace_ELBO
mas calcula parte do elbo analiticamente se for possível.Veja os documentos para obter mais detalhes.
Suporte ao Windows limitado: Observe que o Numpyro não é testado no Windows e pode exigir a construção de Jaxlib a partir da fonte. Veja esse problema JAX para obter mais detalhes. Como alternativa, você pode instalar o subsistema do Windows para Linux e usar o Numpyro como em um sistema Linux. Consulte também CUDA no subsistema Windows para Linux e esta postagem do fórum, se você quiser usar as GPUs no Windows.
Para instalar o Numpyro com a versão mais recente da CPU do JAX, você pode usar o PIP:
pip install numpyro
Em caso de problemas de compatibilidade, surgem durante a execução do comando acima, você pode forçar a instalação de uma versão compatível da CPU conhecida do JAX com
pip install numpyro[cpu]
Para usar o Numpyro na GPU , você precisa instalar o CUDA primeiro e depois usar o seguinte comando pip:
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Se você precisar de mais orientações, dê uma olhada nas instruções de instalação da JAX GPU.
Para executar o Numpyro no Cloud TPUs , você pode olhar para alguns exemplos JAX nos Cloud TPU.
Para o Cloud TPU VM, você precisa configurar o back -end da TPU, conforme detalhado no Guia do Cloud TPU VM JAX QuickStart. Depois de verificar se o back -end da TPU está configurado corretamente, você pode instalar o Numpyro usando o comando pip install numpyro
.
Plataforma padrão: Jax usará a GPU por padrão se o pacote
jaxlib
apoiado por Cuda estiver instalado. Você pode usar o utilitário set_platformnumpyro.set_platform("cpu")
para alternar para a CPU no início do seu programa.
Você também pode instalar o Numpyro da fonte:
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
Você também pode instalar o Numpyro com o conda:
conda install -c conda-forge numpyro
Ao contrário do piro, numpyro.sample('x', dist.Normal(0, 1))
não funciona. Por que?
Você provavelmente está usando uma declaração numpyro.sample
fora de um contexto de inferência. O JAX não possui um estado aleatório global e, como tal, os amostradores de distribuição precisam de uma chave explícita do gerador de números aleatórios (PRNGKEY) para gerar amostras. Os algoritmos de inferência da Numpyro usam o manipulador de sementes para enfiar em uma chave de gerador de números aleatórios, nos bastidores.
Suas opções são:
Ligue diretamente para a distribuição e forneça um PRNGKey
, por exemplo dist.Normal(0, 1).sample(PRNGKey(0))
Forneça o argumento rng_key
para numpyro.sample
. Por exemplo, numpyro.sample('x', dist.Normal(0, 1), rng_key=PRNGKey(0))
.
Enrole o código em um manipulador seed
, usado como gerente de contexto ou como uma função que envolve o chamável original. por exemplo
with handlers . seed ( rng_seed = 0 ): # random.PRNGKey(0) is used
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 )) # uses a PRNGKey split from random.PRNGKey(0)
y = numpyro . sample ( 'y' , dist . Bernoulli ( x )) # uses different PRNGKey split from the last one
, ou como uma função de ordem superior:
def fn ():
x = numpyro . sample ( 'x' , dist . Beta ( 1 , 1 ))
y = numpyro . sample ( 'y' , dist . Bernoulli ( x ))
return y
print ( handlers . seed ( fn , rng_seed = 0 )())
Posso usar o mesmo modelo de piro para fazer inferência no Numpyro?
Como você deve ter notado nos exemplos, o Numpyro suporta todas as primitivas de pirro, como sample
, param
, plate
e module
e manipuladores de efeito. Além disso, garantimos que a API de distribuições seja baseada nas torch.distributions
e as classes de inferência como SVI
e MCMC
tenham a mesma interface. Isso, juntamente com a semelhança na API para operações de Numpy e Pytorch, garante que modelos contendo declarações primitivas de piro possam ser usadas com um back -end com algumas pequenas alterações. Exemplo de algumas diferenças, juntamente com as mudanças necessárias, estão observadas abaixo:
torch
em seu modelo precisará ser gravada em termos da operação jax.numpy
correspondente. Além disso, nem todas as operações torch
têm uma contraparte numpy
(e vice-versa) e, às vezes, existem pequenas diferenças na API.pyro.sample
Declarações fora de um contexto de inferência precisará ser embrulhada em um manipulador seed
, como mencionado acima.numpyro.param
fora de um contexto de inferência não terá efeito. Para recuperar os valores de parâmetros otimizados do SVI, use o método svi.get_params. Observe que você ainda pode usar as instruções param
dentro de um modelo e o Numpyro usará o manipulador de efeitos substitutos internamente para substituir valores do otimizador ao executar o modelo no SVI.Para a maioria dos modelos pequenos, as mudanças necessárias para a inferência no Numpyro devem ser menores. Além disso, estamos trabalhando no Pyro-API, que permite escrever o mesmo código e despachar-o para vários backnds, incluindo Numpyro. Isso será necessariamente mais restritivo, mas terá a vantagem de ser agnóstico de back -end. Veja a documentação para um exemplo e informe -nos seus comentários.
Como posso contribuir para o projeto?
Obrigado pelo seu interesse no projeto! Você pode dar uma olhada em questões para iniciantes que estão marcadas com a boa etiqueta de primeira edição no Github. Além disso, sinta -se para nos alcançar no fórum.
No curto prazo, planejamos trabalhar com o seguinte. Abra novos problemas para solicitações e aprimoramentos de recursos:
As idéias motivadoras por trás do Numpyro e uma descrição das nozes iterativas podem ser encontradas neste artigo que apareceu no Neurips 2019 Program Transformations for Machine Learning Workshop.
Se você usa o Numpyro, considere citar:
@article{phan2019composable,
title={Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
author={Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
journal={arXiv preprint arXiv:1912.11554},
year={2019}
}
assim como
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}