انغمس في التعلم العميق، الذي أعادت مجلة Quanta بنائه
تنفيذ انتباه تشابه جيب التمام المندمج بنفس أسلوب Flash Attention. الملاحظة هي أنه من خلال اعتماد الاستعلامات والمفاتيح المقيسة l2، لم تعد بحاجة إلى تتبع الحد الأقصى للصف لتحقيق الاستقرار الرقمي. يؤدي هذا إلى تبسيط خوارزمية انتباه الفلاش إلى حد كبير، بافتراض أن انتباه تشابه جيب التمام يأتي دون أي تكلفة تعميم.
بمعنى آخر، مستقر، وسريع، وفعال في الذاكرة، واهتمام بالسياق لفترة أطول دون أي جوانب سلبية.
تحديث: لسوء الحظ، أظهرت تجارب روبن تقييمًا أسوأ بكثير، ولم تنعكس درجات FID في الخسارة. في انتظار المزيد من التجارب. استخدم هذه المكتبة بحذر.
التحديث 2: الحل الوحيد هو استخدام l2norm المُجمَّع، والذي قد يسمح بمزيد من التعبير. إذا كان بإمكان أي شخص تقييم هذه التقنية على عمله التوليدي والحصول على بعض درجات FID، فسيكون موضع تقدير كبير.
التحديث 3: تم إثبات وجود نهج مشابه لاهتمام شريحة جيب التمام على نطاق واسع، مع نموذج رؤية معلمة 22B من Brain.
في الوقت الحالي، يجب أن تكون تسلسلات الانحدار الذاتي والمتغيرة الطول أسرع عبر جميع البنيات. بالنسبة للتسلسلات الأطول من 2048، ستكون أيضًا ذات كفاءة في الذاكرة حيث لا يكون الاهتمام المنتظم كذلك.
ومع ذلك، بالنسبة إلى عدم الانحدار التلقائي بدون إخفاء، لا تزال البنية أبطأ على A100 لـ F16. الهدف هو جعله يعمل بشكل أسرع على A100 للأمام والخلف لكل من F32 وF16، حيث لم يتم استغلال الذاكرة المشتركة بشكل كامل بعد.
بطاقات الرسوميات الأقدم التي لا تحتوي على ذاكرة مشتركة كافية، سيتعين على المرء قياس المفاضلة بين كفاءة الذاكرة وسرعتها اعتمادًا على طول التسلسل الذي يتم التدريب عليه.
آرثر هينكين لتدريبي خلال أول نواة CUDA الخاصة بي، ولصياغة تطبيق مرجعي بسيط، مما ساعدني على تمهيد أول نواة تأتي ضمن أداء معقول لخط الأساس. ولم يكن هذا العمل ممكناً لولا خبرته.
بوريس دايما وروبن رومباك لإجراء تجارب على انتباه شريحة جيب التمام المبسطة مع القياس الثابت على بعض نماذج تحويل النص إلى الصورة المهمة والتحقق من أنها تؤدي بالفعل نفس الاهتمام المنتظم.
ماركوس راب لكتابة الورقة التي أظهرت أن الاهتمام لا يتطلب ذاكرة O(n²)، وTri Dao لتجميعها معًا في تطبيق CUDA kernel من أجل الاهتمام المنتظم، مما يدل على التفوق في السرعة باستخدام النهج المتجانب الذي يقلل من الوصول إلى HBM (ولتحديد out dO * O == dP * P
للتمرير للخلف). لم أكن لأتمكن من إكمال رحلة حجتي بحثًا عن صياغة الاهتمام النهائي لولا اكتشافاتهم.
Stability.ai على الرعاية السخية للعمل على أحدث أبحاث الذكاء الاصطناعي
$ pip install flash-cosine-sim-attention
الاهتمام الذاتي
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
عبر الاهتمام
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
مع إخفاء المفتاح/القيمة
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
الانحدار الذاتي
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
المفتاح/القيم ذات الرأس الواحد (Shazeer et al & المستخدمة في PaLM)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
إذا كنت بحاجة إلى إجراء عمليات على الاستعلامات والمفاتيح بين l2norm وخطوة الانتباه الفعلية، فما عليك سوى تعيين l2norm_qk = False
السابق.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
تقاطع الانتباه مع الأعمال السببية كما هو متوقع - (التخزين المؤقت للمفاتيح والقيم في الانحدار التلقائي أثناء الاستدلال، أو التدريب مثل المحول XL)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
إذا تم دمج أبعاد الدُفعة والرأس، فلا بأس بذلك
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - ص32
32
64
96
128
16-ص16
80 - قيد التقدم
دعم bfloat16، استخدم sfinae كما أوصى آرثر
الدفق من qk_mma إلى الذاكرة المشتركة في أجزاء لحساب mma، لمعرفة ما إذا كان يمكن استخدام smem المحرر للتخزين المؤقت أكثر
دعم O(n) 1d التحيز الموضعي الديناميكي
معرفة لماذا يؤدي التخزين المؤقت لجزء smem إلى انخفاض الأداء، فهذا لا معنى له
فكر في استخدام logsumexp - يعمل ولكن السجل الإضافي يؤدي إلى تدهور الأداء
قم بإعداد آلية التخزين المؤقت لأجزاء smem، للسماح بأكبر قدر ممكن من التخزين المؤقت المسموح به على A100 (أو f16)
جعل معالجة حجم بلاط الانتباه قابلة للتخصيص للتمرير إلى الخلف
نقل الإضافة الذرية إلى وظيفة مثقلة داخل مجلس العمل المتحد
مرن أي نوع يستخدم للتراكم
اختبار البلاط 64x96 على f16
إحضار نسخة فعالة من ذاكرة وحدة المعالجة المركزية (للاستدلال فقط، لأن التدريب ليس له معنى) باستخدام رمز pytorch العادي فقط
اكتشف كيفية الإرسال بشكل مختلف للبنيات (على سبيل المثال A100)، في حالة إمكانية الاستفادة من الزيادة في الذاكرة المشتركة بشكل مختلف
فصل أحجام الصفوف والأعمدة لبلاطات الانتباه
dk وdv موجودان الآن في f16 عندما يمكن أن يكونا (kv غير أحادي الرأس)
دعم المزيد من أبعاد الرأس القياسية (wip)
قم بتصحيح وإصلاح التدرجات المتحيزة للخلف مرة أخرى لحجم الرأس 32
إصلاح تدرجات انحياز الانتباه
السماح بمفتاح/قيم ذات رأس واحد، كما هو الحال في PaLM
إصلاح الإضافة الذرية لـ f16
يجب أن يكون انحياز الانتباه قادرًا على قبول أبعاد بُعد دفعة إضافية، بالنسبة لـ Alphafold2 مثل انحياز الانتباه
أتمتة عملية خرق ذاكرة التخزين المؤقت للنواة باستخدام الإصدار كلاحقة لاسم الحزمة
حل المسائل العددية السببية f16
اعتماد جميع الدروس المستفادة من النواة الأمامية إلى النواة العكسية والتأكد من تفوقها على الأقل على A100
حتى الآن لا يستخدم اهتمام تشابه جيب التمام على نطاق واسع في الصناعة. النموذج الكبير الوحيد الذي تم تدريبه عليه حتى الآن هو SwinV2. إذا كان أي شخص يمكن أن يبطل هذا النهج، يرجى فتح قضية أو مراسلتي عبر البريد الإلكتروني. يمكنك إجراء التجارب مقابل الاهتمام المنتظم باستخدام مستودع x-transformers.
تحديث: بدأ بوريس دايما تجربة (الأزرق مع الأحمر كخط أساسي) للتحقق من صحة اهتمام تشابه جيب التمام بمقياس ثابت قدره 10 في إعداد نموذج حقيقي.
التحديث 2: تم إثبات الاهتمام بتشابه جيب التمام في شبكة تحويل النص إلى الصورة في العالم الحقيقي، باستخدام مقياس ثابت قدره 10
. ليس أسوأ من الاهتمام المنتظم. يعود الفضل إلى بوريس دايما لاستثماره الوقت لإجراء التجربة وإزالة الشكوك المحيطة بهذه التقنية.
التحديث 3: قام Robin Rombach باختبار النواة في هذا المستودع بحجم رأس يبلغ 64 ومقياس ثابت قدره 10 في نموذج تحويل النص إلى صورة، ولم يلاحظ أي اختلاف عن الاهتمام المنتظم. المزيد من التقييمات في انتظار.
التحديث 4: من المحتمل أن يكون التحسن في الأداء الذي شوهد في تجارب بوريس يرجع إلى حقيقة أن انتباه جيب التمام يسمح للشخص بالتبديل من التكوين المسبق للطبقة إلى تكوين الطبقة اللاحقة في المحولات (حيث يحل المعيار l2 بشكل فعال محل المعيار المسبق طبقة). من المرجح أن يؤدي انتباه شريحة جيب التمام إلى نتائج مماثلة للانتباه العادي، دون أي تغييرات أخرى على المحول.
لاختبار المخرجات والتدرجات تكون متساوية بالنسبة لسيناريوهات الانحدار الذاتي وغير الانحدار الذاتي
$ python setup.py test
تأكد أولاً من تثبيت نواة CUDA
$ python setup . py install
ثم
$ python benchmark . py
من أجل قياس الأداء للأمام أو للخلف فقط، قم بإلحاق إما علامة --only-forwards
أو --only-backwards
إلى ما ورد أعلاه. لقياس الانحدار الذاتي، قم بإلحاق --causal
إلى الأمام
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
إلى الوراء - لا يزال بحاجة إلى العمل
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
للأمام والخلف - F32 أبطأ بالتأكيد
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
بالنسبة للانحدار التلقائي، فوز واضح python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
بالنسبة للتسلسلات ذات الطول المتغير مع الإخفاء، يعد ذلك بمثابة فوز واضح أيضًا. افترض في المتوسط أن 25% من الرموز المميزة مخفية في python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
نتوجه بالشكر إلى شركة Stability لتوفير الوصول إلى A100s للاختبار. شكرًا لـ Enrico على الوقت الذي أمضيته في تشغيل بعض المعايير عندما لم أتمكن من الوصول إليها بعد.
A100 لا يزال قيد التقدم. لم يتم استغلال الذاكرة المشتركة بشكل كامل بعد. والغريب أن أداء F32 أفضل من F16
إلى الأمام
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
إلى الوراء
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
إلى الأمام والخلف
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
الانحدار الذاتي
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
تسلسلات طويلة متغيرة (ما يصل إلى 25% من الرموز المميزة)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
حاول 8192 طول التسلسل. سيكون بطيئًا ولكنه سيعمل (سينقطع الاهتمام الطبيعي عند> 2048، وسترى هذا إذا قمت بإزالة علامة --use-cuda-kernel
)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}