jpu
v0.0.4
JaxとPintで構築されました!
このモジュールは、JAXとPINTの間のインターフェイスを提供して、JAXがユニットで操作をサポートできるようにします。ユニットの伝播はTRACE時に発生するため、ジット関数はランタイムコストが表示されないはずです。このライブラリは実験的なので、鋭いエッジを期待してください。
例えば:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
インストールするには、 pip
を使用してください。
python -m pip install jpu
唯一の依存関係はjax
とpint
であり、これらも環境に存在していなくてもインストールされます。さまざまなシステムにJaxのインストールの詳細については、Jax Docsをご覧ください。
ここにもう少し完全な例があります:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
このライブラリの最も重要な制限は、 jax.numpy
インターフェイスの代わりにユニットと「数量」と対話するときに、ユーザーがjpu.numpy
関数を使用する必要があるという事実です。これは、Jaxが(まだ?)カスタム配列クラスでUFUNCを発送するための一般的なインターフェイスを提供していないためです。私は文書化されていない__jax_array__
インターフェイスをいじりましたが、それはあまり柔軟ではなく、現在はpytreeオブジェクトと互換性がありません。
これまでのところ、 numpy
/ jax.numpy
インターフェイスのサブセットのみが実装されています。幅広いサポートを追加するリクエスト(サブモジュールを含む)を歓迎します!