Quanta Magazine이 다시 만든 딥 러닝에 대해 알아보세요
Flash Attention과 동일한 스타일로 융합된 코사인 유사성 주목을 구현합니다. l2 정규화된 쿼리와 키를 채택하면 더 이상 수치 안정성을 위해 행 최대값을 추적할 필요가 없습니다. 이는 코사인 유사성 주의가 일반화 비용 없이 발생한다고 가정하여 플래시 주의 알고리즘을 크게 단순화합니다.
즉, 안정적이고 빠르며 메모리 효율적이며 단점이 없는 더 긴 상황에 대한 관심입니다.
업데이트: 불행하게도 Robin의 실험에서는 손실에 반영되지 않은 훨씬 더 나쁜 평가 FID 점수가 나타났습니다. 더 많은 실험이 대기 중입니다. 이 라이브러리를 주의해서 사용하세요.
업데이트 2: 유일하게 절약되는 은혜는 그룹화된 l2norm을 사용하는 것입니다. 이는 잠재적으로 더 많은 표현성을 허용할 수 있습니다. 누구든지 생성 작업에서 이 기술을 평가하고 FID 점수를 얻을 수 있다면 매우 감사하겠습니다.
업데이트 3: 코사인 시뮬레이션 어텐션과 유사한 접근 방식이 Brain의 22B 매개변수 비전 모델을 통해 대규모로 입증되었습니다.
현재로서는 자동 회귀 및 가변 길이 시퀀스가 모든 아키텍처에서 더 빨라야 합니다. 2048보다 긴 시퀀스의 경우 정기적인 주의가 필요하지 않은 메모리 효율성도 있습니다.
그러나 마스킹이 없는 비자동회귀의 경우 F16의 A100에서는 아키텍처가 여전히 느립니다. 목표는 공유 메모리가 아직 완전히 활용되지 않았기 때문에 F32와 F16 모두에 대해 A100 앞뒤로 더 빠르게 작동하도록 하는 것입니다.
공유 메모리가 충분하지 않은 구형 그래픽 카드의 경우 훈련되는 시퀀스 길이에 따라 메모리 효율성과 속도의 균형을 측정해야 합니다.
첫 번째 CUDA 커널을 통해 나에게 지도해 주고 간단한 참조 구현을 코딩해 준 Arthur Hennequin은 합리적인 성능 범위 내에서 기준에 맞는 첫 번째 커널을 부트스트랩하는 데 도움을 주었습니다. 이 작업은 그의 전문성이 없었다면 불가능했을 것이다.
Boris Dayma와 Robin Rombach는 일부 중요한 텍스트-이미지 모델에서 고정 스케일링을 사용하여 단순화된 코사인 시뮬레이션 Attention을 실험하고 이것이 실제로 일반 Attention만큼 성능을 발휘하는지 확인했습니다.
관심을 보인 논문을 쓴 Markus Rabe는 O(n²) 메모리가 필요하지 않으며 Tri Dao는 정기적인 관심을 위해 이 모든 것을 CUDA 커널 구현에 통합하고 HBM 액세스를 최소화하는 타일 접근 방식을 사용하여 속도의 우수성을 입증했습니다. out dO * O == dP * P
(역방향 패스). 그들의 발견이 없었다면 궁극적인 주목을 받는 제제를 찾는 나의 순례를 완료할 수 없었을 것입니다.
최첨단 인공지능 연구를 위한 Stability.ai의 아낌없는 후원
$ pip install flash-cosine-sim-attention
자기 관심
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
교차주의
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v ) # (1, 8, 1024, 64)
키/값 마스킹 포함
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
mask = torch . ones ( 1 , 2048 ). bool (). cuda ()
out = flash_cosine_sim_attention ( q , k , v , mask = mask ) # (1, 8, 1024, 64)
자기회귀
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
단일 방향 키/값(Shazeer et al & PaLM에서 사용됨)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (4, 8, 1024, 64)
l2norm과 실제 주의 단계 사이에서 쿼리와 키에 대한 작업을 수행해야 하는 경우 l2norm_qk = False
로 설정하면 됩니다.
전.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention , l2norm_tensors
q = torch . randn ( 4 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 4 , 1024 , 64 ). cuda ()
v = torch . randn ( 4 , 1024 , 64 ). cuda ()
q , k = l2norm_tensors ( q , k )
# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch
out = flash_cosine_sim_attention ( q , k , v , l2norm_qk = False ) # (4, 8, 1024, 64)
예상대로 인과 관계 작업에 주의를 기울이십시오. - (추론 중 자동 회귀의 키 및 값 캐싱 또는 훈련과 같은 Transformer-xl)
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 1 , 8 , 1024 , 64 ). cuda ()
k = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
v = torch . randn ( 1 , 8 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (1, 8, 1024, 64)
배치와 헤드 치수를 병합했다면 괜찮습니다.
import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention
q = torch . randn ( 32 , 1024 , 64 ). cuda ()
k = torch . randn ( 32 , 2048 , 64 ). cuda ()
v = torch . randn ( 32 , 2048 , 64 ). cuda ()
out = flash_cosine_sim_attention ( q , k , v , causal = True ) # (32, 1024, 64)
16 - f32
32
64
96
128
16 -f16
80 - 진행 중
bfloat16 지원, Arthur가 권장하는 대로 sfinae 사용
mma를 계산하기 위해 qk_mma에서 공유 메모리로 청크로 스트리밍하고, 해제된 smem을 더 많은 캐싱에 사용할 수 있는지 확인하세요.
O(n) 1d 동적 위치 바이어스 지원
왜 smem 조각 캐싱이 성능 저하로 이어지는지 알아내세요. 말이 안 됩니다.
logsumexp 사용에 대해 생각해 보십시오. 작동하지만 추가 로그로 인해 성능이 저하됩니다.
A100(또는 f16)에서 허용되는 만큼의 캐싱을 허용하기 위해 smem 조각 캐싱 메커니즘을 준비합니다.
역방향 전달을 위해 주의 타일 크기 처리를 사용자 정의할 수 있도록 합니다.
mma 내부의 오버로드된 기능에 원자 추가 이동
축적에 사용되는 유형이 유연함
f16에서 64x96 타일을 테스트해 보세요.
일반 파이토치 코드를 사용하여 CPU 메모리 효율적인 버전을 가져옵니다(학습이 의미가 없으므로 추론용으로만).
뒤로 공유 메모리의 증가를 다르게 활용할 수 있는 경우 아키텍처(예: A100)에 대해 다르게 디스패치하는 방법을 알아냅니다.
주의 타일의 행과 열 크기를 분리합니다.
dk 및 dv는 이제 f16에 있습니다(단일 헤드 kv 아님).
더 많은 표준 헤드 치수 지원(wip)
헤드 크기 32에 대해 다시 바이어스 역방향 기울기를 디버그하고 수정합니다.
주의 편향 기울기 수정
PaLM에서와 같이 단일 방향 키/값을 허용합니다.
f16에 대한 원자 추가 수정
주의 편향은 주의 편향과 같은 Alphafold2의 경우 추가 배치 차원의 차원을 허용할 수 있어야 합니다.
버전을 패키지 이름의 접미사로 사용하여 커널 캐시 무효화를 자동화합니다.
f16 인과적인 수치 문제 해결
순방향 커널에서 역방향 커널까지 모든 학습 내용을 채택하고 최소한 A100보다 성능이 뛰어난지 확인하세요.
지금까지 코사인 유사성 주의는 업계에서 널리 사용되지 않았습니다. 지금까지 훈련된 유일한 대형 모델은 SwinV2입니다. 접근 방식을 무효화할 수 있는 사람이 있으면 문제를 공개하거나 이메일을 보내주세요. x-transformers 저장소를 사용하여 정기적인 주의를 기울여 실험을 실행할 수 있습니다.
업데이트: Boris Dayma는 실제 모델 설정에서 고정 척도 10으로 코사인 유사성 주의를 검증하기 위한 실험(파란색과 빨간색을 기준으로 함)을 친절하게 시작했습니다.
업데이트 2: 코사인 유사성 관심은 10
의 상수 척도를 사용하여 실제 텍스트-이미지 주의 네트워크에서 입증되었습니다. 정기적인 관심보다 나쁘지 않습니다. 실험을 실행하는 데 시간을 투자하고 기술을 둘러싼 의심을 제거한 Boris Dayma에게 감사의 뜻을 전합니다.
업데이트 3: Robin Rombach는 텍스트-이미지 모델에서 헤드 크기 64, 고정 스케일 10으로 이 저장소의 커널을 테스트했으며 일반적인 주의와 차이가 없음을 관찰했습니다. 더 많은 평가가 보류 중입니다.
업데이트 4: Boris의 실험에서 볼 수 있는 성능 향상은 코사인 시뮬레이션을 통해 변환기에서 사전 계층 표준에서 사후 계층 표준 구성으로 전환할 수 있다는 사실에 기인한 것 같습니다(l2norm이 사전 계층 표준을 효과적으로 대체하므로). 레이어표준). 코사인 시뮬레이션 어텐션은 변환기에 다른 변경 사항을 적용하지 않고도 일반 어텐션과 동일한 결과를 얻을 가능성이 높습니다.
테스트 출력 및 기울기는 비자기회귀 시나리오와 자기회귀 시나리오에서 동일합니다.
$ python setup.py test
먼저 CUDA 커널을 설치해야 합니다.
$ python setup . py install
그 다음에
$ python benchmark . py
순방향 또는 역방향 벤치마크만 수행하려면 위에 --only-forwards
또는 --only-backwards
플래그를 추가하세요. 자동회귀를 벤치마킹하려면 --causal
추가하세요.
앞으로
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.24ms baseline: 0.23ms
seq_len: 256 slower: 1.27x kernel: 0.38ms baseline: 0.30ms
seq_len: 512 slower: 1.28x kernel: 0.87ms baseline: 0.68ms
seq_len: 1024 slower: 1.15x kernel: 2.63ms baseline: 2.28ms
seq_len: 2048 slower: 0.99x kernel: 7.99ms baseline: 8.10ms
seq_len: 4096 slower: 0.88x kernel: 30.82ms baseline: 34.84ms
seq_len: 8192 slower: 0.00x kernel: 121.96ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.20ms baseline: 0.24ms
seq_len: 256 slower: 0.97x kernel: 0.24ms baseline: 0.25ms
seq_len: 512 slower: 1.22x kernel: 0.43ms baseline: 0.35ms
seq_len: 1024 slower: 0.95x kernel: 0.93ms baseline: 0.98ms
seq_len: 2048 slower: 0.90x kernel: 3.16ms baseline: 3.50ms
seq_len: 4096 slower: 0.85x kernel: 11.06ms baseline: 13.07ms
seq_len: 8192 slower: 0.00x kernel: 42.61ms baseline: oom
거꾸로 - 아직 작업이 필요함
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.07x kernel: 0.61ms baseline: 0.57ms
seq_len: 256 slower: 1.40x kernel: 0.91ms baseline: 0.65ms
seq_len: 512 slower: 1.70x kernel: 2.34ms baseline: 1.38ms
seq_len: 1024 slower: 1.26x kernel: 5.67ms baseline: 4.50ms
seq_len: 2048 slower: 1.29x kernel: 20.60ms baseline: 15.91ms
seq_len: 4096 slower: 1.30x kernel: 78.93ms baseline: 60.81ms
seq_len: 8192 slower: 0.00x kernel: 314.51ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.50ms baseline: 0.55ms
seq_len: 256 slower: 1.06x kernel: 0.58ms baseline: 0.55ms
seq_len: 512 slower: 1.13x kernel: 0.81ms baseline: 0.72ms
seq_len: 1024 slower: 0.97x kernel: 2.09ms baseline: 2.16ms
seq_len: 2048 slower: 0.96x kernel: 7.06ms baseline: 7.35ms
seq_len: 4096 slower: 0.97x kernel: 26.08ms baseline: 26.84ms
seq_len: 8192 slower: 0.00x kernel: 101.02ms baseline: oom
앞으로 및 뒤로 - F32는 확실히 느립니다.
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 1.05x kernel: 0.83ms baseline: 0.79ms
seq_len: 256 slower: 1.34x kernel: 1.26ms baseline: 0.95ms
seq_len: 512 slower: 1.44x kernel: 3.14ms baseline: 2.18ms
seq_len: 1024 slower: 1.15x kernel: 7.83ms baseline: 6.81ms
seq_len: 2048 slower: 1.20x kernel: 28.83ms baseline: 24.03ms
seq_len: 4096 slower: 1.20x kernel: 111.13ms baseline: 92.51ms
seq_len: 8192 slower: 0.00x kernel: 441.70ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.89x kernel: 0.68ms baseline: 0.77ms
seq_len: 256 slower: 1.03x kernel: 0.80ms baseline: 0.77ms
seq_len: 512 slower: 1.06x kernel: 1.16ms baseline: 1.10ms
seq_len: 1024 slower: 0.93x kernel: 2.94ms baseline: 3.16ms
seq_len: 2048 slower: 0.93x kernel: 10.06ms baseline: 10.87ms
seq_len: 4096 slower: 0.93x kernel: 37.09ms baseline: 39.96ms
seq_len: 8192 slower: 0.00x kernel: 143.13ms baseline: oom
자동 회귀의 경우 확실한 승리 python benchmark.py --causal
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.97x kernel: 0.81ms baseline: 0.84ms
seq_len: 256 slower: 1.07x kernel: 1.12ms baseline: 1.05ms
seq_len: 512 slower: 0.83x kernel: 2.23ms baseline: 2.68ms
seq_len: 1024 slower: 0.55x kernel: 4.83ms baseline: 8.82ms
seq_len: 2048 slower: 0.49x kernel: 15.89ms baseline: 32.68ms
seq_len: 4096 slower: 0.46x kernel: 57.50ms baseline: 126.00ms
seq_len: 8192 slower: 0.00x kernel: 224.76ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.69ms baseline: 0.84ms
seq_len: 256 slower: 0.95x kernel: 0.79ms baseline: 0.83ms
seq_len: 512 slower: 0.78x kernel: 1.06ms baseline: 1.37ms
seq_len: 1024 slower: 0.50x kernel: 2.10ms baseline: 4.24ms
seq_len: 2048 slower: 0.37x kernel: 5.85ms baseline: 15.92ms
seq_len: 4096 slower: 0.31x kernel: 19.80ms baseline: 64.42ms
seq_len: 8192 slower: 0.00x kernel: 75.25ms baseline: oom
마스킹이 포함된 가변 길이 시퀀스의 경우에도 확실한 승리입니다. 평균 25%의 토큰이 마스크되었다고 가정합니다 python benchmark.py --mask-prob 0.25
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.95x kernel: 0.84ms baseline: 0.89ms
seq_len: 256 slower: 1.19x kernel: 1.28ms baseline: 1.08ms
seq_len: 512 slower: 1.23x kernel: 3.19ms baseline: 2.59ms
seq_len: 1024 slower: 0.92x kernel: 8.19ms baseline: 8.88ms
seq_len: 2048 slower: 0.92x kernel: 30.08ms baseline: 32.57ms
seq_len: 4096 slower: 0.94x kernel: 123.20ms baseline: 131.22ms
seq_len: 8192 slower: 0.00x kernel: 461.77ms baseline: oom
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.85x kernel: 0.77ms baseline: 0.90ms
seq_len: 256 slower: 0.93x kernel: 0.86ms baseline: 0.93ms
seq_len: 512 slower: 0.93x kernel: 1.31ms baseline: 1.40ms
seq_len: 1024 slower: 0.76x kernel: 3.31ms baseline: 4.35ms
seq_len: 2048 slower: 0.71x kernel: 11.19ms baseline: 15.65ms
seq_len: 4096 slower: 0.70x kernel: 41.27ms baseline: 59.01ms
seq_len: 8192 slower: 0.00x kernel: 158.60ms baseline: oom
테스트를 위해 A100에 대한 액세스를 제공한 Stability에 감사드립니다. 아직 액세스 권한이 없을 때 시간을 내어 일부 벤치마크를 실행해 준 Enrico에게 감사드립니다.
A100은 아직 진행중인 작업입니다. 공유 메모리는 아직 완전히 활용되지 않았습니다. 이상하게도 F32가 F16보다 잘하는 것 같습니다.
포워드
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.98x kernel: 0.29ms baseline: 0.30ms
seq_len: 256 slower: 1.19x kernel: 0.35ms baseline: 0.29ms
seq_len: 512 slower: 0.94x kernel: 0.52ms baseline: 0.55ms
seq_len: 1024 slower: 0.75x kernel: 1.23ms baseline: 1.65ms
seq_len: 2048 slower: 0.88x kernel: 4.17ms baseline: 4.73ms
seq_len: 4096 slower: 0.79x kernel: 14.53ms baseline: 18.36ms
seq_len: 8192 slower: 0.64x kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.84x kernel: 0.24ms baseline: 0.29ms
seq_len: 256 slower: 1.02x kernel: 0.29ms baseline: 0.29ms
seq_len: 512 slower: 1.24x kernel: 0.36ms baseline: 0.29ms
seq_len: 1024 slower: 1.48x kernel: 0.79ms baseline: 0.54ms
seq_len: 2048 slower: 1.31x kernel: 2.08ms baseline: 1.59ms
seq_len: 4096 slower: 1.21x kernel: 6.89ms baseline: 5.70ms
seq_len: 8192 slower: 1.07x kernel: 24.80ms baseline: 23.15ms
뒤로
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.94x kernel: 0.57ms baseline: 0.60ms
seq_len: 256 slower: 1.29x kernel: 0.75ms baseline: 0.58ms
seq_len: 512 slower: 1.16x kernel: 1.30ms baseline: 1.12ms
seq_len: 1024 slower: 0.98x kernel: 3.14ms baseline: 3.19ms
seq_len: 2048 slower: 1.05x kernel: 11.13ms baseline: 10.63ms
seq_len: 4096 slower: 0.98x kernel: 40.11ms baseline: 40.79ms
seq_len: 8192 slower: 0.97x kernel: 154.96ms baseline: 159.70ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.91x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.62ms baseline: 0.60ms
seq_len: 512 slower: 1.36x kernel: 0.82ms baseline: 0.60ms
seq_len: 1024 slower: 1.52x kernel: 1.52ms baseline: 1.01ms
seq_len: 2048 slower: 1.37x kernel: 4.14ms baseline: 3.03ms
seq_len: 4096 slower: 1.33x kernel: 14.23ms baseline: 10.71ms
seq_len: 8192 slower: 1.34x kernel: 53.90ms baseline: 40.28ms
앞으로 & 뒤로
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.92x kernel: 0.80ms baseline: 0.87ms
seq_len: 256 slower: 1.23x kernel: 1.07ms baseline: 0.87ms
seq_len: 512 slower: 1.08x kernel: 1.80ms baseline: 1.66ms
seq_len: 1024 slower: 0.94x kernel: 4.33ms baseline: 4.62ms
seq_len: 2048 slower: 0.99x kernel: 15.26ms baseline: 15.44ms
seq_len: 4096 slower: 0.93x kernel: 54.78ms baseline: 59.21ms
seq_len: 8192 slower: 0.91x kernel: 210.38ms baseline: 230.97ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.90x kernel: 0.78ms baseline: 0.86ms
seq_len: 256 slower: 1.00x kernel: 0.87ms baseline: 0.87ms
seq_len: 512 slower: 1.36x kernel: 1.18ms baseline: 0.86ms
seq_len: 1024 slower: 1.49x kernel: 2.31ms baseline: 1.55ms
seq_len: 2048 slower: 1.33x kernel: 6.17ms baseline: 4.63ms
seq_len: 4096 slower: 1.28x kernel: 21.08ms baseline: 16.44ms
seq_len: 8192 slower: 1.24x kernel: 78.75ms baseline: 63.45ms
자기회귀
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.82ms baseline: 1.01ms
seq_len: 256 slower: 1.02x kernel: 1.00ms baseline: 0.98ms
seq_len: 512 slower: 0.82x kernel: 1.55ms baseline: 1.89ms
seq_len: 1024 slower: 0.51x kernel: 2.79ms baseline: 5.44ms
seq_len: 2048 slower: 0.45x kernel: 8.37ms baseline: 18.67ms
seq_len: 4096 slower: 0.40x kernel: 29.16ms baseline: 72.97ms
seq_len: 8192 slower: 0.38x kernel: 108.68ms baseline: 285.47ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.82x kernel: 0.81ms baseline: 0.98ms
seq_len: 256 slower: 0.90x kernel: 0.88ms baseline: 0.98ms
seq_len: 512 slower: 1.16x kernel: 1.13ms baseline: 0.97ms
seq_len: 1024 slower: 0.80x kernel: 1.68ms baseline: 2.10ms
seq_len: 2048 slower: 0.54x kernel: 3.66ms baseline: 6.81ms
seq_len: 4096 slower: 0.45x kernel: 11.43ms baseline: 25.32ms
seq_len: 8192 slower: 0.41x kernel: 40.58ms baseline: 99.14ms
가변 길이 시퀀스(최대 25% 토큰이 마스크됨)
------------------------------------------------------------
float32 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.85ms baseline: 1.07ms
seq_len: 256 slower: 1.07x kernel: 1.15ms baseline: 1.08ms
seq_len: 512 slower: 1.00x kernel: 1.94ms baseline: 1.94ms
seq_len: 1024 slower: 0.84x kernel: 4.64ms baseline: 5.55ms
seq_len: 2048 slower: 0.84x kernel: 15.86ms baseline: 18.86ms
seq_len: 4096 slower: 0.76x kernel: 55.19ms baseline: 72.47ms
seq_len: 8192 slower: 0.75x kernel: 212.48ms baseline: 282.71ms
------------------------------------------------------------
float16 batch: 4 heads: 8 dim 64
------------------------------------------------------------
seq_len: 128 slower: 0.80x kernel: 0.83ms baseline: 1.04ms
seq_len: 256 slower: 0.90x kernel: 0.93ms baseline: 1.03ms
seq_len: 512 slower: 1.18x kernel: 1.22ms baseline: 1.04ms
seq_len: 1024 slower: 1.10x kernel: 2.40ms baseline: 2.17ms
seq_len: 2048 slower: 0.89x kernel: 6.27ms baseline: 7.06ms
seq_len: 4096 slower: 0.82x kernel: 21.19ms baseline: 25.95ms
seq_len: 8192 slower: 0.78x kernel: 79.45ms baseline: 101.83ms
$ make train
8192 시퀀스 길이를 사용해 보세요. 느리지만 작동할 것입니다(일반적인 주의는 > 2048에서 중단됩니다. --use-cuda-kernel
플래그를 제거하면 이를 볼 수 있습니다)
$ python train . py - - seq - len 8192 - - use - cuda - kernel
@article { Dao2022FlashAttentionFA ,
title = { FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness } ,
author = { Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2205.14135 }
}
@misc { rabe2021selfattention ,
title = { Self-attention Does Not Need $O(n^2)$ Memory } ,
author = { Markus N. Rabe and Charles Staats } ,
year = { 2021 } ,
eprint = { 2112.05682 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@inproceedings { Henry2020QueryKeyNF ,
title = { Query-Key Normalization for Transformers } ,
author = { Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen } ,
booktitle = { FINDINGS } ,
year = { 2020 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 }
}