memory efficient attention pytorch
0.1.6
論文「Self-attention Does Not Need O(n²) Memory」で提案されている、メモリ効率の高いマルチヘッド アテンションの実装。さらに、このモジュールはマスキング、因果マスキング、クロス アテンションを処理します。
このリポジトリには、教育目的で Tri Dao が Flash Attendant 2 論文で行った改善の単純な非 CUDA 実装も含まれています。これは注目を集め、ロングコンテキストのトランスフォーマーを構築するためのゲームチェンジャーです。
更新: 今後は、組み込みの Flash アテンション v1 サポートのために Pytorch 2.0 のF.scaled_dot_product_attention
関数を使用するか、公式リポジトリにある Flash アテンション v2 を使用する必要があります。
$ pip install memory-efficient-attention-pytorch
自己回帰言語モデルの場合
import torch
from memory_efficient_attention_pytorch import Attention
attn = Attention (
dim = 512 ,
dim_head = 64 , # dimension per head
heads = 8 , # number of attention heads
causal = True , # autoregressive or not
memory_efficient = True , # whether to use memory efficient attention (can be turned off to test against normal attention)
q_bucket_size = 1024 , # bucket size along queries dimension
k_bucket_size = 2048 # bucket size along key / values dimension
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
out = attn ( x ) # (1, 65536, 512)
クロスアテンション
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention (
dim = 512 ,
dim_head = 64 ,
heads = 8 ,
memory_efficient = True ,
q_bucket_size = 1024 ,
k_bucket_size = 2048
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
context = torch . randn ( 1 , 65536 , 512 ). cuda ()
mask = torch . ones ( 1 , 65536 ). bool (). cuda ()
out = cross_attn ( x , context = context , mask = mask ) # (1, 65536, 512)
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { liu2021swin ,
title = { Swin Transformer V2: Scaling Up Capacity and Resolution } ,
author = { Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo } ,
year = { 2021 } ,
eprint = { 2111.09883 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@article { dao2023flashattention2 ,
title = { Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}