jpu
v0.0.4
用jax和pint建造!
该模块提供了JAX和PINT之间的接口,以允许JAX支持单位的操作。单位的传播发生在微量的时间,因此jitter功能应看到运行时的成本。该库是实验性的,因此可以期待一些锋利的边缘。
例如:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
要安装,请使用pip
:
python -m pip install jpu
唯一的依赖性是jax
和pint
,如果尚未在您的环境中,则将安装这些依赖项。查看JAX文档,以获取有关在不同系统上安装JAX的更多信息。
这是一个更完整的示例:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
该库的最重要限制是,用户必须使用jpu.numpy
函数与单位而不是jax.numpy
接口进行交互时。这是因为Jax(尚未(?)还没有提供用于在自定义数组类中调度UFUNCS的一般接口。我已经玩过无证件的__jax_array__
界面,但是它的灵活性不够,并且目前与pytree对象不兼容。
到目前为止,仅实现了numpy
/ jax.numpy
接口的一个子集。欢迎添加更广泛支持的拉动请求(包括子模型)!