クイックスタート|変換|インストールガイド|ニューラルネットライブラリ|変更ログ|参考資料
JAX は、アクセラレータ指向の配列計算およびプログラム変換用の Python ライブラリであり、高性能数値計算および大規模機械学習向けに設計されています。
Autograd の更新バージョンを使用すると、JAX はネイティブ Python 関数と NumPy 関数を自動的に区別できます。ループ、分岐、再帰、クロージャを通じて微分することができ、導関数の導関数の導関数を取得できます。順方向モード微分だけでなく、 grad
による逆方向モード微分 (別名バックプロパゲーション) もサポートしており、この 2 つは任意の順序で任意に合成できます。
新しい点は、JAX が XLA を使用して GPU および TPU 上で NumPy プログラムをコンパイルし、実行することです。コンパイルはデフォルトで内部で行われ、ライブラリ呼び出しがジャストインタイムでコンパイルおよび実行されます。ただし、JAX では、1 関数 API jit
を使用して、独自の Python 関数を XLA 最適化カーネルにジャストインタイムでコンパイルすることもできます。コンパイルと自動微分は任意に構成できるため、Python を離れることなく高度なアルゴリズムを表現し、最大限のパフォーマンスを得ることができます。 pmap
使用して複数の GPU または TPU コアを同時にプログラムし、全体を区別することもできます。
もう少し深く掘り下げると、JAX が実際には、構成可能な関数変換のための拡張可能なシステムであることがわかります。 grad
とjit
どちらもそのような変換の例です。その他には、自動ベクトル化用のvmap
と、マルチ アクセラレータの単一プログラム複数データ (SPMD) 並列プログラミング用のpmap
があり、さらに多くの機能が追加される予定です。
これは研究プロジェクトであり、Google の公式製品ではありません。バグや鋭いエッジが予想されます。ぜひ試してみて、バグを報告し、ご意見をお聞かせください。
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Google Cloud GPU に接続されたブラウザでノートブックを使用して、すぐに始められます。以下にスターター ノートブックをいくつか示します。
grad
、コンパイルのためのjit
、ベクトル化のためのvmap
JAX は Cloud TPU 上で実行されるようになりました。プレビューを試すには、Cloud TPU Colabs を参照してください。
JAX についてさらに詳しく知りたい場合は、以下を参照してください。
JAX の核心は、数値関数を変換するための拡張可能なシステムです。主に重要な 4 つの変換、 grad
、 jit
、 vmap
、およびpmap
次に示します。
grad
による自動微分JAX には Autograd とほぼ同じ API があります。最も人気のある関数は、リバースモード勾配のgrad
です。
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
grad
を使用して任意の順序に区別できます。
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
より高度な autodiff の場合、リバース モードのベクトル ヤコビアン積にはjax.vjp
、フォワード モードのヤコビアン ベクトル積にはjax.jvp
使用できます。この 2 つは、相互に、または他の JAX 変換と任意に組み合わせることができます。これらを組み合わせて完全なヘッセ行列を効率的に計算する関数を作成する 1 つの方法は次のとおりです。
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Autograd と同様に、Python 制御構造で自由に微分を使用できます。
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
詳細については、自動微分に関するリファレンス ドキュメントと JAX Autodiff Cookbook を参照してください。
jit
によるコンパイルXLA を使用すると、 @jit
デコレーターまたは高階関数として使用されるjit
で関数をエンドツーエンドでコンパイルできます。
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
jit
とgrad
、およびその他の JAX 変換を自由に組み合わせることができます。
jit
使用すると、関数が使用できる Python 制御フローの種類に制約が課されます。詳細については、JIT を使用した制御フローと論理演算子に関するチュートリアルを参照してください。
vmap
による自動ベクトル化vmap
ベクトル化マップです。これは、配列軸に沿って関数をマッピングするというおなじみのセマンティクスを備えていますが、ループを外側に保持する代わりに、ループを関数のプリミティブ操作に押し込んでパフォーマンスを向上させます。
vmap
使用すると、コード内でバッチ ディメンションを持ち歩く必要がなくなります。たとえば、次の単純なバッチ処理されていないニューラル ネットワーク予測関数を考えてみましょう。
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
多くの場合、代わりにjnp.dot(activations, W)
を記述して、 activations
の左側にバッチ次元を許可しますが、この特定の予測関数は単一の入力ベクトルにのみ適用するように記述しました。この関数を入力のバッチに一度に適用したい場合は、意味的に次のように書くことができます。
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
ただし、一度に 1 つのサンプルをネットワーク経由でプッシュすると時間がかかります。計算をベクトル化して、すべての層で行列とベクトルの乗算ではなく行列と行列の乗算を実行する方が良いでしょう。
vmap
関数はその変換を行います。つまり、次のように書くと、
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
次に、 vmap
関数は関数内に外側のループをプッシュし、マシンは手動でバッチ処理を行ったかのように行列と行列の乗算を正確に実行することになります。
vmap
使用せずに単純なニューラル ネットワークを手動でバッチ処理するのは十分に簡単ですが、場合によっては手動のベクトル化が非現実的または不可能になることがあります。例ごとの勾配を効率的に計算するという問題を考えてみましょう。つまり、パラメータの固定セットについて、バッチ内の各例で個別に評価される損失関数の勾配を計算したいと考えています。 vmap
使用すると、次のように簡単になります。
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
もちろん、 vmap
jit
、 grad
、およびその他の JAX 変換を使用して任意に構成できます。 jax.jacfwd
、 jax.jacrev
、およびjax.hessian
での高速ヤコビアン行列計算とヘッセ行列計算のために、順方向モードと逆方向モードの両方の自動微分を備えたvmap
使用します。
pmap
を使用した SPMD プログラミング複数の GPU など、複数のアクセラレータの並列プログラミングの場合は、 pmap
使用します。 pmap
使用すると、高速並列集合通信操作を含む、単一プログラム複数データ (SPMD) プログラムを作成できます。 pmap
適用すると、作成した関数が XLA ( jit
と同様) によってコンパイルされ、複製されてデバイス間で並行して実行されることを意味します。
以下は 8 GPU マシンの例です。
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
純粋なマップを表現することに加えて、デバイス間で高速な集合通信操作を使用できます。
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
より洗練された通信パターンを実現するために、 pmap
関数をネストすることもできます。
すべてが構成されるため、並列計算を通じて自由に微分することができます。
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
pmap
関数をリバースモードで微分する場合 (例: grad
を使用)、計算のバックワード パスはフォワード パスと同様に並列化されます。
詳細については、SPMD クックブックと SPMD MNIST 分類器のゼロからの例を参照してください。
例と説明を含む現在の問題点をより徹底的に調査するには、「問題点ノートブック」を読むことを強くお勧めします。いくつかの傑出した点:
is
を使用したオブジェクトの同一性テストは保持されません)。不純な Python 関数で JAX 変換を使用すると、 Exception: Can't lift Traced...
またはException: Different traces at same level
あります」のようなエラーが表示される場合があります。x[i] += y
のような配列のインプレース変更更新はサポートされていませんが、機能的な代替手段はあります。 jit
では、これらの代替機能はバッファをインプレースで自動的に再利用します。jax.lax
パッケージにあります。float32
) 値を強制します。倍精度 (64 ビット、たとえばfloat64
) を有効にするには、起動時にjax_enable_x64
変数を設定する (または環境変数JAX_ENABLE_X64=True
を設定する) 必要があります。 。 TPU では、JAX は、 jax.numpy.dot
やlax.conv
などの「matmul のような」操作における内部一時変数を除くすべてにデフォルトで 32 ビット値を使用します。これらの演算には、3 つの bfloat16 パスを介して 32 ビット演算を近似するために使用できるprecision
パラメータがありますが、実行時間が遅くなる可能性があります。 TPU での非 matmul 操作は、精度よりも速度を重視する実装に劣るため、実際の TPU での計算は、他のバックエンドでの同様の計算よりも精度が低くなります。np.add(1, np.array([2], np.float32)).dtype
float32
ではなくfloat64
です。jit
などの一部の変換では、Python 制御フローの使用方法が制限されます。何か問題が発生すると、常に大きなエラーが発生します。 jit
のstatic_argnums
パラメータ、 lax.scan
のような構造化された制御フロー プリミティブを使用するか、より小さなサブ関数でjit
使用する必要がある場合があります。 Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
CPU | はい | はい | はい | はい | はい | はい |
NVIDIA GPU | はい | はい | いいえ | 該当なし | いいえ | 実験的な |
Google TPU | はい | 該当なし | 該当なし | 該当なし | 該当なし | 該当なし |
AMD GPU | はい | いいえ | 実験的な | 該当なし | いいえ | いいえ |
アップルのGPU | 該当なし | いいえ | 該当なし | 実験的な | 該当なし | 該当なし |
インテルGPU | 実験的な | 該当なし | 該当なし | 該当なし | いいえ | いいえ |
プラットフォーム | 説明書 |
---|---|
CPU | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
Google TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU (Linux) | Docker、事前に構築されたホイール、またはソースから構築を使用します。 |
マックのGPU | Apple の指示に従ってください。 |
インテルGPU | インテルの指示に従ってください。 |
別のインストール戦略については、ドキュメントを参照してください。これには、ソースからのコンパイル、Docker を使用したインストール、他のバージョンの CUDA の使用、コミュニティでサポートされている conda ビルド、およびよくある質問への回答が含まれます。
Google DeepMind および Alphabet の複数の Google 研究グループは、JAX でニューラル ネットワークをトレーニングするためのライブラリを開発および共有しています。例やハウツー ガイドを備えたニューラル ネットワーク トレーニング用の完全な機能を備えたライブラリが必要な場合は、Flax とそのドキュメント サイトを試してください。
JAX ベースのネットワーク ライブラリのリストについては、JAX ドキュメント サイトの JAX エコシステム セクションを確認してください。これには、勾配処理と最適化のための Optax、信頼性の高いコードとテストのための chex、ニューラル ネットワーク用の Equinox が含まれます。 (詳細については、DeepMind での NeurIPS 2020 JAX Ecosystem の講演をご覧ください。)
このリポジトリを引用するには:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
上記の bibtex エントリでは、名前はアルファベット順に並べられており、バージョン番号は jax/version.py からのものであり、年はプロジェクトのオープンソース リリースに対応しています。
XLA への自動微分とコンパイルのみをサポートする JAX の初期バージョンは、SysML 2018 で発表された論文で説明されました。私たちは現在、より包括的で最新の論文で JAX のアイデアと機能をカバーすることに取り組んでいます。
JAX API の詳細については、リファレンス ドキュメントを参照してください。
JAX 開発者として始めるには、開発者ドキュメントを参照してください。