深入研究深度學習,由 Quanta 雜誌重做
以與 Flash Attention 相同的風格實現融合餘弦相似度注意力。觀察結果是,透過採用 l2 規範化查詢和鍵,您不再需要追蹤行最大值以獲得數值穩定性。假設餘弦相似性注意力沒有泛化成本,這大大簡化了閃存注意力演算法。
換句話說,穩定、快速、記憶體效率高、情境注意力時間更長,沒有任何缺點。
更新:不幸的是,Robin 的實驗顯示出較差的評估 FID 分數,並未反映在損失中。等待更多實驗。請謹慎使用該函式庫。
更新 2:唯一的優點是使用分組 l2norm,這可能允許更多的表達能力。如果有人可以在他們的生成工作中評估這項技術並獲得一些 FID 分數,我們將不勝感激。
更新 3:類似於 cosine sim Attention 的方法已經透過 Brain 的 22B 參數視覺模型得到了大規模驗證。
目前,自回歸和可變長度序列在所有架構中都應該更快。對於長度超過 2048 的序列,它也將具有記憶效率,而常規注意力則不會。
然而,對於沒有掩蔽的非自回歸,F16 的 A100 架構仍然較慢。目的是讓它在 A100 上向前和向後為 F32 和 F16 執行更快,因為共享記憶體尚未充分利用。
較舊的顯示卡沒有足夠的共享內存,因此必須根據訓練的序列長度來權衡內存效率和速度。
Arthur Hennequin 指導我完成第一個 CUDA 內核,並編寫了一個簡單的參考實現,這幫助我引導了第一個在合理性能範圍內達到基線的內核。如果沒有他的專業知識,這項工作是不可能完成的。
Boris Dayma 和 Robin Rombach 在一些重要的文本到圖像模型上運行了簡化的餘弦模擬注意力實驗,並進行了固定縮放,並驗證了它的性能確實與常規注意力一樣好。
Markus Rabe 撰寫的論文表明注意力不需要 O(n²) 內存,Tri Dao 將其全部放在 CUDA 內核實現中以進行定期注意力,展示了使用平鋪方法最小化 HBM 訪問的速度優勢(併計算out dO * O == dP * P
用於向後傳遞)。如果沒有他們的發現,我就無法完成我尋找終極注意力配方的朝聖之旅。
Stability.ai 慷慨贊助前沿人工智慧研究
$ pip install flash-cosine-sim-attention
自我關注
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
交叉注意力
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
帶鍵/值屏蔽
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
自回歸
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
單頭鍵/值(Shazeer 等人並在 PaLM 中使用)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
如果您需要在 l2norm 和實際注意步驟之間對查詢和鍵進行操作,只需設定l2norm_qk = False
前任。
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
交叉注意力與因果關係按預期工作 - (在推理過程中緩存自回歸中的鍵和值,或類似 Transformer-xl 的訓練)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
如果您合併了批次和頭部尺寸,那就可以了
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - f32
32
64
96
128
16-f16
80 - 進行中
bfloat16 支持,按照 Arthur 的推薦使用 sfinae
從 qk_mma 流到共享記憶體以計算 mma,看看釋放的 smem 是否可以用於快取更多
支持 O(n) 1d 動態位置偏差
弄清楚為什麼 smem 片段快取會導致效能下降,這沒有意義
考慮使用 logsumexp - 有效,但額外的日誌會導致效能下降
準備一個 smem 片段快取機制,以允許 A100(或 f16)上允許的盡可能多的快取
使注意力圖塊尺寸處理可自訂以用於向後傳遞
將原子加入移動到 mma 內的重載函數
靈活使用哪種類型進行累積
在 f16 上測試 64x96 塊
僅使用簡單的 pytorch 程式碼引入 CPU 記憶體高效版本(僅用於推理,因為訓練沒有意義)
弄清楚如何針對架構(例如 A100)進行不同的分派,以防向後可以以不同的方式利用共享記憶體的增加
將注意力圖塊的行和列大小解耦
dk 和 dv 現在處於 f16 狀態(非單頭 kv)
支援更多標準頭尺寸(在製品)
再次調試並修復頭部大小為 32 的向後梯度偏差
修復注意力偏差梯度
允許單頭鍵/值,如 PaLM 中那樣
修復 f16 的原子添加
注意力偏差應該能夠接受額外批量維度的維度,對於 Alphafold2 來說就像注意力偏差
使用版本作為套件名稱的後綴自動清除核心的緩存
解 f16 因果數值問題
採用從前向核心到後向核心的所有學習,並確保它至少在 A100 上表現優於
到目前為止,餘弦相似度注意力機制在工業界還沒有廣泛應用。到目前為止,唯一用它訓練過的大型模型是 SwinV2。如果有人可以使該方法無效,請提出問題或給我發送電子郵件。您可以使用 x-transformers 儲存庫針對常規注意力進行實驗。
更新:Boris Dayma 慷慨地啟動了一項實驗(藍色和紅色作為基線),以在現實世界模型設定中以固定比例 10 驗證餘弦相似性注意力。
更新 2:餘弦相似性注意力已經在現實世界的文本到圖像注意力網絡中得到了證明,使用的常數尺度為10
。不比平時的關注差。鮑里斯·戴瑪 (Boris Dayma) 投入時間進行實驗並消除了對該技術的疑慮。
更新 3:Robin Rombach 在文本到圖像模型中測試了此存儲庫中的內核,頭部大小為 64,固定比例為 10,觀察到與常規注意力沒有任何差異。更多評估待定。
更新 4:Boris 實驗中看到的性能改進可能是由於餘弦模擬注意力允許在 Transformer 中從前層範數配置切換到後層範數配置(因為 l2norm 有效地取代了前層範數)。 )。餘弦模擬注意力可能會產生與常規注意力相同的結果,而不需要對變壓器進行任何其他更改。
用於測試非自回歸和自回歸場景的輸出和梯度是否相等
$ python setup.py test
確保先安裝CUDA內核
$ python setup . py install
然後
$ python benchmark . py
對於僅向前或向後進行基準測試,請在上面附加--only-forwards
或--only-backwards
標誌。若要對自回歸進行基準測試,請附加--causal
向前
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
倒退 - 仍需努力
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
前進和後退 - F32 肯定更慢
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
對於自回歸,明顯獲勝python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
對於帶有遮罩的可變長度序列,這也是一個明顯的勝利。假設平均 25% 的標記被屏蔽掉python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
感謝 Stability 提供 A100 進行測試的存取權限。感謝 Enrico 在我還無法訪問時花時間運行一些基準測試。
A100 仍在進行中。共享記憶體尚未充分利用。奇怪的是,F32 似乎比 F16 表現更好
前鋒
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
向後
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
前進和後退
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
自回歸
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
可變長度序列(最多屏蔽 25% 的標記)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
嘗試 8192 序列長度。它會很慢,但會起作用(正常注意力將在 > 2048 處中斷,如果刪除--use-cuda-kernel
標誌,您會看到這一點)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}