Construído com Jax e Pint!
Este módulo fornece uma interface entre Jax e Pint para permitir que a JAX apoie as operações com unidades. A propagação das unidades acontece no tempo de rastreamento, para que as funções Jitt não tenham custo de tempo de execução. Esta biblioteca é experimental, portanto, espere algumas bordas nítidas.
Por exemplo:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
Para instalar, use pip
:
python -m pip install jpu
As únicas dependências são jax
e pint
, e estas também serão instaladas, se ainda não estiver em seu ambiente. Dê uma olhada no Jax Docs para obter mais informações sobre a instalação do JAX em diferentes sistemas.
Aqui está um exemplo um pouco mais completo:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
A limitação mais significativa desta biblioteca é o fato de que os usuários devem usar as funções jpu.numpy
ao interagir com "quantidades" com unidades em vez da interface jax.numpy
. Isso ocorre porque o JAX ainda não fornece uma interface geral para despachar dos UFUNCs nas classes de matriz personalizada. Eu brinquei com a interface __jax_array__
sem documentos, mas não é realmente flexível o suficiente, e atualmente não é compatível com os objetos Pytree.
Até agora, apenas um subconjunto da interface numpy
/ jax.numpy
é implementado. Solicitações de atração adicionando suporte mais amplo (incluindo submódulos) seriam bem -vindos!