Реализация эффективного многоголового внимания с памятью, как предложено в статье «Самовнимание не нуждается в памяти O(n²)». Кроме того, модуль позаботится о маскировке, причинно-следственной маскировке, а также о перекрестном внимании.
Этот репозиторий также содержит наивную реализацию без CUDA улучшений, сделанных Три Дао в его статье Flash Attention 2, для образовательных целей. Это меняет правила игры в плане внимания и создания преобразователей длительного контекста.
Обновление: с этого момента вам следует просто использовать функцию F.scaled_dot_product_attention
в Pytorch 2.0 для встроенной поддержки Flash Attention v1 или использовать Flash Attention v2 в официальном репозитории.
$ pip install memory-efficient-attention-pytorch
Для авторегрессионной языковой модели
import torch
from memory_efficient_attention_pytorch import Attention
attn = Attention (
dim = 512 ,
dim_head = 64 , # dimension per head
heads = 8 , # number of attention heads
causal = True , # autoregressive or not
memory_efficient = True , # whether to use memory efficient attention (can be turned off to test against normal attention)
q_bucket_size = 1024 , # bucket size along queries dimension
k_bucket_size = 2048 # bucket size along key / values dimension
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
out = attn ( x ) # (1, 65536, 512)
Перекрестное внимание
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention (
dim = 512 ,
dim_head = 64 ,
heads = 8 ,
memory_efficient = True ,
q_bucket_size = 1024 ,
k_bucket_size = 2048
). cuda ()
x = torch . randn ( 1 , 65536 , 512 ). cuda ()
context = torch . randn ( 1 , 65536 , 512 ). cuda ()
mask = torch . ones ( 1 , 65536 ). bool (). cuda ()
out = cross_attn ( x , context = context , mask = mask ) # (1, 65536, 512)
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { liu2021swin ,
title = { Swin Transformer V2: Scaling Up Capacity and Resolution } ,
author = { Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo } ,
year = { 2021 } ,
eprint = { 2111.09883 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@article { dao2023flashattention2 ,
title = { Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}