Mamba: Linearzeitsequenzmodellierung mit selektiven Zustandsräumen
Albert Gu*, Tri Dao*
Papier: https://arxiv.org/abs/2312.00752
Transformatoren sind SSMs: Verallgemeinerte Modelle und effiziente Algorithmen
Durch strukturierte Zustandsraum-Dualität
Tri Dao*, Albert Gu*
Papier: https://arxiv.org/abs/2405.21060
Mamba ist eine neue Zustandsraummodellarchitektur, die eine vielversprechende Leistung bei informationsreichen Daten wie der Sprachmodellierung zeigt, wo frühere subquadratische Modelle hinter Transformers zurückbleiben. Es basiert auf dem Fortschritt strukturierter Zustandsraummodelle mit einem effizienten hardwarebewussten Design und einer Implementierung im Geiste von FlashAttention.
pip install causal-conv1d>=1.4.0
: eine effiziente Implementierung einer einfachen kausalen Conv1d-Schicht, die innerhalb des Mamba-Blocks verwendet wird.pip install mamba-ssm
: das Kernpaket von Mamba.pip install mamba-ssm[causal-conv1d]
: Zum Installieren des Mamba-Kernpakets und von causal-conv1d.pip install mamba-ssm[dev]
: Zum Installieren des Mamba-Kernpakets und der Entwicklungsabhängigkeiten. Es kann auch mit pip install .
aus diesem Repository.
Wenn pip
sich über PyTorch-Versionen beschwert, versuchen Sie, --no-build-isolation
an pip
zu übergeben.
Weitere Anforderungen:
Für AMD-Karten siehe zusätzliche Voraussetzungen unten.
Wir legen mehrere Schnittstellenebenen mit dem Mamba-Modell offen.
Mamba basiert auf einer selektiven SSM-Schicht, die im Mittelpunkt des Artikels steht (Abschnitt 3; Algorithmus 2).
Quelle: ops/selective_scan_interface.py.
Das Hauptmodul dieses Repositorys ist der Mamba-Architekturblock, der den selektiven SSM umschließt.
Quelle: module/mamba_simple.py.
Verwendung:
import torch
from mamba_ssm import Mamba
batch , length , dim = 2 , 64 , 16
x = torch . randn ( batch , length , dim ). to ( "cuda" )
model = Mamba (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 16 , # SSM state expansion factor
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Der Mamba-2-Block ist unter module/mamba2.py implementiert.
Eine einfachere Version finden Sie unter module/mamba2_simple.py
Die Verwendung ähnelt Mamba(-1):
from mamba_ssm import Mamba2
model = Mamba2 (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 64 , # SSM state expansion factor, typically 64 or 128
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Eine Minimalversion des inneren SSD-Moduls (Listing 1 aus dem Mamba-2-Artikel) mit Konvertierung zwischen „diskreten“ und „kontinuierlichen“ SSM-Versionen finden Sie unter module/ssd_minimal.py.
Abschließend stellen wir ein Beispiel für ein vollständiges Sprachmodell bereit: ein Deep-Sequence-Modell-Backbone (mit sich wiederholenden Mamba-Blöcken) + Sprachmodell-Kopf.
Quelle: models/mixer_seq_simple.py.
Dies ist ein Beispiel für die Integration von Mamba in ein durchgängiges neuronales Netzwerk. Dieses Beispiel wird in den folgenden Generierungsskripten verwendet.
Vortrainierte Modelle werden auf Hugging Face hochgeladen: mamba-130m
, mamba-370m
, mamba-790m
, mamba-1.4b
, mamba-2.8b
, mamba2-130m
, mamba2-370m
, mamba2-780m
, mamba2-1.3b
, mamba2-2.7b
, transformerpp-2.7b
, mamba2attn-2.7b
, trainiert auf 300B-Tokens auf dem Pile, sowie mamba-2.8b-slimpj
(trainiert auf 600B-Tokens auf dem SlimPajama-Datensatz).
Die Modelle werden durch das Generierungsskript unten automatisch heruntergeladen.
Diese Modelle wurden auf dem Pile trainiert und folgen den von GPT-3 beschriebenen Standardmodelldimensionen, denen viele Open-Source-Modelle folgen:
Parameter | Schichten | Modellabm. |
---|---|---|
130M | 24 | 768 |
370M | 48 | 1024 |
790M | 48 | 1536 |
1,4B | 48 | 2048 |
2,8B | 64 | 2560 |
(Die Layer-Anzahl von Mamba verdoppelt die eines Transformers mit ähnlicher Größe, da für jede „Layer“ (MHA-Block + MLP-Block) eines Transformers zwei Mamba-Blöcke benötigt werden.)
Hinweis: Hierbei handelt es sich um Basismodelle, die nur für 300B-Tokens trainiert wurden, ohne jegliche nachgelagerte Modifikation (Anweisungsoptimierung usw.). Es wird erwartet, dass die Leistung mit der anderer Architekturen vergleichbar oder besser ist, die auf ähnlichen Daten trainiert werden, jedoch nicht mit größeren oder fein abgestimmten Modellen mithalten kann.
Um Zero-Shot-Bewertungen von Modellen durchzuführen (entsprechend Tabelle 3 des Artikels), verwenden wir die lm-evaluation-harness-Bibliothek.
lm-evaluation-harness
mit pip install lm-eval==0.4.2
.lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
Um die in den Blogbeiträgen berichteten Ergebnisse für das mamba-2.8b-slimpj
-Modell zu reproduzieren:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
Um Auswertungen für Mamba-2-Modelle durchzuführen, ersetzen Sie einfach die Modellnamen:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
Beachten Sie, dass das Ergebnis jeder Aufgabe aufgrund von Störungen im Bewertungsprozess um 0,1–0,3 von den gemeldeten Werten abweichen kann.
Das Skript benchmarks/benchmark_generation_mamba_simple.py
Weitere konfigurierbare Optionen sind die Top-P-Wahrscheinlichkeit (Kernprobenahme) und die Softmax-Temperatur.
So testen Sie die Generierungslatenz (z. B. Batchgröße = 1) mit verschiedenen Stichprobenstrategien:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
So testen Sie den Generierungsdurchsatz mit zufälligen Eingabeaufforderungen (z. B. große Batchgröße):
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --batch 64
Bei Mamba-2 müssen Sie lediglich den Modellnamen ändern:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba2-2.7b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
Unsere Modelle wurden mit PyTorch AMP für gemischte Präzision trainiert. AMP behält Modellparameter in float32 und wandelt sie bei Bedarf mit halber Genauigkeit um. Andererseits speichern andere Frameworks wie DeepSpeed Parameter in float16 und übertragen sie bei Bedarf (z. B. zur Optimiererakkumulation) hoch.
Wir haben beobachtet, dass möglicherweise eine höhere Präzision für die Hauptmodellparameter erforderlich ist, da SSMs empfindlich auf ihre wiederkehrende Dynamik reagieren. Wenn bei Ihnen Instabilitäten auftreten, versuchen Sie es bitte als ersten Schritt mit einem Framework, das Parameter in fp32 speichert (z. B. AMP).
Einige Teile des Modells verfügen über Initialisierungen, die aus früheren Arbeiten an S4-Modellen übernommen wurden. Zum Beispiel die nn.Linear
Modulen auf Null). Wenn dies der Fall ist, müssen Sie möglicherweise eine benutzerdefinierte Logik hinzufügen (z. B. diese Zeile deaktiviert die Neuinitialisierung in unserem Trainer, wäre aber in jedem anderen Framework ein No-Op), die spezifisch für das Trainingsframework ist.
Wenn Sie ROCm 6.0 verwenden, führen Sie die folgenden Schritte aus, um Fehler während der Kompilierung zu vermeiden. Dies ist ab ROCm 6.1 nicht erforderlich.
Suchen Sie Ihr ROCm-Installationsverzeichnis. Dies ist normalerweise unter /opt/rocm/
zu finden, kann aber je nach Installation variieren.
Wenden Sie den Patch an. Führen Sie es mit sudo
aus, falls Sie auf Berechtigungsprobleme stoßen.
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
Wenn Sie diese Codebasis verwenden oder unsere Arbeit anderweitig wertvoll finden, zitieren Sie bitte Mamba:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}