深入研究深度学习,由 Quanta 杂志重做
以与 Flash Attention 相同的风格实现融合余弦相似度注意力。观察结果是,通过采用 l2 规范化查询和键,您不再需要跟踪行最大值以获得数值稳定性。假设余弦相似性注意力没有泛化成本,这极大地简化了闪存注意力算法。
换句话说,稳定、快速、内存效率高、上下文注意力时间更长,没有任何缺点。
更新:不幸的是,Robin 的实验显示出更差的评估 FID 分数,并未反映在损失中。等待更多实验。请谨慎使用该库。
更新 2:唯一的优点是使用分组 l2norm,这可能会允许更多的表达能力。如果有人可以在他们的生成工作中评估这项技术并获得一些 FID 分数,我们将不胜感激。
更新 3:类似于 cosine sim Attention 的方法已经通过 Brain 的 22B 参数视觉模型得到了大规模验证。
目前,自回归和可变长度序列在所有架构中都应该更快。对于长度超过 2048 的序列,它也将具有记忆效率,而常规注意力则不会。
然而,对于没有掩蔽的非自回归,F16 的 A100 架构仍然较慢。目的是让它在 A100 上向前和向后为 F32 和 F16 执行更快,因为共享内存尚未充分利用。
较旧的显卡没有足够的共享内存,因此必须根据训练的序列长度来权衡内存效率和速度。
Arthur Hennequin 指导我完成第一个 CUDA 内核,并编写了一个简单的参考实现,这帮助我引导了第一个在合理性能范围内达到基线的内核。如果没有他的专业知识,这项工作是不可能完成的。
Boris Dayma 和 Robin Rombach 在一些重要的文本到图像模型上运行了简化的余弦模拟注意力实验,并进行了固定缩放,并验证了它的性能确实与常规注意力一样好。
Markus Rabe 撰写的论文表明注意力不需要 O(n²) 内存,Tri Dao 将其全部放在 CUDA 内核实现中以进行定期注意力,展示了使用平铺方法最小化 HBM 访问的速度优势(并计算out dO * O == dP * P
用于向后传递)。如果没有他们的发现,我就无法完成我寻找终极注意力配方的朝圣之旅。
Stability.ai 慷慨赞助前沿人工智能研究
$ pip install flash-cosine-sim-attention
自我关注
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
交叉注意力
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
带键/值屏蔽
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
自回归
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
单头键/值(Shazeer 等人并在 PaLM 中使用)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
如果您需要在 l2norm 和实际注意步骤之间对查询和键进行操作,只需设置l2norm_qk = False
前任。
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
交叉注意力与因果关系按预期工作 - (在推理过程中缓存自回归中的键和值,或类似 Transformer-xl 的训练)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
如果您合并了批次和头部尺寸,那就可以了
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - f32
32
64
96
128
16-f16
80 - 进行中
bfloat16 支持,按照 Arthur 的推荐使用 sfinae
从 qk_mma 流到共享内存以计算出 mma,看看释放的 smem 是否可以用于缓存更多
支持 O(n) 1d 动态位置偏差
弄清楚为什么 smem 片段缓存会导致性能下降,这没有意义
考虑使用 logsumexp - 有效,但额外的日志会导致性能下降
准备一个 smem 片段缓存机制,以允许 A100(或 f16)上允许的尽可能多的缓存
使注意力图块尺寸处理可定制以用于向后传递
将原子添加移动到 mma 内的重载函数
灵活使用哪种类型进行累积
在 f16 上测试 64x96 块
仅使用简单的 pytorch 代码引入 CPU 内存高效版本(仅用于推理,因为训练没有意义)
弄清楚如何针对架构(例如 A100)进行不同的分派,以防向后可以以不同的方式利用共享内存的增加
将注意力图块的行和列大小解耦
dk 和 dv 现在处于 f16 状态(非单头 kv)
支持更多标准头尺寸(在制品)
再次调试并修复头部大小为 32 的向后梯度偏差
修复注意力偏差梯度
允许单头键/值,如 PaLM 中那样
修复 f16 的原子添加
注意力偏差应该能够接受额外批量维度的维度,对于 Alphafold2 来说就像注意力偏差
使用版本作为包名称的后缀自动清除内核的缓存
解决 f16 因果数值问题
采用从前向内核到后向内核的所有学习,并确保它至少在 A100 上表现优于
到目前为止,余弦相似度注意力机制在工业界还没有广泛应用。迄今为止,唯一用它训练过的大型模型是 SwinV2。如果有人可以使该方法无效,请提出问题或给我发送电子邮件。您可以使用 x-transformers 存储库针对常规注意力进行实验。
更新:Boris Dayma 慷慨地启动了一项实验(蓝色和红色作为基线),以在现实世界模型设置中以固定比例 10 验证余弦相似性注意力。
更新 2:余弦相似性注意力已经在现实世界的文本到图像注意力网络中得到了证明,使用的常数尺度为10
。不比平时的关注差。鲍里斯·戴玛 (Boris Dayma) 投入时间进行实验并消除了对该技术的疑虑。
更新 3:Robin Rombach 在文本到图像模型中测试了此存储库中的内核,头部大小为 64,固定比例为 10,观察到与常规注意力没有任何差异。更多评估待定。
更新 4:Boris 实验中看到的性能改进可能是由于余弦模拟注意力允许在 Transformer 中从前层范数配置切换到后层范数配置(因为 l2norm 有效地取代了前层范数)。层范数)。余弦模拟注意力可能会产生与常规注意力相同的结果,而不需要对变压器进行任何其他更改。
用于测试非自回归和自回归场景的输出和梯度是否相等
$ python setup.py test
确保先安装CUDA内核
$ python setup . py install
然后
$ python benchmark . py
对于仅向前或向后进行基准测试,请在上面附加--only-forwards
或--only-backwards
标志。要对自回归进行基准测试,请附加--causal
向前
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
倒退 - 仍需努力
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
前进和后退 - F32 肯定更慢
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
对于自回归,明显获胜python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
对于带有掩码的可变长度序列,这也是一个明显的胜利。假设平均 25% 的标记被屏蔽掉python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
感谢 Stability 提供 A100 进行测试的访问权限。感谢 Enrico 在我还无法访问时花时间运行一些基准测试。
A100 仍在进行中。共享内存尚未得到充分利用。奇怪的是,F32 似乎比 F16 表现更好
前锋
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
向后
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
前进和后退
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
自回归
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
可变长度序列(最多屏蔽 25% 的标记)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
尝试 8192 序列长度。它会很慢,但会起作用(正常注意力将在 > 2048 处中断,如果删除--use-cuda-kernel
标志,您会看到这一点)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}