¡Construido con Jax y Pint!
Este módulo proporciona una interfaz entre Jax y Pint para permitir que Jax admita las operaciones con unidades. La propagación de unidades ocurre en el momento de traza, por lo que las funciones jitidas no deberían ver el costo de tiempo de ejecución. Esta biblioteca es experimental, así que espere algunos bordes afilados.
Por ejemplo:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
Para instalar, use pip
:
python -m pip install jpu
Las únicas dependencias son jax
y pint
, y estas también se instalarán, si no están en su entorno. Eche un vistazo a los documentos de Jax para obtener más información sobre la instalación de Jax en diferentes sistemas.
Aquí hay un ejemplo un poco más completo:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
La limitación más significativa de esta biblioteca es el hecho de que los usuarios deben usar funciones jpu.numpy
al interactuar con "cantidades" con unidades en lugar de la interfaz jax.numpy
. Esto se debe a que Jax no (¿todavía?) Proporciona una interfaz general para el envío de UFuncs en clases de matriz personalizadas. He jugado con la interfaz __jax_array__
indocumentada, pero no es lo suficientemente flexible, y actualmente no es compatible con los objetos Pytree.
Hasta ahora, solo se implementa un subconjunto de la interfaz numpy
/ jax.numpy
. ¡Las solicitudes de extracción de agregar un soporte más amplio (incluidos los submódulos) serían bienvenidos!