Dibangun dengan Jax dan Pint!
Modul ini menyediakan antarmuka antara Jax dan pint untuk memungkinkan Jax mendukung operasi dengan unit. Propagasi unit terjadi pada waktu jejak, sehingga fungsi yang di -retak harus melihat tidak ada biaya runtime. Perpustakaan ini bersifat eksperimental jadi harapkan beberapa tepi yang tajam.
Misalnya:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
Untuk menginstal, gunakan pip
:
python -m pip install jpu
Satu -satunya dependensi adalah jax
dan pint
, dan ini juga akan diinstal, jika belum di lingkungan Anda. Lihatlah dokumen JAX untuk informasi lebih lanjut tentang menginstal JAX pada sistem yang berbeda.
Berikut adalah contoh yang sedikit lebih lengkap:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
Keterbatasan paling signifikan dari perpustakaan ini adalah fakta bahwa pengguna harus menggunakan fungsi jpu.numpy
saat berinteraksi dengan "jumlah" dengan unit alih -alih antarmuka jax.numpy
. Ini karena Jax belum (belum?) Menyediakan antarmuka umum untuk pengiriman UFuncs pada kelas array khusus. Saya telah bermain -main dengan antarmuka __jax_array__
yang tidak berdokumen, tetapi itu tidak cukup fleksibel, dan saat ini tidak kompatibel dengan objek Pytree.
Sejauh ini, hanya sebagian antarmuka numpy
/ jax.numpy
yang diimplementasikan. Permintaan tarik menambahkan dukungan yang lebih luas (termasuk submodul) akan diterima!