Tauchen Sie ein in Deep Learning, überarbeitet vom Quanta Magazine
Implementierung der Fused-Cosine-Ähnlichkeitsaufmerksamkeit im gleichen Stil wie Flash Attention. Die Beobachtung ist, dass Sie durch die Verwendung von l2-normalisierten Abfragen und Schlüsseln aus Gründen der numerischen Stabilität die Zeilenmaxima nicht mehr im Auge behalten müssen. Dies vereinfacht den Flash-Aufmerksamkeitsalgorithmus erheblich, vorausgesetzt, dass die Kosinusähnlichkeitsaufmerksamkeit ohne Kosten für die Generalisierung erfolgt.
Mit anderen Worten: stabil, schnell, speichereffizient und längere Kontextaufmerksamkeit ohne Nachteile.
Update: Leider zeigten Robins Experimente viel schlechtere FID-Bewertungen, die sich nicht im Verlust widerspiegelten. Warten auf weitere Experimente. Verwenden Sie diese Bibliothek mit Vorsicht.
Update 2: Die einzige Rettung wäre die Verwendung von gruppiertem l2norm, was möglicherweise mehr Ausdruckskraft ermöglichen könnte. Wenn jemand diese Technik bei seiner generativen Arbeit bewerten und einige FID-Ergebnisse erhalten kann, wäre ich sehr dankbar.
Update 3: Ein Ansatz ähnlich der Cosinus-Sim-Aufmerksamkeit wurde mit einem 22B-Parameter-Vision-Modell von Brain im Maßstab erprobt.
Derzeit sollten autoregressive und variable Sequenzen auf allen Architekturen schneller sein. Bei Sequenzen, die länger als 2048 sind, ist es auch speichereffizient, wo dies bei normaler Aufmerksamkeit nicht der Fall wäre.
Bei nicht-autoregressiver Architektur ohne Maskierung ist die Architektur auf A100 für F16 jedoch immer noch langsamer. Das Ziel besteht darin, eine schnellere Leistung auf dem A100 vorwärts und rückwärts sowohl für F32 als auch für F16 zu erreichen, da der gemeinsame Speicher noch nicht vollständig ausgenutzt ist.
Bei älteren Grafikkarten ohne ausreichend gemeinsamen Speicher muss der Kompromiss zwischen Speichereffizienz und Geschwindigkeit je nach der trainierten Sequenzlänge abgewogen werden.
Arthur Hennequin dafür, dass er mich durch meinen ersten CUDA-Kernel gecoacht hat und eine einfache Referenzimplementierung programmiert hat, die mir dabei geholfen hat, den ersten Kernel zu booten, der eine angemessene Leistung auf den Ausgangswert bringt. Ohne sein Fachwissen wäre diese Arbeit nicht möglich gewesen.
Boris Dayma und Robin Rombach für die Durchführung von Experimenten mit der vereinfachten Cosinus-Sim-Aufmerksamkeit mit fester Skalierung an einigen bedeutenden Text-zu-Bild-Modellen und für die Überprüfung, dass sie tatsächlich genauso gut funktioniert wie normale Aufmerksamkeit.
Markus Rabe für das Verfassen des Papiers, in dem gezeigt wurde, dass Aufmerksamkeit keinen O(n²)-Speicher erfordert, und Tri Dao dafür, dass er alles in einer CUDA-Kernel-Implementierung für regelmäßige Aufmerksamkeit zusammengestellt hat und die Überlegenheit in der Geschwindigkeit mithilfe des gekachelten Ansatzes demonstriert hat, der HBM-Zugriffe minimiert (und für die Berechnung). out dO * O == dP * P
für Rückwärtspass). Ohne ihre Entdeckungen wäre es mir nicht möglich gewesen, meine Pilgerreise auf der Suche nach der ultimativen Aufmerksamkeitsformulierung abzuschließen.
Stability.ai für das großzügige Sponsoring für die Arbeit an der Spitzenforschung im Bereich der künstlichen Intelligenz
$ pip install flash-cosine-sim-attention
Selbstaufmerksamkeit
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
Queraufmerksamkeit
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
Mit Schlüssel-/Wertmaskierung
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
Autoregressiv
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
Einköpfiger Schlüssel/Werte (Shazeer et al. & in PaLM verwendet)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
Wenn Sie zwischen l2norm und dem eigentlichen Aufmerksamkeitsschritt Operationen an den Abfragen und Schlüsseln durchführen müssen, setzen Sie einfach l2norm_qk = False
ex.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
Kreuzaufmerksamkeit mit Kausalität funktioniert wie erwartet – (Zwischenspeicherung von Schlüsseln und Werten in autoregressiver Funktion während der Inferenz oder Transformer-XL-ähnliches Training)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
Wenn Sie Chargen- und Kopfabmessungen zusammengeführt haben, ist das in Ordnung
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - f32
32
64
96
128
16 -f16
80 - in Bearbeitung
bfloat16-Unterstützung, verwenden Sie sfinae wie von Arthur empfohlen
Streamen Sie in Blöcken von qk_mma in den gemeinsam genutzten Speicher, um mma zu berechnen. Prüfen Sie, ob freigegebener Smem zum Zwischenspeichern weiterer Daten verwendet werden kann
Unterstützt O(n) 1d dynamische Positionsverzerrung
Herauszufinden, warum das Zwischenspeichern von SEM-Fragmenten zu Leistungseinbußen führen würde, macht keinen Sinn
Denken Sie über die Verwendung von logsumexp nach – funktioniert, aber zusätzliches Protokoll führt zu einer Leistungseinbuße
Bereiten Sie einen SEM-Fragment-Caching-Mechanismus vor, um so viel Caching zu ermöglichen, wie auf A100 (oder f16) zulässig ist.
Machen Sie die Verarbeitung der Aufmerksamkeitskachelgröße für den Rückwärtsdurchlauf anpassbar
Verschieben Sie die atomare Addition zur überladenen Funktion innerhalb von mma
flexibel, welcher Typ zur Akkumulation verwendet wird
Testen Sie 64x96-Kacheln auf f16
Bringen Sie eine CPU-speichereffiziente Version ein (nur als Schlussfolgerung, da Training keinen Sinn ergibt), indem Sie einfach einfachen Pytorch-Code verwenden
Finden Sie heraus, wie Sie die Verteilung für Architekturen (z. B. A100) unterschiedlich gestalten können, damit Sie die Erweiterung des gemeinsam genutzten Speichers anders nutzen können
Entkoppeln Sie Zeilen- und Spaltengrößen für Aufmerksamkeitskacheln
dk und dv sind jetzt in f16, wenn es möglich ist (nicht einköpfiges kv)
Unterstützung weiterer Standardkopfabmessungen (WIP)
Debuggen und korrigieren Sie die Verzerrungen nach hinten erneut für eine Kopfgröße von 32
Aufmerksamkeitsverzerrungsgradienten korrigieren
Erlauben Sie einköpfige Schlüssel/Werte, wie in PaLM
Atomic Add für f16 korrigieren
Aufmerksamkeitsverzerrung sollte in der Lage sein, Dimensionen einer zusätzlichen Batch-Dimension zu akzeptieren, für Alphafold2 wie Aufmerksamkeitsverzerrung
Automatisieren Sie das Cache-Busting des Kernels, indem Sie die Version als Suffix zum Paketnamen verwenden
Lösen Sie f16-kausale numerische Probleme
Übernehmen Sie alle Erkenntnisse vom Vorwärtskernel zum Rückwärtskernel und stellen Sie sicher, dass er mindestens auf A100 eine bessere Leistung erbringt
Bisher wird die Kosinusähnlichkeitsbetrachtung in der Industrie nicht weit verbreitet eingesetzt. Das einzige große Modell, das bisher damit trainiert wurde, ist SwinV2. Wenn jemand den Ansatz entkräften kann, öffnen Sie bitte ein Problem oder senden Sie mir eine E-Mail. Mit dem x-transformers-Repository können Sie Experimente gegen regelmäßige Aufmerksamkeit durchführen.
Update: Boris Dayma hat freundlicherweise ein Experiment gestartet (blau mit rot als Basislinie), um die Aufmerksamkeit der Kosinusähnlichkeit mit einer festen Skala von 10 in einer realen Modellumgebung zu validieren.
Update 2: Kosinusähnlichkeitsaufmerksamkeit wurde in einem realen Text-zu-Bild-Aufmerksamkeitsnetzwerk unter Verwendung einer konstanten Skala von 10
nachgewiesen. Nicht schlimmer als regelmäßige Aufmerksamkeit. Der Dank geht an Boris Dayma, der sich die Zeit genommen hat, das Experiment durchzuführen und Zweifel an der Technik auszuräumen.
Update 3: Robin Rombach hat den Kernel in diesem Repository mit einer Kopfgröße von 64 und einer festen Skalierung von 10 in einem Text-zu-Bild-Modell getestet und dabei keinen Unterschied zur normalen Aufmerksamkeit festgestellt. Weitere Bewertungen stehen noch aus.
Update 4: Die in Boris' Experimenten beobachtete Leistungsverbesserung ist wahrscheinlich auf die Tatsache zurückzuführen, dass die Aufmerksamkeit der Kosinussimulation es einem ermöglicht, in den Transformatoren von der Pre-Layernorm- zur Post-Layernorm-Konfiguration zu wechseln (da die l2norm effektiv die Pre-Layernorm-Konfiguration einnimmt). Layernorm). Die Cosinus-Sim-Aufmerksamkeit wird wahrscheinlich die gleichen Ergebnisse liefern wie die normale Aufmerksamkeit, ohne dass weitere Änderungen am Transformator erforderlich sind.
Zum Testen sind Ausgabe und Gradienten für nicht-autoregressive und autoregressive Szenarien gleich
$ python setup.py test
Stellen Sie sicher, dass Sie zuerst den CUDA-Kernel installieren
$ python setup . py install
Dann
$ python benchmark . py
Wenn Sie nur Vorwärts- oder Rückwärts-Benchmarking durchführen möchten, hängen Sie entweder das Flag --only-forwards
oder --only-backwards
an das Obige an. Um ein autoregressives Benchmarking durchzuführen, hängen Sie --causal
an
Nach vorne
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
Rückwärts – muss noch bearbeitet werden
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
Vorwärts und rückwärts – F32 ist definitiv langsamer
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
Für autoregressiv ist ein klarer python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
Für Sequenzen variabler Länge mit Maskierung ebenfalls ein klarer Gewinn. Gehen Sie davon aus, dass durchschnittlich 25 % der Token python benchmark.py --mask-prob 0.25
maskiert sind
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
Der Dank geht an Stability für die Bereitstellung des Zugangs zu A100s zum Testen. Vielen Dank an Enrico, der sich die Zeit genommen hat, einige Benchmarks durchzuführen, als ich noch keinen Zugriff hatte.
A100 ist noch in Arbeit. Shared Memory wird noch nicht vollständig ausgenutzt. Merkwürdigerweise scheint F32 besser abzuschneiden als F16
Vorwärts
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
Rückwärts
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
Vorwärts und rückwärts
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
Autoregressiv
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
Sequenzen mit variabler Länge (bis zu 25 % der Token werden ausgeblendet)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
Versuchen Sie es mit einer Sequenzlänge von 8192. Es wird langsam sein, aber funktionieren (die normale Aufmerksamkeit wird bei > 2048 unterbrochen, Sie werden dies sehen, wenn Sie das Flag --use-cuda-kernel
entfernen)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}