Plongez dans le Deep Learning, refait par Quanta Magazine
Implémentation de l'attention de similarité cosinus fusionnée dans le même style que Flash Attention. L'observation est qu'en adoptant des requêtes et des clés normalisées l2, vous n'avez plus besoin de suivre les maximums de lignes pour la stabilité numérique. Cela simplifie grandement l’algorithme d’attention flash, en supposant que l’attention de similarité cosinusoïdale n’entraîne aucun coût de généralisation.
En d’autres termes, une attention contextuelle stable, rapide, efficace en mémoire et plus longue sans inconvénients.
Mise à jour : Malheureusement, les expériences de Robin ont montré des scores FID d'évaluation bien pires, qui ne se reflètent pas dans la perte. En attendant plus d'expériences. Utilisez cette bibliothèque avec prudence.
Mise à jour 2 : la seule solution salvatrice serait d'utiliser des l2norm groupés, ce qui pourrait potentiellement permettre plus d'expressivité. Si quelqu'un peut évaluer cette technique sur son travail génératif et obtenir des scores FID, ce serait très apprécié.
Mise à jour 3 : une approche similaire à l'attention de simulation cosinus a été éprouvée à grande échelle, avec un modèle de vision à paramètres 22B de Brain.
Pour le moment, les séquences autorégressives et de longueur variable devraient être plus rapides dans toutes les architectures. Pour les séquences plus longues que 2048, cela sera également efficace en termes de mémoire là où une attention régulière ne le serait pas.
Cependant, pour les non autorégressifs sans masquage, l'architecture est encore plus lente sur A100 pour F16. L'objectif est de le faire fonctionner plus rapidement sur l'A100 en avant et en arrière pour F32 et F16, car la mémoire partagée n'est pas encore pleinement exploitée.
Pour les cartes graphiques plus anciennes qui ne disposent pas de suffisamment de mémoire partagée, il faudra évaluer le compromis entre l'efficacité et la vitesse de la mémoire en fonction de la longueur de la séquence à laquelle l'entraînement est effectué.
Arthur Hennequin pour m'avoir accompagné dans mon premier noyau CUDA et pour avoir codé une implémentation de référence simple, ce qui m'a aidé à amorcer le premier noyau présentant des performances raisonnables par rapport à la ligne de base. Ce travail n'aurait pas été possible sans son expertise.
Boris Dayma et Robin Rombach pour avoir mené des expériences sur l'attention simplifiée de simulation de cosinus avec une mise à l'échelle fixe sur certains modèles texte-image importants et vérifié qu'elle fonctionne effectivement aussi bien qu'une attention régulière.
Markus Rabe pour avoir écrit l'article qui montrait que l'attention ne nécessite pas de mémoire O(n²), et Tri Dao pour avoir tout rassemblé dans une implémentation du noyau CUDA pour une attention régulière, démontrant sa supériorité en termes de vitesse en utilisant l'approche en mosaïque minimisant les accès HBM (et pour comprendre out dO * O == dP * P
pour passe arrière). Je n'aurais pas pu terminer mon pèlerinage à la recherche de la formulation ultime de l'attention sans leurs découvertes.
Stability.ai pour son généreux parrainage visant à mener des recherches de pointe sur l'intelligence artificielle
$ pip install flash-cosine-sim-attention
Attention personnelle
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
Attention croisée
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
Avec masquage clé/valeur
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
Autorégressif
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
Clé/valeurs à tête unique (Shazeer et al & utilisées dans PaLM)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
Si vous devez effectuer des opérations sur les requêtes et les clés entre la norme l2 et l'étape d'attention réelle, définissez simplement l2norm_qk = False
ex.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
L'attention croisée avec la causalité fonctionne comme prévu - (mise en cache des clés et des valeurs en autorégressive pendant l'inférence, ou formation de type Transformer-XL)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
Si vous avez fusionné les dimensions du lot et de la tête, ce n'est pas grave.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - f32
32
64
96
128
16-f16
80 - en cours
support bfloat16, utilisez sfinae comme recommandé par Arthur
diffuser de qk_mma vers la mémoire partagée en morceaux pour calculer le mma, voir si le smem libéré peut être utilisé pour mettre davantage en cache
prise en charge du biais de position dynamique O(n) 1d
comprendre pourquoi la mise en cache des fragments smem entraînerait une dégradation des performances, cela n'a pas de sens
pensez à utiliser logsumexp - fonctionne mais un journal supplémentaire entraîne une dégradation des performances
préparer un mécanisme de mise en cache des fragments smem, pour permettre autant de mise en cache que possible sur A100 (ou f16)
rendre le traitement de la taille des tuiles d'attention personnalisable pour le passage en arrière
déplacer l'ajout atomique à la fonction surchargée à l'intérieur de mma
flexible quant au type utilisé pour l'accumulation
tester les tuiles 64x96 sur f16
apporter une version efficace en mémoire CPU (uniquement à titre d'inférence, car la formation n'a pas de sens) en utilisant simplement du code pytorch
comprendre comment répartir différemment pour les architectures (disons A100), au cas où l'inverse permettrait d'utiliser différemment l'augmentation de la mémoire partagée
découpler les tailles de lignes et de colonnes pour les vignettes d'attention
dk et dv sont maintenant en f16 quand cela peut l'être (kv non à tête unique)
prend en charge des dimensions de tête plus standard (wip)
déboguez et corrigez encore une fois les dégradés de biais vers l'arrière pour une taille de tête de 32
corriger les dégradés de biais d'attention
autoriser les clés/valeurs à tête unique, comme dans PaLM
correction de l'ajout atomique pour f16
le biais d'attention devrait être capable d'accepter les dimensions d'une dimension de lot supplémentaire, pour Alphafold2 comme le biais d'attention
automatiser le contournement du cache du noyau en utilisant la version comme suffixe au nom du package
résoudre les problèmes numériques causals de F16
adopter tous les apprentissages du noyau avant au noyau arrière et assurez-vous qu'il surpasse au moins sur A100
Jusqu’à présent, l’attention portée à la similarité cosinusoïdale n’est pas largement utilisée dans l’industrie. Le seul grand modèle qui a été entraîné jusqu’à présent est SwinV2. Si quelqu'un peut invalider l'approche, veuillez ouvrir un problème ou m'envoyer un e-mail. Vous pouvez exécuter des expériences avec une attention régulière à l'aide du référentiel x-transformers.
Mise à jour : Boris Dayma a gracieusement lancé une expérience (bleu avec rouge comme ligne de base) pour valider l'attention à la similarité cosinus avec une échelle fixe de 10 dans un modèle du monde réel.
Mise à jour 2 : l'attention à la similarité cosinusoïdale a été prouvée dans un réseau d'attention texte-image du monde réel, en utilisant une échelle constante de 10
. Pas pire qu'une attention régulière. Le mérite revient à Boris Dayma d'avoir investi du temps pour mener l'expérience et dissiper les doutes entourant la technique.
Mise à jour 3 : Robin Rombach a testé le noyau dans ce référentiel avec une taille de tête de 64 et une échelle fixe de 10 dans un modèle texte-image, n'observant aucune différence par rapport à une attention régulière. Plus d’évaluations en attente.
Mise à jour 4 : L'amélioration des performances observée dans les expériences de Boris est probablement due au fait que l'attention cosinus-sim permet de passer de la configuration pré-couche à la configuration post-couche dans les transformateurs (car la norme l2 remplace effectivement la configuration pré-couche). couchenorme). L'attention cosinus sim donnera probablement les mêmes résultats qu'une attention régulière, sans aucune autre modification du transformateur.
Pour tester, la sortie et les gradients sont égaux pour les scénarios non autorégressifs et autorégressifs
$ python setup.py test
Assurez-vous d'installer d'abord le noyau CUDA
$ python setup . py install
Alors
$ python benchmark . py
Pour une analyse comparative uniquement en avant ou en arrière, ajoutez l'indicateur --only-forwards
ou --only-backwards
à ce qui précède. Pour comparer l'autorégression, ajoutez --causal
Avant
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
À l'envers - il faut encore du travail
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
Avancer et reculer – F32 est nettement plus lent
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
Pour autorégressif, une victoire claire python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
Pour les séquences de longueur variable avec masquage, c'est également une nette victoire. Supposons qu'en moyenne 25 % des jetons soient masqués python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
Merci à Stability pour avoir donné accès aux A100 à des fins de test. Merci à Enrico d'avoir pris le temps de réaliser quelques benchmarks alors que je n'y avais pas encore accès.
L'A100 est toujours en chantier. La mémoire partagée n'est pas encore pleinement exploitée. Bizarrement, le F32 semble faire mieux que le F16
Attaquants
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
En arrière
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
Avant et arrière
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
Autorégressif
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
Séquences de longueur variable (jusqu'à 25 % de jetons masqués)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
Essayez une longueur de séquence de 8192. Ce sera lent mais cela fonctionnera (l'attention normale sera interrompue à > 2048, vous verrez cela si vous supprimez l'indicateur --use-cuda-kernel
)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}