memory efficient attention pytorch
0.1.6
การใช้ความสนใจแบบหลายหัวที่มีประสิทธิภาพของหน่วยความจำตามที่เสนอในรายงาน การเอาใจใส่ตนเองไม่จำเป็นต้องมีหน่วยความจำ O(n²) นอกจากนี้ โมดูลจะดูแลเรื่องการมาสก์ การมาสก์เชิงสาเหตุ และความสนใจแบบข้าม
พื้นที่เก็บข้อมูลนี้ยังประกอบด้วยการนำการปรับปรุงที่ Tri Dao กระทำโดย Tri Dao อย่างไร้เดียงสาด้วยเอกสาร Flash Attention 2 ของเขาเพื่อวัตถุประสงค์ทางการศึกษา มันเป็นตัวเปลี่ยนเกมสำหรับความสนใจและสร้างหม้อแปลงบริบทยาว
อัปเดต: จากนี้ไป คุณควรใช้ฟังก์ชัน F.scaled_dot_product_attention
ใน Pytorch 2.0 สำหรับการรองรับ Flash Attention v1 ในตัว - หรือใช้ Flash Attention v2 ที่พื้นที่เก็บข้อมูลอย่างเป็นทางการ
$ pip install memory-efficient-attention-pytorch
สำหรับโมเดลภาษาแบบถอยหลังอัตโนมัติ
import torch
from memory_efficient_attention_pytorch import Attention
attn = Attention (
dim = 512 ,
dim_head = 64 , # dimension per head
heads = 8 , # number of attention heads
causal = True , # autoregressive or not
memory_efficient = True , # whether to use memory efficient attention (can be turned off to test against normal attention)
q_bucket_size = 1024 , # bucket size along queries dimension
k_bucket_size = 2048 # bucket size along key / values dimension
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
out = attn ( x ) # (1, 65536, 512)
ข้ามความสนใจ
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention (
dim = 512 ,
dim_head = 64 ,
heads = 8 ,
memory_efficient = True ,
q_bucket_size = 1024 ,
k_bucket_size = 2048
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
context = torch . randn ( 1 , 65536 , 512 ). cuda ()
mask = torch . ones ( 1 , 65536 ). bool (). cuda ()
out = cross_attn ( x , context = context , mask = mask ) # (1, 65536, 512)
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { liu2021swin ,
title = { Swin Transformer V2: Scaling Up Capacity and Resolution } ,
author = { Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo } ,
year = { 2021 } ,
eprint = { 2111.09883 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@article { dao2023flashattention2 ,
title = { Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}