ดำดิ่งสู่การเรียนรู้เชิงลึก ปรับปรุงใหม่โดย Quanta Magazine
การใช้ความสนใจความคล้ายคลึงโคไซน์แบบหลอมรวมในรูปแบบเดียวกับ Flash Attention ข้อสังเกตก็คือด้วยการใช้คำค้นหาและคีย์ที่ทำให้เป็นมาตรฐาน l2 คุณไม่จำเป็นต้องติดตามค่าสูงสุดของแถวเพื่อความเสถียรของตัวเลขอีกต่อไป วิธีนี้จะช่วยลดความซับซ้อนของอัลกอริธึมความสนใจแบบแฟลชได้อย่างมาก โดยสมมติว่าความสนใจที่มีความคล้ายคลึงกันของโคไซน์นั้นไม่มีค่าใช้จ่ายในการสรุปทั่วไป
กล่าวอีกนัยหนึ่งคือ มีเสถียรภาพ รวดเร็ว หน่วยความจำมีประสิทธิภาพ และความสนใจบริบทที่ยาวนานขึ้นโดยไม่มีข้อเสีย
อัปเดต: น่าเสียดายที่การทดลองของ Robin แสดงให้เห็นคะแนน FID การประเมินที่แย่กว่ามากซึ่งไม่ได้สะท้อนถึงการสูญเสีย อยู่ระหว่างการทดลองเพิ่มเติม ใช้ห้องสมุดนี้ด้วยความระมัดระวัง
อัปเดต 2: สิ่งเดียวที่ช่วยประหยัดได้คือการใช้ l2norm ที่จัดกลุ่ม ซึ่งอาจช่วยให้แสดงออกได้มากขึ้น หากใครก็ตามสามารถประเมินเทคนิคนี้ในงานกำเนิดของตนและได้รับคะแนน FID บ้างจะยินดีเป็นอย่างยิ่ง
อัปเดต 3: วิธีการที่คล้ายคลึงกับความสนใจของโคไซน์ซิมได้รับการพิสูจน์แล้วในวงกว้าง ด้วยโมเดลการมองเห็นพารามิเตอร์ 22B จาก Brain
ในขณะนี้ ลำดับการถดถอยอัตโนมัติและความยาวผันแปรควรจะเร็วขึ้นในทุกสถาปัตยกรรม สำหรับลำดับที่ยาวกว่าปี 2048 หน่วยความจำก็จะมีประสิทธิภาพเช่นกัน โดยที่จะไม่ให้ความสนใจเป็นประจำ
อย่างไรก็ตาม สำหรับการไม่ถดถอยอัตโนมัติโดยไม่มีการปิดบัง สถาปัตยกรรมจะยังคงช้ากว่าบน A100 สำหรับ F16 จุดมุ่งหมายคือเพื่อให้ทำงานเร็วขึ้นบน A100 ไปข้างหน้าและข้างหลังสำหรับทั้ง F32 และ F16 เนื่องจากหน่วยความจำที่ใช้ร่วมกันยังไม่ถูกใช้ประโยชน์อย่างเต็มที่
กราฟิกการ์ดรุ่นเก่าที่ไม่มีหน่วยความจำที่ใช้ร่วมกันเพียงพอ เราจะต้องวัดข้อดีข้อเสียของประสิทธิภาพและความเร็วของหน่วยความจำ ขึ้นอยู่กับความยาวของลำดับที่กำลังฝึกอยู่
Arthur Hennequin สำหรับการฝึกสอนฉันผ่านเคอร์เนล CUDA แรกของฉัน และสำหรับการเขียนโค้ดการใช้งานอ้างอิงแบบง่าย ซึ่งช่วยให้ฉันบูตเคอร์เนลแรกที่มีประสิทธิภาพที่เหมาะสมจนถึงระดับพื้นฐาน งานนี้คงเป็นไปไม่ได้หากปราศจากความเชี่ยวชาญของเขา
Boris Dayma และ Robin Rombach สำหรับการทดลองดำเนินการเกี่ยวกับความสนใจของโคไซน์ซิมที่เรียบง่ายพร้อมการปรับขนาดคงที่ในโมเดลข้อความเป็นรูปภาพที่สำคัญบางรายการ และตรวจสอบว่ามีประสิทธิภาพเช่นเดียวกับความสนใจปกติหรือไม่
Markus Rabe สำหรับการเขียนกระดาษที่แสดงความสนใจไม่จำเป็นต้องใช้หน่วยความจำ O(n²) และ Tri Dao สำหรับการรวบรวมทั้งหมดเข้าด้วยกันในการใช้งานเคอร์เนล CUDA เพื่อความสนใจอย่างสม่ำเสมอ แสดงให้เห็นถึงความเหนือกว่าในด้านความเร็วโดยใช้แนวทางแบบเรียงต่อกันที่ลดการเข้าถึง HBM ให้เหลือน้อยที่สุด (และสำหรับการคำนวณ ออก dO * O == dP * P
เพื่อถอยหลัง) คงไม่สามารถเดินทางไปแสวงบุญของฉันเพื่อแสวงหาการกำหนดความสนใจขั้นสูงสุดได้หากปราศจากการค้นพบของพวกเขา
Stability.ai สำหรับการสนับสนุนอย่างมีน้ำใจในการทำงานวิจัยปัญญาประดิษฐ์ที่ล้ำหน้า
$ pip install flash-cosine-sim-attention
ความสนใจตนเอง
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
ข้ามความสนใจ
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
ด้วยการมาสก์คีย์ / ค่า
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
ถดถอยอัตโนมัติ
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
คีย์ / ค่าแบบหัวเดียว (Shazeer et al & ใช้ใน PaLM)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
หากคุณต้องการดำเนินการกับข้อความค้นหาและคีย์ระหว่าง l2norm และขั้นตอนความสนใจจริง เพียงตั้งค่า l2norm_qk = False
อดีต.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
ความสนใจข้ามกับการทำงานเชิงสาเหตุตามที่คาดไว้ - (การแคชคีย์และค่าในแบบ autoregressive ในระหว่างการอนุมาน หรือ Transformer-xl เช่นการฝึก)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
หากคุณมีการรวมขนาดแบทช์และส่วนหัวเข้าด้วยกัน ก็ไม่เป็นไร
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - f32
32
64
96
128
16 -f16
80 - อยู่ระหว่างดำเนินการ
รองรับ bfloat16 ใช้ sfinae ตามที่ Arthur แนะนำ
สตรีมจาก qk_mma ไปยังหน่วยความจำที่ใช้ร่วมกันเป็นชิ้น ๆ เพื่อคำนวณ mma ดูว่า smem ที่เป็นอิสระสามารถใช้สำหรับแคชเพิ่มเติมได้หรือไม่
รองรับอคติตำแหน่งไดนามิก O (n) 1d
ลองหาคำตอบว่าเหตุใดการแคชแฟรกเมนต์ smem จึงทำให้ประสิทธิภาพลดลง จึงไม่สมเหตุสมผล
ลองคิดถึงการใช้ logsumexp - ใช้งานได้ แต่บันทึกพิเศษทำให้ประสิทธิภาพลดลง
เตรียมกลไกการแคชส่วน smem เพื่อให้แคชได้มากที่สุดเท่าที่อนุญาตบน A100 (หรือ f16)
ทำให้การประมวลผลขนาดไทล์ความสนใจปรับแต่งได้สำหรับการย้อนกลับ
ย้ายอะตอมมิกเพิ่มไปยังฟังก์ชันโอเวอร์โหลดภายใน mma
ยืดหยุ่นได้ว่าจะนำไปใช้สะสมประเภทไหน
ทดสอบไทล์ขนาด 64x96 บน f16
นำเวอร์ชันที่มีประสิทธิภาพของหน่วยความจำ CPU เข้ามา (สำหรับการอนุมานเท่านั้น เนื่องจากการฝึกอบรมไม่สมเหตุสมผล) โดยใช้โค้ด pytorch ธรรมดา
หาวิธีการจัดส่งที่แตกต่างกันสำหรับสถาปัตยกรรม (เช่น A100) ในกรณีที่ย้อนหลังสามารถใช้การเพิ่มหน่วยความจำที่ใช้ร่วมกันแตกต่างกัน
แยกขนาดแถวและคอลัมน์สำหรับไทล์ความสนใจ
dk และ dv ตอนนี้อยู่ใน f16 เมื่อเป็นไปได้ (ไม่ใช่ kv หัวเดียว)
รองรับขนาดหัวมาตรฐานมากขึ้น (wip)
แก้ไขข้อบกพร่องและแก้ไขการไล่ระดับสีย้อนหลังอีกครั้งสำหรับขนาดหัว 32
แก้ไขการไล่ระดับอคติของความสนใจ
อนุญาตให้ใช้คีย์ / ค่าแบบหัวเดียวเช่นเดียวกับใน PaLM
แก้ไขการเพิ่มอะตอมมิกสำหรับ f16
อคติความสนใจควรจะสามารถยอมรับมิติของมิติชุดพิเศษได้ สำหรับ Alphafold2 เช่นการให้น้ำหนักความสนใจ
ทำให้เคอร์เนลป้องกันแคชโดยอัตโนมัติโดยใช้เวอร์ชันเป็นส่วนต่อท้ายชื่อแพ็คเกจ
แก้ไขปัญหาตัวเลขเชิงสาเหตุ f16
นำการเรียนรู้ทั้งหมดตั้งแต่เคอร์เนลไปข้างหน้าไปจนถึงเคอร์เนลถอยหลัง และตรวจสอบให้แน่ใจว่ามีประสิทธิภาพเหนือกว่า A100 เป็นอย่างน้อย
จนถึงขณะนี้ความสนใจความคล้ายคลึงของโคไซน์ยังไม่มีการใช้กันอย่างแพร่หลายในอุตสาหกรรม รุ่นใหญ่เพียงรุ่นเดียวที่ได้รับการฝึกจนถึงตอนนี้คือ SwinV2 หากใครสามารถยกเลิกแนวทางนี้ได้ โปรดเปิดปัญหาหรือส่งอีเมลถึงฉัน คุณสามารถดำเนินการทดสอบโดยไม่สนใจความสนใจเป็นประจำได้โดยใช้พื้นที่เก็บข้อมูล x-transformers
อัปเดต: Boris Dayma ได้เริ่มการทดลองอย่างสง่างาม (สีน้ำเงินที่มีสีแดงเป็นพื้นฐาน) เพื่อตรวจสอบความสนใจของความคล้ายคลึงของโคไซน์ด้วยมาตราส่วนคงที่ที่ 10 ในการตั้งค่าแบบจำลองในโลกแห่งความเป็นจริง
อัปเดต 2: ความสนใจที่คล้ายคลึงกันของโคไซน์ได้รับการพิสูจน์แล้วในเครือข่ายความสนใจจากข้อความเป็นรูปภาพในโลกแห่งความเป็นจริง โดยใช้มาตราส่วนคงที่ที่ 10
ไม่เลวร้ายไปกว่าการเอาใจใส่เป็นประจำ เครดิตเป็นของ Boris Dayma สำหรับการสละเวลาเพื่อทำการทดสอบและขจัดข้อสงสัยเกี่ยวกับเทคนิคนี้
อัปเดต 3: Robin Rombach ได้ทดสอบเคอร์เนลในพื้นที่เก็บข้อมูลนี้ด้วยขนาดส่วนหัวที่ 64 และขนาดคงที่ที่ 10 ในรูปแบบข้อความเป็นรูปภาพ โดยไม่สังเกตเห็นความแตกต่างจากความสนใจปกติ การประเมินเพิ่มเติมที่รอดำเนินการ
อัปเดต 4: การปรับปรุงประสิทธิภาพที่เห็นในการทดลองของ Boris น่าจะเนื่องมาจากข้อเท็จจริงที่ว่าความสนใจของโคไซน์ซิมช่วยให้สามารถเปลี่ยนจากการกำหนดค่า pre layernorm ไปเป็น post layernorm ในหม้อแปลงไฟฟ้าได้ (เนื่องจาก l2norm เข้ามาแทนที่ pre- ชั้นมาตรฐาน) ความสนใจของโคไซน์ซิมน่าจะให้ผลลัพธ์เช่นเดียวกับความสนใจปกติ โดยไม่มีการเปลี่ยนแปลงใดๆ กับหม้อแปลง
สำหรับการทดสอบเอาต์พุตและการไล่ระดับสีจะเท่ากันสำหรับสถานการณ์ที่ไม่ถดถอยอัตโนมัติและถดถอยอัตโนมัติ
$ python setup.py test
ตรวจสอบให้แน่ใจว่าได้ติดตั้งเคอร์เนล CUDA ก่อน
$ python setup . py install
แล้ว
$ python benchmark . py
สำหรับการเปรียบเทียบแบบไปข้างหน้าหรือข้างหลังเท่านั้น ให้เพิ่มแฟล็ก --only-forwards
หรือ --only-backwards
ท้ายด้านบน หากต้องการวัดประสิทธิภาพการถดถอยอัตโนมัติ ให้ผนวก --causal
ซึ่งไปข้างหน้า
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
ถอยหลัง - ยังต้องทำงาน
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
เดินหน้าและถอยหลัง - F32 ช้าลงอย่างแน่นอน
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
สำหรับการถดถอยอัตโนมัติ win python benchmark.py --causal
ที่ชัดเจน
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
สำหรับลำดับความยาวที่แปรผันได้พร้อมการมาสก์ก็ถือเป็นชัยชนะที่ชัดเจนเช่นกัน สมมติว่าโดยเฉลี่ย 25% ของโทเค็นถูกปกปิด python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
ขอขอบคุณ Stability ที่ให้สิทธิ์เข้าถึง A100 เพื่อทำการทดสอบ ขอขอบคุณ Enrico ที่สละเวลาเรียกใช้การวัดประสิทธิภาพบางส่วนในขณะที่ฉันยังไม่มีสิทธิ์เข้าถึง
A100 ยังคงอยู่ในระหว่างดำเนินการ หน่วยความจำที่ใช้ร่วมกันยังไม่ถูกใช้ประโยชน์อย่างเต็มที่ น่าแปลกที่ F32 ดูเหมือนจะทำได้ดีกว่า F16
ไปข้างหน้า
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
ถอยหลัง
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
ไปข้างหน้าและข้างหลัง
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
ถดถอยอัตโนมัติ
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
ลำดับความยาวผันแปรได้ (โทเค็นมากถึง 25% ถูกปกปิด)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
ลองความยาวลำดับ 8192 มันจะช้า แต่จะได้ผล (ความสนใจปกติจะพังที่> 2048 คุณจะเห็นสิ่งนี้หากคุณลบการตั้งค่าสถานะ --use-cuda-kernel
)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}