快速入门|转型|安装指南|神经网络库|更改日志|参考文档
JAX是一个面向加速器的数组计算和程序转换的Python库,专为高性能数值计算和大规模机器学习而设计。
借助 Autograd 的更新版本,JAX 可以自动区分原生 Python 和 NumPy 函数。它可以通过循环、分支、递归和闭包进行微分,并且可以求导数的导数的导数。它支持通过grad
反向模式微分(又名反向传播)以及前向模式微分,并且两者可以按任何顺序任意组合。
新增功能是 JAX 使用 XLA 在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译发生在幕后,库调用会被及时编译和执行。但 JAX 还允许您使用单函数 API jit
将自己的 Python 函数即时编译到 XLA 优化的内核中。编译和自动微分可以任意组合,因此您可以在不离开Python的情况下表达复杂的算法并获得最大的性能。您甚至可以使用pmap
同时对多个 GPU 或 TPU 内核进行编程,并区分整个过程。
深入挖掘一下,您会发现 JAX 实际上是一个用于可组合函数转换的可扩展系统。 grad
和jit
都是此类转换的实例。其他功能包括用于自动矢量化的vmap
和用于多个加速器的单程序多数据 (SPMD) 并行编程的pmap
,还有更多功能即将推出。
这是一个研究项目,不是 Google 官方产品。预计会有错误和锋利的边缘。请尝试一下、报告错误并让我们知道您的想法来提供帮助!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
直接在浏览器中使用连接到 Google Cloud GPU 的笔记本。以下是一些入门笔记本:
grad
、用于编译的jit
以及用于矢量化的vmap
JAX 现在在 Cloud TPU 上运行。要试用预览版,请参阅 Cloud TPU Colabs。
要更深入地了解 JAX:
JAX 的核心是一个用于转换数值函数的可扩展系统。以下是主要感兴趣的四个转换: grad
、 jit
、 vmap
和pmap
。
grad
JAX 具有与 Autograd 大致相同的 API。最流行的函数是反向模式梯度的grad
:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
您可以使用grad
区分任何订单。
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
对于更高级的自动差分,您可以将jax.vjp
用于反向模式雅可比矢量乘积,将jax.jvp
用于正向模式雅可比矢量乘积。两者可以任意组合,也可以与其他 JAX 转换任意组合。以下是组合这些函数以创建有效计算完整 Hessian 矩阵的函数的一种方法:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
与 Autograd 一样,您可以自由地使用 Python 控制结构的微分:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
有关更多信息,请参阅有关自动微分的参考文档和 JAX Autodiff Cookbook。
jit
编译您可以使用 XLA 与jit
端到端编译函数,用作@jit
装饰器或高阶函数。
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
您可以根据需要混合jit
和grad
以及任何其他 JAX 转换。
使用jit
对函数可以使用的 Python 控制流类型施加了限制;有关更多信息,请参阅有关使用 JIT 的控制流和逻辑运算符的教程。
vmap
自动矢量化vmap
是矢量化映射。它具有沿数组轴映射函数的熟悉语义,但不是将循环保留在外部,而是将循环向下推入函数的原始操作中,以获得更好的性能。
使用vmap
可以让您不必在代码中携带批量维度。例如,考虑这个简单的非批处理神经网络预测函数:
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
我们经常改为编写jnp.dot(activations, W)
以允许在activations
左侧使用批量维度,但我们编写了这个特定的预测函数以仅应用于单个输入向量。如果我们想立即将此函数应用于一批输入,从语义上讲,我们可以编写
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
但是一次通过网络推送一个示例会很慢!最好将计算向量化,以便在每一层我们都进行矩阵-矩阵乘法而不是矩阵-向量乘法。
vmap
函数为我们完成了这种转换。也就是说,如果我们写
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
然后vmap
函数会将外循环推入函数内,我们的机器最终将执行矩阵-矩阵乘法,就像我们手动完成批处理一样。
在没有vmap
情况下手动批处理简单的神经网络很容易,但在其他情况下手动矢量化可能不切实际或不可能。以有效计算每个示例梯度的问题为例:也就是说,对于一组固定的参数,我们希望计算在批次中的每个示例中单独评估的损失函数的梯度。使用vmap
,这很容易:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
当然, vmap
可以与jit
、 grad
和任何其他 JAX 转换任意组合!我们使用具有正向和反向模式自动微分的vmap
在jax.jacfwd
、 jax.jacrev
和jax.hessian
中进行快速 Jacobian 和 Hessian 矩阵计算。
pmap
进行 SPMD 编程对于多个加速器(例如多个 GPU)的并行编程,请使用pmap
。使用pmap
您可以编写单程序多数据 (SPMD) 程序,包括快速并行集体通信操作。应用pmap
将意味着您编写的函数由 XLA 编译(类似于jit
),然后跨设备并行复制和执行。
以下是 8-GPU 机器上的示例:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
除了表达纯地图之外,您还可以在设备之间使用快速集体通信操作:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
您甚至可以嵌套pmap
函数以实现更复杂的通信模式。
所有这些都组合在一起,因此您可以通过并行计算自由地进行微分:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
当反向模式微分pmap
函数(例如使用grad
)时,计算的后向传递就像前向传递一样被并行化。
有关更多信息,请参阅 SPMD Cookbook 和从头开始的 SPMD MNIST 分类器示例。
要对当前的陷阱进行更全面的调查,并附有示例和解释,我们强烈建议您阅读陷阱笔记本。一些杰出的:
is
进行的对象身份测试)。如果您对不纯的 Python 函数使用 JAX 转换,您可能会看到类似Exception: Can't lift Traced...
或Exception: Different traces at same level
错误。x[i] += y
,但有功能替代方案。在jit
下,这些功能替代方案将自动就地重用缓冲区。jax.lax
包中。float32
)值,要启用双精度(64 位,例如float64
),需要在启动时设置jax_enable_x64
变量(或设置环境变量JAX_ENABLE_X64=True
) 。在 TPU 上,JAX 默认情况下对除“类似 matmul”操作中的内部临时变量(例如jax.numpy.dot
和lax.conv
之外的所有内容使用 32 位值。这些操作有一个precision
参数,可用于通过三个 bfloat16 传递来近似 32 位操作,但代价可能是运行时间较慢。 TPU 上的非 matmul 运算低于通常强调速度而不是精度的实现,因此实际上 TPU 上的计算将不如其他后端上的类似计算精确。np.add(1, np.array([2], np.float32)).dtype
是float64
而不是float32
。jit
)会限制您使用 Python 控制流的方式。如果出现问题,你总是会收到很大的错误。您可能必须使用jit
的static_argnums
参数、像lax.scan
这样的结构化控制流原语,或者仅在较小的子函数上使用jit
。 Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
中央处理器 | 是的 | 是的 | 是的 | 是的 | 是的 | 是的 |
英伟达图形处理器 | 是的 | 是的 | 不 | 不适用 | 不 | 实验性的 |
谷歌TPU | 是的 | 不适用 | 不适用 | 不适用 | 不适用 | 不适用 |
AMD显卡 | 是的 | 不 | 实验性的 | 不适用 | 不 | 不 |
苹果GPU | 不适用 | 不 | 不适用 | 实验性的 | 不适用 | 不适用 |
英特尔GPU | 实验性的 | 不适用 | 不适用 | 不适用 | 不 | 不 |
平台 | 指示 |
---|---|
中央处理器 | pip install -U jax |
英伟达图形处理器 | pip install -U "jax[cuda12]" |
谷歌TPU | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
AMD GPU(Linux) | 使用 Docker、预构建的轮子或从源代码构建。 |
苹果电脑GPU | 请遵循 Apple 的说明。 |
英特尔GPU | 请遵循英特尔的说明。 |
有关替代安装策略的信息,请参阅文档。其中包括从源代码编译、使用 Docker 安装、使用其他版本的 CUDA、社区支持的 conda 构建以及一些常见问题的解答。
Google DeepMind 和 Alphabet 的多个 Google 研究小组开发并共享用于在 JAX 中训练神经网络的库。如果您想要一个功能齐全的神经网络训练库,其中包含示例和操作指南,请尝试 Flax 及其文档网站。
查看 JAX 文档网站上的 JAX 生态系统部分,获取基于 JAX 的网络库列表,其中包括用于梯度处理和优化的 Optax、用于可靠代码和测试的 chex 以及用于神经网络的 Equinox。 (请观看 DeepMind 的 NeurIPS 2020 JAX 生态系统演讲,了解更多详细信息。)
引用这个存储库:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
在上面的 bibtex 条目中,名称按字母顺序排列,版本号是来自 jax/version.py 的版本号,年份对应于项目的开源版本。
SysML 2018 上发表的一篇论文描述了 JAX 的新版本,仅支持自动微分和编译为 XLA。我们目前正在努力在一篇更全面、最新的论文中介绍 JAX 的想法和功能。
有关 JAX API 的详细信息,请参阅参考文档。
要开始成为 JAX 开发人员,请参阅开发人员文档。