Mamba : modélisation de séquences temporelles linéaires avec des espaces d'états sélectifs
Albert Gu*, Tri Dao*
Article : https://arxiv.org/abs/2312.00752
Les transformateurs sont des SSM : modèles généralisés et algorithmes efficaces
Grâce à la dualité spatiale d’états structurés
Tri Dao*, Albert Gu*
Article : https://arxiv.org/abs/2405.21060
Mamba est une nouvelle architecture de modèle d'espace d'état montrant des performances prometteuses sur des données riches en informations telles que la modélisation du langage, là où les modèles sous-quadratiques précédents sont en deçà des Transformers. Il est basé sur la ligne de progrès des modèles spatiaux d'états structurés, avec une conception et une mise en œuvre efficaces tenant compte du matériel dans l'esprit de FlashAttention.
pip install causal-conv1d>=1.4.0
: une implémentation efficace d'une simple couche causale Conv1d utilisée à l'intérieur du bloc Mamba.pip install mamba-ssm
: le package principal de Mamba.pip install mamba-ssm[causal-conv1d]
: Pour installer le package principal Mamba et causal-conv1d.pip install mamba-ssm[dev]
: Pour installer le package Mamba principal et les dépendances de développement. Il peut également être construit à partir des sources avec pip install .
à partir de ce référentiel.
Si pip
se plaint des versions de PyTorch, essayez de transmettre --no-build-isolation
à pip
.
Autres exigences :
Pour les cartes AMD, consultez les conditions préalables supplémentaires ci-dessous.
Nous exposons plusieurs niveaux d'interface avec le modèle Mamba.
Mamba est basé sur une couche SSM sélective, qui est au centre de l'article (Section 3 ; Algorithme 2).
Source : ops/selective_scan_interface.py.
Le module principal de ce référentiel est le bloc d'architecture Mamba encapsulant le SSM sélectif.
Source : modules/mamba_simple.py.
Usage:
import torch
from mamba_ssm import Mamba
batch , length , dim = 2 , 64 , 16
x = torch . randn ( batch , length , dim ). to ( "cuda" )
model = Mamba (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 16 , # SSM state expansion factor
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Le bloc Mamba-2 est implémenté dans modules/mamba2.py.
Une version plus simple se trouve sur modules/mamba2_simple.py
L'utilisation est similaire à Mamba(-1) :
from mamba_ssm import Mamba2
model = Mamba2 (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 64 , # SSM state expansion factor, typically 64 or 128
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Une version minimale du module SSD interne (liste 1 de l'article Mamba-2) avec conversion entre les versions SSM « discrètes » et « continues » se trouve sur modules/ssd_minimal.py.
Enfin, nous fournissons un exemple de modèle de langage complet : un squelette de modèle de séquence profonde (avec des blocs Mamba répétitifs) + une tête de modèle de langage.
Source : models/mixer_seq_simple.py.
Ceci est un exemple de la façon d'intégrer Mamba dans un réseau neuronal de bout en bout. Cet exemple est utilisé dans les scripts de génération ci-dessous.
Les modèles pré-entraînés sont téléchargés sur Hugging Face : mamba-130m
, mamba-370m
, mamba-790m
, mamba-1.4b
, mamba-2.8b
, mamba2-130m
, mamba2-370m
, mamba2-780m
, mamba2-1.3b
, mamba2-2.7b
, transformerpp-2.7b
, mamba2attn-2.7b
, formé sur 300 B de jetons sur la pile, ainsi que mamba-2.8b-slimpj
(entraîné sur 600 B de jetons sur l'ensemble de données SlimPajama).
Les modèles seront automatiquement téléchargés par le script de génération ci-dessous.
Ces modèles ont été formés sur la Pile et suivent les dimensions du modèle standard décrites par GPT-3 et suivies par de nombreux modèles open source :
Paramètres | Calques | Modèle faible. |
---|---|---|
130M | 24 | 768 |
370M | 48 | 1024 |
790M | 48 | 1536 |
1,4B | 48 | 2048 |
2,8 milliards | 64 | 2560 |
(Le nombre de couches de Mamba double celui d'un Transformer de taille similaire, car deux blocs Mamba sont nécessaires pour chaque "couche" (bloc MHA + bloc MLP) d'un Transformer.)
Remarque : il s'agit de modèles de base formés uniquement pour les jetons 300B, sans aucune forme de modification en aval (réglage des instructions, etc.). Les performances devraient être comparables ou meilleures que celles d'autres architectures formées sur des données similaires, mais ne pas correspondre à des modèles plus grands ou affinés.
Pour exécuter des évaluations zéro-shot des modèles (correspondant au tableau 3 de l'article), nous utilisons la bibliothèque lm-evaluation-harness.
lm-evaluation-harness
par pip install lm-eval==0.4.2
.lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
Pour reproduire les résultats sur le modèle mamba-2.8b-slimpj
rapportés dans les articles du blog :
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
Pour exécuter des évaluations sur les modèles Mamba-2, remplacez simplement les noms des modèles :
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
Notez que le résultat de chaque tâche peut différer des valeurs rapportées de 0,1 à 0,3 en raison du bruit dans le processus d'évaluation.
Le script benchmarks/benchmark_generation_mamba_simple.py
D'autres options configurables incluent la probabilité top-p (échantillonnage du noyau) et la température softmax.
Pour tester la latence de génération (par exemple taille du lot = 1) avec différentes stratégies d'échantillonnage :
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
Pour tester le débit de génération avec des invites aléatoires (par exemple, un lot de grande taille) :
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --batch 64
Avec Mamba-2, il vous suffit de changer le nom du modèle :
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba2-2.7b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
Nos modèles ont été formés à l'aide de PyTorch AMP pour une précision mixte. AMP conserve les paramètres du modèle dans float32 et convertit en demi-précision si nécessaire. D'un autre côté, d'autres frameworks comme DeepSpeed stockent les paramètres dans float16 et les upcasts si nécessaire (par exemple pour l'accumulation d'optimiseur).
Nous avons observé qu'une plus grande précision pour les principaux paramètres du modèle peut être nécessaire, car les SSM sont sensibles à leur dynamique récurrente. Si vous rencontrez des instabilités, essayez dans un premier temps un framework stockant les paramètres dans fp32 (comme AMP).
Certaines parties du modèle ont des initialisations héritées de travaux antérieurs sur les modèles S4. Par exemple, le nn.Linear
à zéro). Si tel est le cas, vous devrez peut-être ajouter une logique personnalisée (par exemple, cette ligne désactive la réinitialisation dans notre entraîneur, mais ne serait pas opérationnelle dans tout autre framework) spécifique au framework de formation.
Si vous utilisez ROCm 6.0, exécutez les étapes suivantes pour éviter les erreurs lors de la compilation. Ceci n’est pas requis à partir de la version ROCm 6.1.
Localisez votre répertoire d'installation ROCm. Ceci se trouve généralement dans /opt/rocm/
, mais peut varier en fonction de votre installation.
Appliquez le patch. Exécutez avec sudo
au cas où vous rencontreriez des problèmes d'autorisation.
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
Si vous utilisez cette base de code, ou si vous trouvez notre travail précieux, veuillez citer Mamba :
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}