Быстрый старт | Преобразования | Руководство по установке | Библиотеки нейронных сетей | Журналы изменений | Справочная документация
JAX — это библиотека Python для вычислений массивов с использованием ускорителей и преобразования программ, предназначенная для высокопроизводительных численных вычислений и крупномасштабного машинного обучения.
Благодаря обновленной версии Autograd JAX может автоматически различать собственные функции Python и NumPy. Он может различаться с помощью циклов, ветвей, рекурсии и замыканий, а также может принимать производные от производных. Он поддерживает дифференцирование в обратном режиме (также известное как обратное распространение ошибки) посредством grad
, а также дифференцирование в прямом режиме, и эти два режима могут быть составлены произвольно и в любом порядке.
Новым является то, что JAX использует XLA для компиляции и запуска ваших программ NumPy на графических процессорах и TPU. По умолчанию компиляция происходит «под капотом», при этом вызовы библиотеки компилируются и выполняются «точно в срок». Но JAX также позволяет вам своевременно компилировать ваши собственные функции Python в ядра, оптимизированные для XLA, с помощью однофункционального API jit
. Компиляцию и автоматическое дифференцирование можно составлять произвольно, поэтому вы можете выражать сложные алгоритмы и получать максимальную производительность, не выходя из Python. Вы даже можете одновременно запрограммировать несколько ядер графических процессоров или TPU, используя pmap
, и различать все это.
Копните немного глубже, и вы увидите, что JAX действительно является расширяемой системой для преобразований составных функций. И grad
, и jit
являются примерами таких преобразований. Другие — это vmap
для автоматической векторизации и pmap
для параллельного программирования одной программы и нескольких данных (SPMD) с несколькими ускорителями, и это еще не все.
Это исследовательский проект, а не официальный продукт Google. Ожидайте ошибок и острых краев. Пожалуйста, помогите, опробовав его, сообщив об ошибках и сообщив нам, что вы думаете!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
Сразу приступайте к работе, используя ноутбук в браузере, подключенный к графическому процессору Google Cloud. Вот несколько стартовых блокнотов:
grad
для дифференциации, jit
для компиляции и vmap
для векторизации.JAX теперь работает на облачных TPU. Чтобы опробовать предварительную версию, посетите Cloud TPU Colabs.
Для более глубокого погружения в JAX:
По своей сути JAX — это расширяемая система преобразования числовых функций. Вот четыре преобразования, представляющие основной интерес: grad
, jit
, vmap
и pmap
.
grad
JAX имеет примерно тот же API, что и Autograd. Самая популярная функция — grad
для градиентов обратного режима:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
Вы можете дифференцировать любой порядок с помощью grad
.
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
Для более продвинутого автодифференцирования вы можете использовать jax.vjp
для векторно-якобианских произведений в обратном режиме и jax.jvp
для векторно-якобианских произведений в прямом режиме. Их можно произвольно составлять друг с другом и с другими преобразованиями JAX. Вот один из способов скомпоновать их, чтобы создать функцию, которая эффективно вычисляет полные матрицы Гессе:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
Как и в случае с Autograd, вы можете использовать дифференциацию со структурами управления Python:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
Дополнительную информацию см. в справочной документации по автоматическому дифференцированию и в кулинарной книге JAX Autodiff.
jit
Вы можете использовать XLA для сквозной компиляции ваших функций с помощью jit
, который используется либо как декоратор @jit
, либо как функция более высокого порядка.
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
Вы можете смешивать jit
и grad
и любое другое преобразование JAX по своему усмотрению.
Использование jit
накладывает ограничения на тип потока управления Python, который может использовать функция; дополнительную информацию см. в руководстве «Поток управления и логические операторы с JIT».
vmap
vmap
— карта векторизации. Он имеет знакомую семантику отображения функции по осям массива, но вместо того, чтобы держать цикл снаружи, он помещает цикл в примитивные операции функции для повышения производительности.
Использование vmap
может избавить вас от необходимости переносить размеры пакетов в код. Например, рассмотрим эту простую непакетную функцию прогнозирования нейронной сети:
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
Вместо этого мы часто пишем jnp.dot(activations, W)
чтобы учесть размерность пакета в левой части activations
, но мы написали эту конкретную функцию прогнозирования, чтобы она применялась только к одиночным входным векторам. Если бы мы хотели применить эту функцию к пакету входных данных одновременно, семантически мы могли бы просто написать
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
Но передача по сети одного примера за раз будет медленной! Лучше векторизовать вычисления, чтобы на каждом уровне мы выполняли умножение матрицы на матрицу, а не на умножение матрицы на вектор.
Функция vmap
делает это преобразование за нас. То есть, если мы напишем
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
тогда функция vmap
поместит внешний цикл внутрь функции, и наша машина в конечном итоге выполнит умножение матриц на матрицы точно так же, как если бы мы выполняли пакетную обработку вручную.
Достаточно легко вручную собрать простую нейронную сеть без vmap
, но в других случаях ручная векторизация может оказаться непрактичной или невозможной. Возьмем задачу эффективного вычисления градиентов для каждого примера: то есть для фиксированного набора параметров мы хотим вычислить градиент нашей функции потерь, оцениваемой отдельно для каждого примера в пакете. С vmap
это легко:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
Конечно, vmap
можно составить произвольно с помощью jit
, grad
и любого другого JAX-преобразования! Мы используем vmap
с автоматическим дифференцированием как в прямом, так и в обратном режиме для быстрых вычислений матриц Якобиана и Гессиана в jax.jacfwd
, jax.jacrev
и jax.hessian
.
pmap
Для параллельного программирования нескольких ускорителей, например нескольких графических процессоров, используйте pmap
. С помощью pmap
вы пишете однопрограммные программы с несколькими данными (SPMD), включая быстрые параллельные коллективные коммуникационные операции. Применение pmap
будет означать, что написанная вами функция компилируется XLA (аналогично jit
), затем реплицируется и выполняется параллельно на всех устройствах.
Вот пример на машине с 8 графическими процессорами:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
Помимо выражения чистых карт, вы можете использовать быстрые коллективные операции связи между устройствами:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
Вы даже можете вкладывать функции pmap
для более сложных шаблонов взаимодействия.
Все это складывается, поэтому вы можете проводить дифференциацию с помощью параллельных вычислений:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
При дифференцировании функции pmap
в обратном режиме (например, с помощью grad
) обратный проход вычислений распараллеливается так же, как и прямой проход.
Дополнительную информацию см. в кулинарной книге SPMD и примере классификатора SPMD MNIST с нуля.
Для более подробного изучения текущих ошибок с примерами и пояснениями мы настоятельно рекомендуем прочитать «Блокнот ошибок». Некоторые выдающиеся достижения:
is
не сохраняется). Если вы используете преобразование JAX для нечистой функции Python, вы можете увидеть ошибку типа Exception: Can't lift Traced...
или Exception: Different traces at same level
.x[i] += y
, не поддерживаются, но существуют функциональные альтернативы. Под jit
эти функциональные альтернативы будут автоматически повторно использовать буферы на месте.jax.lax
.float32
), а для включения двойной точности (64-битные, например float64
) необходимо установить переменную jax_enable_x64
при запуске (или установить переменную среды JAX_ENABLE_X64=True
) . В TPU JAX по умолчанию использует 32-битные значения для всего, кроме внутренних временных переменных в операциях, подобных matmul, таких как jax.numpy.dot
и lax.conv
. Эти операции имеют параметр precision
, который можно использовать для аппроксимации 32-битных операций за три прохода bfloat16, что может привести к увеличению времени выполнения. Операции, не относящиеся к Matmul, на TPU ниже реализаций, в которых скорость часто важнее точности, поэтому на практике вычисления на TPU будут менее точными, чем аналогичные вычисления на других бэкэндах.np.add(1, np.array([2], np.float32)).dtype
— это float64
, а не float32
.jit
, ограничивают возможности использования потока управления Python. Вы всегда будете получать громкие ошибки, если что-то пойдет не так. Возможно, вам придется использовать параметр static_argnums
jit
, примитивы структурированного потока управления, такие как lax.scan
, или просто использовать jit
для небольших подфункций. Linux x86_64 | Linux aarch64 | Мак x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 | |
---|---|---|---|---|---|---|
Процессор | да | да | да | да | да | да |
NVIDIA графический процессор | да | да | нет | н/д | нет | экспериментальный |
Гугл ТПУ | да | н/д | н/д | н/д | н/д | н/д |
AMD графический процессор | да | нет | экспериментальный | н/д | нет | нет |
Apple графический процессор | н/д | нет | н/д | экспериментальный | н/д | н/д |
графический процессор Intel | экспериментальный | н/д | н/д | н/д | нет | нет |
Платформа | Инструкции |
---|---|
Процессор | pip install -U jax |
NVIDIA графический процессор | pip install -U "jax[cuda12]" |
Гугл ТПУ | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
Графический процессор AMD (Linux) | Используйте Docker, готовые колеса или сборку из исходного кода. |
Графический процессор Mac | Следуйте инструкциям Apple. |
графический процессор Intel | Следуйте инструкциям Intel. |
См. документацию для получения информации об альтернативных стратегиях установки. К ним относятся компиляция из исходного кода, установка с помощью Docker, использование других версий CUDA, сборка conda, поддерживаемая сообществом, а также ответы на некоторые часто задаваемые вопросы.
Несколько исследовательских групп Google в Google DeepMind и Alphabet разрабатывают и совместно используют библиотеки для обучения нейронных сетей в JAX. Если вам нужна полнофункциональная библиотека для обучения нейронных сетей с примерами и практическими руководствами, попробуйте Flax и его сайт документации.
Посетите раздел «Экосистема JAX» на сайте документации JAX, чтобы увидеть список сетевых библиотек на основе JAX, в который входят Optax для обработки и оптимизации градиентов, chex для надежного кода и тестирования и Equinox для нейронных сетей. (Для получения дополнительных подробностей посмотрите обсуждение экосистемы JAX NeurIPS 2020 на DeepMind здесь.)
Чтобы процитировать этот репозиторий:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
В приведенной выше записи bibtex имена расположены в алфавитном порядке, номер версии должен быть взят из jax/version.py, а год соответствует выпуску проекта с открытым исходным кодом.
Зарождающаяся версия JAX, поддерживающая только автоматическое дифференцирование и компиляцию в XLA, была описана в документе, опубликованном на SysML 2018. В настоящее время мы работаем над освещением идей и возможностей JAX в более полной и актуальной статье.
Подробную информацию о JAX API см. в справочной документации.
Чтобы начать работу в качестве разработчика JAX, см. документацию для разработчиков.