快速入門|轉型|安裝指南|神經網路庫|更改日誌|參考文檔
JAX是一個以加速器為導向的陣列運算和程式轉換的Python函式庫,專為高效能數值運算和大規模機器學習而設計。
借助 Autograd 的更新版本,JAX 可以自動區分原生 Python 和 NumPy 函數。它可以透過循環、分支、遞歸和閉包進行微分,並且可以求導數的導數的導數。它支援通過grad
反向模式微分(又稱反向傳播)以及前向模式微分,兩者可以按任何順序任意組合。
新增功能是 JAX 使用 XLA 在 GPU 和 TPU 上編譯和執行 NumPy 程式。預設情況下,編譯發生在幕後,函式庫調用會被及時編譯和執行。但 JAX 也允許您使用單函數 API jit
將自己的 Python 函數即時編譯到 XLA 最佳化的核心。編譯和自動微分可以任意組合,因此您可以在不離開Python的情況下表達複雜的演算法並獲得最大的效能。您甚至可以使用pmap
同時對多個 GPU 或 TPU 核心進行編程,並區分整個過程。
深入挖掘一下,您會發現 JAX 實際上是一個用於可組合函數轉換的可擴展系統。 grad
和jit
都是這類轉換的實例。其他功能包括用於自動向量化的vmap
和用於多個加速器的單程式多資料 (SPMD) 平行程式設計的pmap
,還有更多功能即將推出。
這是一個研究項目,不是 Google 官方產品。預計會有錯誤和鋒利的邊緣。請嘗試、報告錯誤並讓我們知道您的想法來提供幫助!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
直接在瀏覽器中使用連接到 Google Cloud GPU 的筆電。以下是一些入門筆記本:
grad
、用於編譯的jit
以及用於向量化的vmap
JAX 現在在 Cloud TPU 上運行。若要試用預覽版,請參閱 Cloud TPU Colabs。
要更深入了解 JAX:
JAX 的核心是一個用於轉換數值函數的可擴展系統。以下是主要感興趣的四個轉換: grad
、 jit
、 vmap
和pmap
。
grad
JAX 具有與 Autograd 大致相同的 API。最流行的函數是反向模式梯度的grad
:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
您可以使用grad
區分任何訂單。
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
對於更進階的自動差分,您可以將jax.vjp
用於反向模式雅可比向量乘積,將jax.jvp
用於正向模式雅可比向量乘積。兩者可以任意組合,也可以與其他 JAX 轉換任意組合。以下是組合這些矩陣以創建有效計算完整 Hessian 矩陣的函數的一種方法:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
與 Autograd 一樣,您可以自由地使用 Python 控制結構的微分:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
有關更多信息,請參閱有關自動微分的參考文檔和 JAX Autodiff Cookbook。
jit
編譯您可以使用 XLA 與jit
端對端編譯函數,用作@jit
裝飾器或高階函數。
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
您可以根據需要混合jit
和grad
以及任何其他 JAX 轉換。
使用jit
對函數可以使用的 Python 控制流程類型施加了限制;有關更多信息,請參閱有關使用 JIT 的控制流和邏輯運算符的教程。
vmap
自動向量化vmap
是向量化映射。它具有沿著數組軸映射函數的熟悉語義,但不是將循環保留在外部,而是將循環向下推入函數的原始操作中,以獲得更好的性能。
使用vmap
可以讓您不必在程式碼中攜帶批量維度。例如,考慮這個簡單的非批次神經網路預測函數:
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
我們經常改為編寫jnp.dot(activations, W)
以允許在activations
的左側使用批量維度,但我們編寫了這個特定的預測函數以僅應用於單個輸入向量。如果我們想立即將此函數應用於一批輸入,從語義上講,我們可以編寫
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
但是一次透過網路推送一個範例會很慢!最好將計算向量化,以便在每一層我們都進行矩陣-矩陣乘法而不是矩陣-向量乘法。
vmap
函數為我們完成了這個轉換。也就是說,如果我們寫
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
然後vmap
函數會將外循環推入函數內,我們的機器最終將執行矩陣-矩陣乘法,就像我們手動完成批次一樣。
在沒有vmap
情況下手動批次簡單的神經網路很容易,但在其他情況下手動向量化可能不切實際或不可能。以有效計算每個範例梯度的問題為例:也就是說,對於一組固定的參數,我們希望計算在批次中的每個範例中單獨評估的損失函數的梯度。使用vmap
,這很容易:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
當然, vmap
可以與jit
、 grad
和任何其他 JAX 轉換任意組合!我們使用具有正向和反向模式自動微分的vmap
在jax.jacfwd
、 jax.jacrev
和jax.hessian
中進行快速 Jacobian 和 Hessian 矩陣計算。
pmap
進行 SPMD 編程對於多個加速器(例如多個 GPU)的平行編程,請使用pmap
。使用pmap
您可以編寫單程序多資料 (SPMD) 程序,包括快速並行集體通訊作業。應用程式pmap
將意味著您編寫的函數由 XLA 編譯(類似於jit
),然後跨裝置複製和並行執行。
以下是 8-GPU 機器上的範例:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
除了表達純地圖之外,您還可以在設備之間使用快速集體通訊操作:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
您甚至可以嵌套pmap
函數以實現更複雜的通訊模式。
所有這些都組合在一起,因此您可以透過並行計算自由地進行微分:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
當反向模式微分pmap
函數(例如使用grad
)時,計算的後向傳遞就像前向傳遞一樣並行化。
有關更多信息,請參閱 SPMD Cookbook 和從頭開始的 SPMD MNIST 分類器範例。
要對目前的陷阱進行更全面的調查,並附有範例和解釋,我們強烈建議您閱讀陷阱筆記本。一些傑出的:
is
進行的物件身份測試)。如果您對不純的 Python 函數使用 JAX 轉換,您可能會看到類似Exception: Can't lift Traced...
或Exception: Different traces at same level
錯誤。x[i] += y
,但有功能替代方案。在jit
下,這些功能替代方案將自動就地重複使用緩衝區。jax.lax
套件中。float32
)值,若要啟用雙精確度(64 位元,例如float64
),需要在啟動時設定jax_enable_x64
變數(或設定環境變數JAX_ENABLE_X64=True
) 。在 TPU 上,JAX 預設對除「類似於 matmul」操作中的內部臨時變數(例如jax.numpy.dot
和lax.conv
之外的所有內容使用 32 位元值。這些操作有一個precision
參數,可用於透過三個 bfloat16 傳遞來近似 32 位元操作,但代價可能是運行時間較慢。 TPU 上的非 matmul 運算低於通常強調速度而不是精度的實現,因此實際上 TPU 上的計算將不如其他後端上的類似計算精確。np.add(1, np.array([2], np.float32)).dtype
是float64
而不是float32
。jit
)會限制您使用 Python 控制流的方式。如果出現問題,你總是會收到很大的錯誤。您可能必須使用jit
的static_argnums
參數、像lax.scan
這樣的結構化控制流原語,或僅在較小的子函數上使用jit
。 Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
中央處理器 | 是的 | 是的 | 是的 | 是的 | 是的 | 是的 |
英偉達圖形處理器 | 是的 | 是的 | 不 | 不適用 | 不 | 實驗性的 |
谷歌TPU | 是的 | 不適用 | 不適用 | 不適用 | 不適用 | 不適用 |
AMD顯示卡 | 是的 | 不 | 實驗性的 | 不適用 | 不 | 不 |
蘋果GPU | 不適用 | 不 | 不適用 | 實驗性的 | 不適用 | 不適用 |
英特爾GPU | 實驗性的 | 不適用 | 不適用 | 不適用 | 不 | 不 |
平台 | 指示 |
---|---|
中央處理器 | pip install -U jax |
英偉達圖形處理器 | pip install -U "jax[cuda12]" |
谷歌TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU(Linux) | 使用 Docker、預先建置輪子或從原始碼建置。 |
蘋果電腦GPU | 請遵循 Apple 的說明。 |
英特爾GPU | 請遵循英特爾的說明。 |
有關替代安裝策略的信息,請參閱文件。其中包括從原始碼編譯、使用 Docker 安裝、使用其他版本的 CUDA、社群支援的 conda 建置以及一些常見問題的解答。
Google DeepMind 和 Alphabet 的多個 Google 研究小組開發並分享用於在 JAX 中訓練神經網路的程式庫。如果您想要一個功能齊全的神經網路訓練庫,其中包含範例和操作指南,請嘗試 Flax 及其文件網站。
查看 JAX 文件網站上的 JAX 生態系統部分,以取得基於 JAX 的網路庫列表,其中包括用於梯度處理和優化的 Optax、可靠程式碼和測試的 chex 以及用於神經網路的 Equinox。 (請觀看 DeepMind 的 NeurIPS 2020 JAX 生態系統演講,以了解更多詳細資訊。)
引用這個儲存庫:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
在上面的 bibtex 條目中,名稱按字母順序排列,版本號是來自 jax/version.py 的版本號,年份對應於專案的開源版本。
SysML 2018 上發表的一篇論文描述了 JAX 的新版本,僅支援自動微分和編譯為 XLA。
有關 JAX API 的詳細信息,請參閱參考文件。
若要開始成為 JAX 開發人員,請參閱開發人員文件。