jpu
v0.0.4
用jax和pint建造!
該模塊提供了JAX和PINT之間的接口,以允許JAX支持單位的操作。單位的傳播發生在微量的時間,因此jitter功能應看到運行時的成本。該庫是實驗性的,因此可以期待一些鋒利的邊緣。
例如:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
要安裝,請使用pip
:
python -m pip install jpu
唯一的依賴性是jax
和pint
,如果尚未在您的環境中,則將安裝這些依賴項。查看JAX文檔,以獲取有關在不同系統上安裝JAX的更多信息。
這是一個更完整的示例:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
該庫的最重要限制是,用戶必須使用jpu.numpy
函數與單位而不是jax.numpy
接口進行交互時。這是因為Jax(尚未(?)還沒有提供用於在自定義數組類中調度UFUNCS的一般接口。我已經玩過無證件的__jax_array__
界面,但是它的靈活性不夠,並且目前與pytree對像不兼容。
到目前為止,僅實現了numpy
/ jax.numpy
接口的一個子集。歡迎添加更廣泛支持的拉動請求(包括子模型)!