Mamba: Modelagem de Sequência de Tempo Linear com Espaços de Estado Seletivos
Albert Gu*, Tri Dao*
Artigo: https://arxiv.org/abs/2312.00752
Transformadores são SSMs: Modelos Generalizados e Algoritmos Eficientes
Através da Dualidade do Espaço de Estados Estruturados
Tri Dao*, Albert Gu*
Artigo: https://arxiv.org/abs/2405.21060
Mamba é uma nova arquitetura de modelo de espaço de estados que mostra desempenho promissor em dados densos em informações, como modelagem de linguagem, onde os modelos subquadráticos anteriores ficam aquém dos Transformers. Baseia-se na linha de progresso em modelos de espaço de estados estruturados, com um design eficiente e consciente de hardware e implementação no espírito do FlashAttention.
pip install causal-conv1d>=1.4.0
: uma implementação eficiente de uma camada Conv1d causal simples usada dentro do bloco Mamba.pip install mamba-ssm
: o pacote principal do Mamba.pip install mamba-ssm[causal-conv1d]
: Para instalar o pacote principal do Mamba e causal-conv1d.pip install mamba-ssm[dev]
: Para instalar o pacote principal do Mamba e dependências de desenvolvimento. Ele também pode ser compilado a partir do código-fonte com pip install .
deste repositório.
Se pip
reclamar das versões do PyTorch, tente passar --no-build-isolation
para pip
.
Outros requisitos:
Para placas AMD, consulte os pré-requisitos adicionais abaixo.
Expomos vários níveis de interface com o modelo Mamba.
O Mamba é baseado em uma camada SSM seletiva, que é o foco do artigo (Seção 3; Algoritmo 2).
Fonte: ops/selective_scan_interface.py.
O módulo principal deste repositório é o bloco de arquitetura Mamba que envolve o SSM seletivo.
Fonte: módulos/mamba_simple.py.
Uso:
import torch
from mamba_ssm import Mamba
batch , length , dim = 2 , 64 , 16
x = torch . randn ( batch , length , dim ). to ( "cuda" )
model = Mamba (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 16 , # SSM state expansion factor
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
O bloco Mamba-2 é implementado em module/mamba2.py.
Uma versão mais simples está em module/mamba2_simple.py
O uso é semelhante ao Mamba(-1):
from mamba_ssm import Mamba2
model = Mamba2 (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 64 , # SSM state expansion factor, typically 64 or 128
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Uma versão mínima do módulo SSD interno (Listagem 1 do artigo Mamba-2) com conversão entre versões SSM "discretas" e "contínuas" está em module/ssd_minimal.py.
Finalmente, fornecemos um exemplo de um modelo de linguagem completo: um backbone de modelo de sequência profunda (com blocos repetidos do Mamba) + cabeça do modelo de linguagem.
Fonte: models/mixer_seq_simple.py.
Este é um exemplo de como integrar o Mamba em uma rede neural ponta a ponta. Este exemplo é usado nos scripts de geração abaixo.
Modelos pré-treinados são carregados no Hugging Face: mamba-130m
, mamba-370m
, mamba-790m
, mamba-1.4b
, mamba-2.8b
, mamba2-130m
, mamba2-370m
, mamba2-780m
, mamba2-1.3b
, mamba2-2.7b
, transformerpp-2.7b
, mamba2attn-2.7b
, treinado em tokens de 300 bilhões na pilha, bem como mamba-2.8b-slimpj
(treinado em tokens de 600 bilhões no conjunto de dados SlimPajama).
Os modelos serão baixados automaticamente pelo script de geração abaixo.
Esses modelos foram treinados no Pile e seguem as dimensões do modelo padrão descritas pelo GPT-3 e seguidos por muitos modelos de código aberto:
Parâmetros | Camadas | Modelo escuro. |
---|---|---|
130 milhões | 24 | 768 |
370 milhões | 48 | 1024 |
790 milhões | 48 | 1536 |
1,4B | 48 | 2048 |
2,8B | 64 | 2560 |
(A contagem de camadas do Mamba dobra a de um Transformer de tamanho semelhante, pois são necessários dois blocos Mamba para cada "camada" (bloco MHA + bloco MLP) de um Transformer.)
Nota: estes são modelos básicos treinados apenas para tokens de 300B, sem qualquer forma de modificação downstream (ajuste de instruções, etc.). Espera-se que o desempenho seja comparável ou melhor do que outras arquiteturas treinadas em dados semelhantes, mas não corresponda a modelos maiores ou ajustados.
Para executar avaliações zero-shot de modelos (correspondentes à Tabela 3 do artigo), usamos a biblioteca lm-evaluation-harness.
lm-evaluation-harness
por pip install lm-eval==0.4.2
.lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
Para reproduzir os resultados no modelo mamba-2.8b-slimpj
relatado nas postagens do blog:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
Para executar avaliações em modelos Mamba-2, basta substituir os nomes dos modelos:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
Observe que o resultado de cada tarefa pode diferir dos valores relatados em 0,1-0,3 devido ao ruído no processo de avaliação.
O script benchmarks/benchmark_generation_mamba_simple.py
Outras opções configuráveis incluem a probabilidade top-p (amostragem de núcleo) e a temperatura softmax.
Para testar a latência de geração (por exemplo, tamanho do lote = 1) com diferentes estratégias de amostragem:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
Para testar o rendimento da geração com prompts aleatórios (por exemplo, tamanho de lote grande):
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --batch 64
Com o Mamba-2, você só precisa alterar o nome do modelo:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba2-2.7b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
Nossos modelos foram treinados usando PyTorch AMP para precisão mista. O AMP mantém os parâmetros do modelo em float32 e converte com meia precisão quando necessário. Por outro lado, outras estruturas como DeepSpeed armazenam parâmetros em float16 e upcasts quando necessário (por exemplo, para acumulação de otimizador).
Observamos que pode ser necessária maior precisão para os principais parâmetros do modelo, porque os SSMs são sensíveis à sua dinâmica recorrente. Se você estiver enfrentando instabilidades, como primeiro passo, tente um framework que armazene parâmetros em fp32 (como AMP).
Algumas partes do modelo possuem inicializações herdadas de trabalhos anteriores em modelos S4. Por exemplo, o nn.Linear
como zero). Se for esse o caso, talvez seja necessário adicionar uma lógica personalizada (por exemplo, esta linha desativa a reinicialização em nosso treinador, mas não funcionaria em qualquer outra estrutura) que seja específica da estrutura de treinamento.
Se você estiver no ROCm 6.0, execute as etapas a seguir para evitar erros durante a compilação. Isso não é necessário para ROCm 6.1 em diante.
Localize o diretório de instalação do ROCm. Normalmente é encontrado em /opt/rocm/
, mas pode variar dependendo da sua instalação.
Aplique o patch. Execute com sudo
caso encontre problemas de permissão.
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
Se você usa esta base de código ou considera nosso trabalho valioso, cite Mamba:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}