สร้างด้วย Jax และ Pint!
โมดูลนี้ให้อินเทอร์เฟซระหว่าง JAX และ PINT เพื่อให้ JAX รองรับการดำเนินการกับหน่วย การแพร่กระจายของหน่วยเกิดขึ้นในเวลาที่ติดตามดังนั้นฟังก์ชั่นที่ได้รับการคัดเลือกไม่ควรเห็นค่าใช้จ่ายรันไทม์ ห้องสมุดนี้ทดลองใช้ดังนั้นคาดว่าจะมีขอบคม
ตัวอย่างเช่น:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
ในการติดตั้งให้ใช้ pip
:
python -m pip install jpu
การพึ่งพาเพียงอย่างเดียวคือ jax
และ pint
และสิ่งเหล่านี้จะถูกติดตั้งหากไม่ได้อยู่ในสภาพแวดล้อมของคุณ ดูเอกสาร JAX สำหรับข้อมูลเพิ่มเติมเกี่ยวกับการติดตั้ง JAX ในระบบที่แตกต่างกัน
นี่คือตัวอย่างที่สมบูรณ์กว่าเล็กน้อย:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
ข้อ จำกัด ที่สำคัญที่สุดของไลบรารีนี้คือความจริงที่ว่าผู้ใช้จะต้องใช้ฟังก์ชั่น jpu.numpy
เมื่อโต้ตอบกับ "ปริมาณ" กับหน่วยแทน jax.numpy
อินเตอร์เฟส นี่เป็นเพราะ JAX ไม่ได้ (ยัง?) ให้อินเทอร์เฟซทั่วไปสำหรับการส่ง UFUNCS ในคลาสอาร์เรย์ที่กำหนดเอง ฉันได้เล่นกับอินเทอร์เฟซ __jax_array__
ที่ไม่มีเอกสาร แต่มันไม่ยืดหยุ่นเพียงพอและไม่สามารถใช้งานได้กับวัตถุ Pytree
จนถึงตอนนี้มีการใช้งานส่วนย่อยของอินเตอร์เฟส numpy
/ jax.numpy
เท่านั้น การร้องขอแบบดึงเพิ่มการสนับสนุนที่กว้างขึ้น (รวมถึง submodules) จะได้รับการต้อนรับ!