jpu
v0.0.4
Jax와 파인트로 제작되었습니다!
이 모듈은 JAX와 파인트 사이의 인터페이스를 제공하여 JAX가 장치로 작업을 지원할 수 있도록합니다. 단위의 전파는 추적 시간에 발생하므로 JITTER 기능은 런타임 비용을 보지 않아야합니다. 이 라이브러리는 실험적이므로 날카로운 가장자리가 예상됩니다.
예를 들어:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
설치하려면 pip
사용하십시오.
python -m pip install jpu
유일한 종속성은 jax
와 pint
이며, 환경에도 아직 설치됩니다. 다른 시스템에 JAX 설치에 대한 자세한 내용은 JAX 문서를 살펴보십시오.
다음은 약간 더 완전한 예입니다.
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
이 라이브러리의 가장 중요한 제한은 사용자가 jax.numpy
인터페이스 대신 "수량"과 상호 작용할 때 jpu.numpy
함수를 사용해야한다는 사실입니다. JAX는 (아직?) 사용자 정의 배열 클래스에서 ufuncs를 발송하기위한 일반적인 인터페이스를 제공하지 않기 때문입니다. 나는 문서화되지 않은 __jax_array__
인터페이스와 함께 연주했지만 실제로는 충분히 유연하지 않으며 현재 Pytree 객체와 호환되지 않습니다.
지금까지 numpy
/ jax.numpy
인터페이스의 하위 집합 만 구현됩니다. 더 넓은 지원 (하위 모듈 포함)을 추가하는 풀 요청을 환영합니다!