เริ่มต้นอย่างรวดเร็ว | การเปลี่ยนแปลง | คู่มือการติดตั้ง | ไลบรารีเน็ตประสาท | บันทึกการเปลี่ยนแปลง | เอกสารอ้างอิง
JAX คือไลบรารี Python สำหรับการคำนวณอาเรย์ที่มุ่งเน้นการเร่งความเร็วและการแปลงโปรแกรม ซึ่งออกแบบมาสำหรับการประมวลผลเชิงตัวเลขประสิทธิภาพสูงและการเรียนรู้ของเครื่องขนาดใหญ่
ด้วย Autograd เวอร์ชันอัปเดต JAX สามารถแยกแยะฟังก์ชัน Python และ NumPy ดั้งเดิมได้โดยอัตโนมัติ มันสามารถแยกความแตกต่างผ่านการวนซ้ำ การแตกแขนง การเรียกซ้ำ และการปิด และสามารถรับอนุพันธ์ของอนุพันธ์ของอนุพันธ์ได้ รองรับการสร้างความแตกต่างของโหมดย้อนกลับ (aka backpropagation) ผ่านทาง grad
และการสร้างความแตกต่างในโหมดไปข้างหน้า และทั้งสองสามารถประกอบขึ้นตามอำเภอใจในลำดับใดก็ได้
มีอะไรใหม่คือ JAX ใช้ XLA เพื่อคอมไพล์และรันโปรแกรม NumPy บน GPU และ TPU การคอมไพล์จะเกิดขึ้นภายใต้ประทุนตามค่าเริ่มต้น โดยที่การเรียกไลบรารีจะได้รับการคอมไพล์และดำเนินการทันเวลาพอดี แต่ JAX ยังให้คุณคอมไพล์ฟังก์ชัน Python ของคุณเองลงในเคอร์เนลที่ปรับให้เหมาะสมกับ XLA ได้ทันเวลาโดยใช้ API ฟังก์ชันเดียว jit
การคอมไพล์และการสร้างความแตกต่างอัตโนมัติสามารถสร้างขึ้นได้ตามใจชอบ ดังนั้นคุณจึงสามารถแสดงอัลกอริธึมที่ซับซ้อนและรับประสิทธิภาพสูงสุดโดยไม่ต้องออกจาก Python คุณสามารถตั้งโปรแกรม GPU หรือแกน TPU หลายตัวพร้อมกันได้โดยใช้ pmap
และแยกความแตกต่างจากสิ่งทั้งหมด
ลองขุดลึกลงไปอีกหน่อย แล้วคุณจะเห็นว่า JAX เป็นระบบที่ขยายได้จริงๆ สำหรับการแปลงฟังก์ชันที่เขียนได้ ทั้ง grad
และ jit
เป็นตัวอย่างของการเปลี่ยนแปลงดังกล่าว ส่วนอื่นๆ ได้แก่ vmap
สำหรับ vectorization อัตโนมัติ และ pmap
สำหรับการเขียนโปรแกรมแบบขนานหลายข้อมูล (SPMD) โปรแกรมเดี่ยวของตัวเร่งความเร็วหลายตัว และยังมีอีกมากมายที่จะตามมา
นี่เป็นโครงการวิจัย ไม่ใช่ผลิตภัณฑ์อย่างเป็นทางการของ Google คาดว่าจะมีข้อบกพร่องและขอบคม โปรดช่วยด้วยการทดลองใช้ รายงานข้อบกพร่อง และแจ้งให้เราทราบว่าคุณคิดอย่างไร!
import jax . numpy as jnp
from jax import grad , jit , vmap
def predict ( params , inputs ):
for W , b in params :
outputs = jnp . dot ( inputs , W ) + b
inputs = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
def loss ( params , inputs , targets ):
preds = predict ( params , inputs )
return jnp . sum (( preds - targets ) ** 2 )
grad_loss = jit ( grad ( loss )) # compiled gradient evaluation function
perex_grads = jit ( vmap ( grad_loss , in_axes = ( None , 0 , 0 ))) # fast per-example grads
ใช้งานได้ทันทีโดยใช้โน้ตบุ๊กในเบราว์เซอร์ของคุณ ซึ่งเชื่อมต่อกับ Google Cloud GPU นี่คือสมุดบันทึกเริ่มต้นบางส่วน:
grad
สำหรับการสร้างความแตกต่าง, jit
สำหรับการรวบรวม และ vmap
สำหรับการสร้างเวกเตอร์ขณะนี้ JAX ทำงานบน Cloud TPU หากต้องการลองใช้ตัวอย่าง โปรดดูที่ Cloud TPU Colabs
หากต้องการเจาะลึกเกี่ยวกับ JAX:
โดยพื้นฐานแล้ว JAX เป็นระบบที่ขยายได้สำหรับการแปลงฟังก์ชันตัวเลข ต่อไปนี้เป็นการเปลี่ยนแปลงที่สนใจหลักสี่ประการ: grad
, jit
, vmap
และ pmap
grad
JAX มี API เหมือนกับ Autograd โดยประมาณ ฟังก์ชันที่ได้รับความนิยมมากที่สุดคือ grad
สำหรับการไล่ระดับสีแบบโหมดย้อนกลับ:
from jax import grad
import jax . numpy as jnp
def tanh ( x ): # Define a function
y = jnp . exp ( - 2.0 * x )
return ( 1.0 - y ) / ( 1.0 + y )
grad_tanh = grad ( tanh ) # Obtain its gradient function
print ( grad_tanh ( 1.0 )) # Evaluate it at x = 1.0
# prints 0.4199743
คุณสามารถแยกความแตกต่างจากคำสั่งซื้อใดๆ ด้วย grad
print ( grad ( grad ( grad ( tanh )))( 1.0 ))
# prints 0.62162673
สำหรับ autodiff ขั้นสูง คุณสามารถใช้ jax.vjp
สำหรับผลิตภัณฑ์ vector-Jacobian แบบโหมดย้อนกลับ และ jax.jvp
สำหรับผลิตภัณฑ์ Jacobian-vector แบบโหมดส่งต่อ ทั้งสองสามารถถูกประกอบขึ้นโดยพลการต่อกัน และกับการแปลง JAX อื่นๆ ต่อไปนี้เป็นวิธีหนึ่งในการเขียนเมทริกซ์เหล่านั้นเพื่อสร้างฟังก์ชันที่คำนวณเมทริกซ์ Hessian แบบเต็มได้อย่างมีประสิทธิภาพ:
from jax import jit , jacfwd , jacrev
def hessian ( fun ):
return jit ( jacfwd ( jacrev ( fun )))
เช่นเดียวกับ Autograd คุณสามารถใช้การสร้างความแตกต่างกับโครงสร้างควบคุม Python ได้ฟรี:
def abs_val ( x ):
if x > 0 :
return x
else :
return - x
abs_val_grad = grad ( abs_val )
print ( abs_val_grad ( 1.0 )) # prints 1.0
print ( abs_val_grad ( - 1.0 )) # prints -1.0 (abs_val is re-evaluated)
ดูเอกสารอ้างอิงเกี่ยวกับการสร้างความแตกต่างอัตโนมัติและ JAX Autodiff Cookbook สำหรับข้อมูลเพิ่มเติม
jit
คุณสามารถใช้ XLA เพื่อคอมไพล์ฟังก์ชันของคุณตั้งแต่ต้นจนจบด้วย jit
ซึ่งใช้เป็น @jit
มัณฑนากรหรือเป็นฟังก์ชันลำดับที่สูงกว่า
import jax . numpy as jnp
from jax import jit
def slow_f ( x ):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp . ones (( 5000 , 5000 ))
fast_f = jit ( slow_f )
% timeit - n10 - r3 fast_f ( x ) # ~ 4.5 ms / loop on Titan X
% timeit - n10 - r3 slow_f ( x ) # ~ 14.5 ms / loop (also on GPU via JAX)
คุณสามารถผสม jit
และ grad
และการแปลง JAX อื่นๆ ได้ตามที่คุณต้องการ
การใช้ jit
จะทำให้เกิดข้อจำกัดกับประเภทของโฟลว์การควบคุม Python ที่ฟังก์ชันสามารถใช้ได้ ดูบทช่วยสอนเกี่ยวกับโฟลว์การควบคุมและตัวดำเนินการเชิงตรรกะด้วย JIT สำหรับข้อมูลเพิ่มเติม
vmap
vmap
เป็นแผนที่แบบเวกเตอร์ มันมีความหมายที่คุ้นเคยในการแมปฟังก์ชันตามแกนอาเรย์ แต่แทนที่จะเก็บลูปไว้ด้านนอก มันจะดันลูปลงไปที่การดำเนินการดั้งเดิมของฟังก์ชันเพื่อประสิทธิภาพที่ดีขึ้น
การใช้ vmap
ช่วยให้คุณไม่ต้องแบกขนาดเป็นชุดในโค้ดของคุณ ตัวอย่างเช่น ลองพิจารณาฟังก์ชันการทำนายโครงข่ายประสาทเทียมแบบง่ายๆ ที่ไม่แบตช์ แบบง่ายๆ นี้:
def predict ( params , input_vec ):
assert input_vec . ndim == 1
activations = input_vec
for W , b in params :
outputs = jnp . dot ( W , activations ) + b # `activations` on the right-hand side!
activations = jnp . tanh ( outputs ) # inputs to the next layer
return outputs # no activation on last layer
เรามักจะเขียน jnp.dot(activations, W)
แทนเพื่ออนุญาตมิติข้อมูลแบทช์ทางด้านซ้ายของ activations
แต่เราได้เขียนฟังก์ชันการทำนายเฉพาะนี้เพื่อใช้กับเวกเตอร์อินพุตเดี่ยวเท่านั้น หากเราต้องการใช้ฟังก์ชันนี้กับชุดอินพุตในคราวเดียว เราก็สามารถเขียนเชิงความหมายได้
from functools import partial
predictions = jnp . stack ( list ( map ( partial ( predict , params ), input_batch )))
แต่การผลักดันตัวอย่างหนึ่งผ่านเครือข่ายในแต่ละครั้งจะช้า! เป็นการดีกว่าถ้าคำนวณแบบเวคเตอร์ เพื่อที่ว่าในทุกเลเยอร์ เราจะทำการคูณเมทริกซ์-เมทริกซ์ แทนที่จะคูณเมทริกซ์-เวกเตอร์
ฟังก์ชัน vmap
ทำการเปลี่ยนแปลงนั้นให้เรา นั่นคือถ้าเราเขียน
from jax import vmap
predictions = vmap ( partial ( predict , params ))( input_batch )
# or, alternatively
predictions = vmap ( predict , in_axes = ( None , 0 ))( params , input_batch )
จากนั้นฟังก์ชัน vmap
จะดันลูปด้านนอกเข้าไปภายในฟังก์ชัน และเครื่องของเราจะลงเอยด้วยการคูณเมทริกซ์-เมทริกซ์เหมือนกับว่าเราทำการแบทช์ด้วยมือ
มันง่ายพอที่จะแบทช์โครงข่ายประสาทเทียมอย่างง่ายด้วยตนเองโดยไม่ต้องใช้ vmap
แต่ในกรณีอื่น ๆ การสร้างเวกเตอร์ด้วยตนเองอาจไม่สามารถทำได้หรือเป็นไปไม่ได้ แก้ไขปัญหาของการคำนวณการไล่ระดับสีต่อตัวอย่างอย่างมีประสิทธิภาพ กล่าวคือ สำหรับชุดพารามิเตอร์คงที่ เราต้องการคำนวณการไล่ระดับสีของฟังก์ชันการสูญเสียของเราโดยประเมินแยกกันในแต่ละตัวอย่างในชุดงาน ด้วย vmap
มันง่ายมาก:
per_example_gradients = vmap ( partial ( grad ( loss ), params ))( inputs , targets )
แน่นอนว่า vmap
สามารถแต่งขึ้นได้ตามใจชอบด้วย jit
, grad
และการแปลง JAX อื่นๆ! เราใช้ vmap
ที่มีการสร้างความแตกต่างอัตโนมัติทั้งโหมดเดินหน้าและถอยหลังเพื่อการคำนวณเมทริกซ์ Jacobian และ Hessian ที่รวดเร็วใน jax.jacfwd
, jax.jacrev
และ jax.hessian
pmap
สำหรับการเขียนโปรแกรมแบบขนานของตัวเร่งความเร็วหลายตัว เช่น GPU หลายตัว ให้ใช้ pmap
ด้วย pmap
คุณจะเขียนโปรแกรมเดี่ยวหลายข้อมูล (SPMD) รวมถึงการดำเนินการสื่อสารรวมแบบขนานที่รวดเร็ว การใช้ pmap
หมายความว่าฟังก์ชันที่คุณเขียนถูกคอมไพล์โดย XLA (คล้ายกับ jit
) จากนั้นจำลองและดำเนินการแบบขนานระหว่างอุปกรณ์ต่างๆ
นี่คือตัวอย่างบนเครื่อง 8-GPU:
from jax import random , pmap
import jax . numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random . split ( random . key ( 0 ), 8 )
mats = pmap ( lambda key : random . normal ( key , ( 5000 , 6000 )))( keys )
# Run a local matmul on each device in parallel (no data transfer)
result = pmap ( lambda x : jnp . dot ( x , x . T ))( mats ) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print ( pmap ( jnp . mean )( result ))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
นอกเหนือจากการแสดงแผนที่อย่างแท้จริงแล้ว คุณยังสามารถใช้การดำเนินการสื่อสารรวมที่รวดเร็วระหว่างอุปกรณ์ต่างๆ ได้:
from functools import partial
from jax import lax
@ partial ( pmap , axis_name = 'i' )
def normalize ( x ):
return x / lax . psum ( x , 'i' )
print ( normalize ( jnp . arange ( 4. )))
# prints [0. 0.16666667 0.33333334 0.5 ]
คุณยังสามารถซ้อนฟังก์ชัน pmap
สำหรับรูปแบบการสื่อสารที่ซับซ้อนยิ่งขึ้นได้
ทุกอย่างประกอบขึ้นเป็นองค์ประกอบ คุณจึงมีอิสระที่จะแยกแยะความแตกต่างผ่านการคำนวณแบบคู่ขนาน:
from jax import grad
@ pmap
def f ( x ):
y = jnp . sin ( x )
@ pmap
def g ( z ):
return jnp . cos ( z ) * jnp . tan ( y . sum ()) * jnp . tanh ( x ). sum ()
return grad ( lambda w : jnp . sum ( g ( w )))( x )
print ( f ( x ))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print ( grad ( lambda x : jnp . sum ( f ( x )))( x ))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
เมื่อโหมดย้อนกลับสร้างความแตกต่างให้กับฟังก์ชัน pmap
(เช่น ด้วย grad
) การย้อนกลับของการคำนวณจะขนานกันเหมือนกับการส่งผ่านไปข้างหน้า
ดู SPMD Cookbook และตัวแยกประเภท SPMD MNIST ตั้งแต่ต้นเพื่อดูข้อมูลเพิ่มเติม
หากต้องการสำรวจ gotchas ปัจจุบันอย่างละเอียดยิ่งขึ้น พร้อมตัวอย่างและคำอธิบาย เราขอแนะนำให้อ่าน Gotchas Notebook ความโดดเด่นบางประการ:
is
ไม่ถูกรักษาไว้) หากคุณใช้การแปลง JAX กับฟังก์ชัน Python ที่ไม่บริสุทธิ์ คุณอาจเห็นข้อผิดพลาดเช่น Exception: Can't lift Traced...
หรือ Exception: Different traces at same level
x[i] += y
แต่มีทางเลือกอื่นที่ใช้งานได้ ภายใต้ jit
ทางเลือกการทำงานเหล่านั้นจะนำบัฟเฟอร์กลับมาใช้ใหม่โดยอัตโนมัติjax.lax
float32
) ตามค่าเริ่มต้น และเพื่อเปิดใช้งานความแม่นยำสองเท่า (64 บิต เช่น float64
) เราจำเป็นต้องตั้งค่าตัวแปร jax_enable_x64
เมื่อเริ่มต้น (หรือตั้งค่าตัวแปรสภาพแวดล้อม JAX_ENABLE_X64=True
) . บน TPU นั้น JAX จะใช้ค่า 32 บิตเป็นค่าเริ่มต้นสำหรับทุกอย่าง ยกเว้น ตัวแปรชั่วคราวภายในในการดำเนินการ 'matmul-like' เช่น jax.numpy.dot
และ lax.conv
การดำเนินการเหล่านั้นมีพารามิเตอร์ precision
ซึ่งสามารถใช้เพื่อประมาณการดำเนินการ 32 บิตผ่านการผ่าน bfloat16 สามครั้ง โดยมีต้นทุนรันไทม์ที่ช้ากว่า การดำเนินการที่ไม่ใช่ matmul บน TPU ต่ำกว่าการใช้งานที่มักเน้นความเร็วมากกว่าความแม่นยำ ดังนั้นในทางปฏิบัติการคำนวณบน TPU จะมีความแม่นยำน้อยกว่าการคำนวณที่คล้ายกันบนแบ็กเอนด์อื่นๆnp.add(1, np.array([2], np.float32)).dtype
คือ float64
มากกว่า float32
jit
จะจำกัดวิธีการใช้โฟลว์การควบคุม Python คุณจะได้รับข้อผิดพลาดที่ดังเสมอหากมีสิ่งผิดปกติเกิดขึ้น คุณอาจต้องใช้พารามิเตอร์ static_argnums
ของ jit
ซึ่งเป็นโฟลว์การควบคุมที่มีโครงสร้างเบื้องต้น เช่น lax.scan
หรือเพียงใช้ jit
กับฟังก์ชันย่อยที่เล็กกว่า ลินุกซ์ x86_64 | ลินุกซ์ aarch64 | แมค x86_64 | แมค aarch64 | วินโดว์ x86_64 | วินโดวส์ WSL2 x86_64 | |
---|---|---|---|---|---|---|
ซีพียู | ใช่ | ใช่ | ใช่ | ใช่ | ใช่ | ใช่ |
NVIDIA GPU | ใช่ | ใช่ | เลขที่ | ไม่มี | เลขที่ | ทดลอง |
กูเกิลทีพียู | ใช่ | ไม่มี | ไม่มี | ไม่มี | ไม่มี | ไม่มี |
เอเอ็มดีจีพียู | ใช่ | เลขที่ | ทดลอง | ไม่มี | เลขที่ | เลขที่ |
แอปเปิลจีพียู | ไม่มี | เลขที่ | ไม่มี | ทดลอง | ไม่มี | ไม่มี |
อินเทลจีพียู | ทดลอง | ไม่มี | ไม่มี | ไม่มี | เลขที่ | เลขที่ |
แพลตฟอร์ม | คำแนะนำ |
---|---|
ซีพียู | pip install -U jax |
NVIDIA GPU | pip install -U "jax[cuda12]" |
กูเกิลทีพียู | pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html |
เอเอ็มดี GPU (ลินุกซ์) | ใช้ Docker ล้อที่สร้างไว้ล่วงหน้า หรือสร้างจากแหล่งที่มา |
แมคจีพียู | ปฏิบัติตามคำแนะนำของ Apple |
อินเทลจีพียู | ปฏิบัติตามคำแนะนำของ Intel |
ดูเอกสารประกอบสำหรับข้อมูลเกี่ยวกับกลยุทธ์การติดตั้งทางเลือก ซึ่งรวมถึงการคอมไพล์จากแหล่งที่มา การติดตั้งด้วย Docker การใช้ CUDA เวอร์ชันอื่น การสร้าง Conda ที่สนับสนุนโดยชุมชน และคำตอบสำหรับคำถามที่พบบ่อยบางข้อ
กลุ่มวิจัยของ Google หลายกลุ่มที่ Google DeepMind และ Alphabet พัฒนาและแบ่งปันไลบรารีสำหรับการฝึกอบรมโครงข่ายประสาทเทียมใน JAX หากคุณต้องการไลบรารีที่มีคุณลักษณะครบถ้วนสำหรับการฝึกอบรมโครงข่ายประสาทเทียมพร้อมตัวอย่างและคำแนะนำวิธีใช้ ลองใช้ Flax และไซต์เอกสารประกอบ
ตรวจสอบส่วนระบบนิเวศของ JAX บนไซต์เอกสารประกอบของ JAX เพื่อดูรายการไลบรารีเครือข่ายที่ใช้ JAX ซึ่งรวมถึง Optax สำหรับการประมวลผลและการเพิ่มประสิทธิภาพแบบไล่ระดับ chex สำหรับโค้ดและการทดสอบที่เชื่อถือได้ และ Equinox สำหรับเครือข่ายประสาทเทียม (ดูระบบนิเวศ JAX ของ NeurIPS 2020 ที่ DeepMind พูดคุยที่นี่เพื่อดูรายละเอียดเพิ่มเติม)
หากต้องการอ้างอิงที่เก็บนี้:
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/jax-ml/jax},
version = {0.3.13},
year = {2018},
}
ในรายการ bibtex ข้างต้น ชื่อจะเรียงตามตัวอักษร หมายเลขเวอร์ชันตั้งใจให้เป็นจาก jax/version.py และปีสอดคล้องกับการเปิดตัวโอเพ่นซอร์สของโปรเจ็กต์
JAX เวอร์ชันใหม่ซึ่งสนับสนุนเฉพาะการสร้างความแตกต่างและการคอมไพล์อัตโนมัติไปยัง XLA ได้รับการอธิบายไว้ในรายงานที่ปรากฏใน SysML 2018 ขณะนี้เรากำลังดำเนินการเพื่อครอบคลุมแนวคิดและความสามารถของ JAX ในรายงานที่ครอบคลุมและทันสมัยยิ่งขึ้น
สำหรับรายละเอียดเกี่ยวกับ JAX API โปรดดูเอกสารอ้างอิง
สำหรับการเริ่มต้นเป็นนักพัฒนา JAX โปรดดูเอกสารประกอบสำหรับนักพัฒนา