Gebaut mit Jax und Pint!
Dieses Modul bietet eine Schnittstelle zwischen JAX und Pint, damit JAX Operationen mit Einheiten unterstützen kann. Die Ausbreitung von Einheiten erfolgt zum Zeitpunkt der Spur, sodass Jitt -Funktionen keine Laufzeitkosten feststellen sollten. Diese Bibliothek ist experimentell. Erwarten Sie also einige scharfe Kanten.
Zum Beispiel:
>> > import jax
>> > import jax . numpy as jnp
>> > import jpu
>> >
>> > u = jpu . UnitRegistry ()
>> >
>> > @ jax . jit
... def add_two_lengths ( a , b ):
... return a + b
...
>> > add_two_lengths ( 3 * u . m , jnp . array ([ 4.5 , 1.2 , 3.9 ]) * u . cm )
< Quantity ([ 3.045 3.012 3.039 ], 'meter' ) >
Verwenden Sie zum Installieren pip
:
python -m pip install jpu
Die einzigen Abhängigkeiten sind jax
und pint
, und diese werden auch installiert, wenn nicht bereits in Ihrer Umgebung. Schauen Sie sich die JAX -Dokumente an, um weitere Informationen zur Installation von JAX auf verschiedenen Systemen zu erhalten.
Hier ist ein etwas vollständigeres Beispiel:
>> > import jax
>> > import numpy as np
>> > from jpu import UnitRegistry , numpy as jnpu
>> >
>> > u = UnitRegistry ()
>> >
>> > @ jax . jit
... def projectile_motion ( v_init , theta , time , g = u . standard_gravity ):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu . cos ( theta )
... y = v_init * time * jnpu . sin ( theta ) - 0.5 * g * jnpu . square ( time )
... return x . to ( u . m ), y . to ( u . m )
...
>> > x , y = projectile_motion (
... 5.0 * u . km / u . h , 60 * u . deg , np . linspace ( 0 , 1 , 50 ) * u . s
... )
Die wichtigste Einschränkung dieser Bibliothek ist die Tatsache, dass Benutzer bei der Interaktion mit "Mengen" mit Einheiten anstelle der jax.numpy
-Schnittstelle jpu.numpy
-Funktionen verwenden müssen. Dies liegt daran, dass JAX (noch?) NICHT eine allgemeine Schnittstelle zum Versenden von UFuncs in benutzerdefinierten Array -Klassen bereitstellt. Ich habe mit der undokumentierten __jax_array__
-Schnittstelle herumgespielt, aber sie ist nicht wirklich flexibel genug und ist derzeit nicht mit Pytree -Objekten kompatibel.
Bisher wird nur eine Untergruppe der Schnittstelle von numpy
/ jax.numpy
implementiert. Anfragen, die eine breitere Unterstützung (einschließlich Submodules) hinzufügen, wären willkommen!