memory efficient attention pytorch
0.1.6
实现内存高效的多头注意力,如论文《自我注意力不需要 O(n²) 内存》中所提出的。此外,该模块还将处理掩蔽、因果掩蔽以及交叉注意力。
该存储库还包含 Tri Dao 及其 Flash Attention 2 论文所做改进的简单非 CUDA 实现,用于教育目的。它是注意力和构建长上下文变压器的游戏规则改变者。
更新:从现在开始,您应该只使用 Pytorch 2.0 中的F.scaled_dot_product_attention
函数来获得内置 Flash Attention v1 支持 - 或在官方存储库中使用 Flash Attention v2
$ pip install memory-efficient-attention-pytorch
对于自回归语言模型
import torch
from memory_efficient_attention_pytorch import Attention
attn = Attention (
dim = 512 ,
dim_head = 64 , # dimension per head
heads = 8 , # number of attention heads
causal = True , # autoregressive or not
memory_efficient = True , # whether to use memory efficient attention (can be turned off to test against normal attention)
q_bucket_size = 1024 , # bucket size along queries dimension
k_bucket_size = 2048 # bucket size along key / values dimension
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
out = attn ( x ) # (1, 65536, 512)
交叉注意力
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention (
dim = 512 ,
dim_head = 64 ,
heads = 8 ,
memory_efficient = True ,
q_bucket_size = 1024 ,
k_bucket_size = 2048
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
context = torch . randn ( 1 , 65536 , 512 ). cuda ()
mask = torch . ones ( 1 , 65536 ). bool (). cuda ()
out = cross_attn ( x , context = context , mask = mask ) # (1, 65536, 512)
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { liu2021swin ,
title = { Swin Transformer V2: Scaling Up Capacity and Resolution } ,
author = { Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo } ,
year = { 2021 } ,
eprint = { 2111.09883 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@article { dao2023flashattention2 ,
title = { Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}