Mamba: modelado de secuencias de tiempo lineal con espacios de estados selectivos
Albert Gu*, Tri Dao*
Documento: https://arxiv.org/abs/2312.00752
Los transformadores son SSM: modelos generalizados y algoritmos eficientes
A través de la dualidad del espacio de estados estructurado
Tri Dao*, Albert Gu*
Documento: https://arxiv.org/abs/2405.21060
Mamba es una nueva arquitectura de modelo de espacio de estados que muestra un rendimiento prometedor en datos densos en información, como el modelado de lenguaje, donde los modelos subcuadráticos anteriores no alcanzan a Transformers. Se basa en la línea de progreso sobre modelos de espacio de estados estructurados, con un diseño e implementación eficiente consciente del hardware en el espíritu de FlashAttention.
pip install causal-conv1d>=1.4.0
: una implementación eficiente de una capa Conv1d causal simple utilizada dentro del bloque Mamba.pip install mamba-ssm
: el paquete principal de Mamba.pip install mamba-ssm[causal-conv1d]
: para instalar el paquete principal de Mamba y causal-conv1d.pip install mamba-ssm[dev]
: para instalar el paquete principal de Mamba y las dependencias de desarrollo. También se puede construir desde el código fuente con pip install .
de este repositorio.
Si pip
se queja de las versiones de PyTorch, intente pasar --no-build-isolation
a pip
.
Otros requisitos:
Para tarjetas AMD, consulte los requisitos previos adicionales a continuación.
Exponemos varios niveles de interfaz con el modelo Mamba.
Mamba se basa en una capa SSM selectiva, que es el tema central del artículo (Sección 3; Algoritmo 2).
Fuente: ops/selective_scan_interface.py.
El módulo principal de este repositorio es el bloque de arquitectura Mamba que envuelve el SSM selectivo.
Fuente: módulos/mamba_simple.py.
Uso:
import torch
from mamba_ssm import Mamba
batch , length , dim = 2 , 64 , 16
x = torch . randn ( batch , length , dim ). to ( "cuda" )
model = Mamba (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 16 , # SSM state expansion factor
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
El bloque Mamba-2 se implementa en módulos/mamba2.py.
Una versión más simple está en module/mamba2_simple.py
El uso es similar a Mamba(-1):
from mamba_ssm import Mamba2
model = Mamba2 (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 64 , # SSM state expansion factor, typically 64 or 128
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Una versión mínima del módulo SSD interno (Listado 1 del documento Mamba-2) con conversión entre versiones SSM "discretas" y "continuas" se encuentra en module/ssd_minimal.py.
Finalmente, proporcionamos un ejemplo de un modelo de lenguaje completo: una columna vertebral de modelo de secuencia profunda (con bloques Mamba repetidos) + cabeza de modelo de lenguaje.
Fuente: modelos/mixer_seq_simple.py.
Este es un ejemplo de cómo integrar Mamba en una red neuronal de un extremo a otro. Este ejemplo se utiliza en los scripts de generación siguientes.
Los modelos previamente entrenados se cargan en Hugging Face: mamba-130m
, mamba-370m
, mamba-790m
, mamba-1.4b
, mamba mamba2-2.7b
mamba-2.8b
, mamba2-130m
, mamba2-370m
, mamba2-780m
, mamba2-1.3b
mamba2-2.7b
, transformerpp-2.7b
, mamba2attn-2.7b
, entrenado en 300 mil millones de tokens en Pile, así como mamba-2.8b-slimpj
(entrenado en 600 mil millones de tokens en el conjunto de datos SlimPajama).
Los modelos se descargarán automáticamente mediante el siguiente script de generación.
Estos modelos se entrenaron en Pile y siguen las dimensiones del modelo estándar descritas por GPT-3 y seguidas por muchos modelos de código abierto:
Parámetros | capas | Modelo tenue. |
---|---|---|
130M | 24 | 768 |
370M | 48 | 1024 |
790M | 48 | 1536 |
1.4B | 48 | 2048 |
2.8B | 64 | 2560 |
(El número de capas de Mamba duplica el de un Transformer con tamaño similar, ya que se necesitan dos bloques Mamba para cada "capa" (bloque MHA + bloque MLP) de un Transformer).
Nota: estos son modelos base entrenados solo para tokens 300B, sin ningún tipo de modificación posterior (ajuste de instrucciones, etc.). Se espera que el rendimiento sea comparable o mejor que el de otras arquitecturas entrenadas con datos similares, pero no que coincida con modelos más grandes o ajustados.
Para ejecutar evaluaciones de tiro cero de modelos (correspondientes a la Tabla 3 del artículo), utilizamos la biblioteca lm-evaluación-arnés.
lm-evaluation-harness
mediante pip install lm-eval==0.4.2
.lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
Para reproducir los resultados del modelo mamba-2.8b-slimpj
informados en las publicaciones del blog:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
Para ejecutar evaluaciones en modelos Mamba-2, simplemente reemplace los nombres de los modelos:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
Tenga en cuenta que el resultado de cada tarea puede diferir de los valores informados entre 0,1 y 0,3 debido al ruido en el proceso de evaluación.
El script benchmarks/benchmark_generación_mamba_simple.py
Otras opciones configurables incluyen la probabilidad top-p (muestreo de núcleos) y la temperatura softmax.
Para probar la latencia de generación (por ejemplo, tamaño de lote = 1) con diferentes estrategias de muestreo:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
Para probar el rendimiento de generación con indicaciones aleatorias (por ejemplo, lotes de gran tamaño):
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --batch 64
Con Mamba-2, sólo necesitas cambiar el nombre del modelo:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba2-2.7b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
Nuestros modelos fueron entrenados usando PyTorch AMP para precisión mixta. AMP mantiene los parámetros del modelo en float32 y lanza a la mitad de precisión cuando es necesario. Por otro lado, otros frameworks como DeepSpeed almacenan parámetros en float16 y actualizan cuando es necesario (por ejemplo, para la acumulación de optimizadores).
Hemos observado que puede ser necesaria una mayor precisión para los principales parámetros del modelo, porque los SSM son sensibles a su dinámica recurrente. Si experimenta inestabilidades, como primer paso, pruebe con un marco que almacene parámetros en fp32 (como AMP).
Algunas partes del modelo tienen inicializaciones heredadas de trabajos anteriores en modelos S4. Por ejemplo, el nn.Linear
en cero). Si este es el caso, es posible que tenga que agregar una lógica personalizada (por ejemplo, esta línea desactiva la reinicialización en nuestro entrenador, pero no sería operativa en cualquier otro marco) que sea específica del marco de capacitación.
Si está en ROCm 6.0, ejecute los siguientes pasos para evitar errores durante la compilación. Esto no es necesario para ROCm 6.1 en adelante.
Localice su directorio de instalación de ROCm. Esto normalmente se encuentra en /opt/rocm/
, pero puede variar según su instalación.
Aplicar el parche. Ejecute con sudo
en caso de que encuentre problemas de permisos.
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
Si utiliza este código base o encuentra valioso nuestro trabajo, cite a Mamba:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}