Mamba: Pemodelan Urutan Waktu Linier dengan Ruang Keadaan Selektif
Albert Gu*, Tri Dao*
Makalah: https://arxiv.org/abs/2312.00752
Transformer adalah SSM: Model Umum dan Algoritma Efisien
Melalui Dualitas Ruang Negara Terstruktur
Tri Dao*, Albert Gu*
Makalah: https://arxiv.org/abs/2405.21060
Mamba adalah arsitektur model ruang negara baru yang menunjukkan kinerja menjanjikan pada data padat informasi seperti pemodelan bahasa, dimana model subkuadrat sebelumnya tidak mampu dibandingkan dengan Transformers. Hal ini didasarkan pada kemajuan dalam model ruang negara terstruktur, dengan desain dan implementasi berbasis perangkat keras yang efisien dalam semangat FlashAttention.
pip install causal-conv1d>=1.4.0
: implementasi efisien dari lapisan Conv1d kausal sederhana yang digunakan di dalam blok Mamba.pip install mamba-ssm
: paket inti Mamba.pip install mamba-ssm[causal-conv1d]
: Untuk menginstal paket inti Mamba dan causal-conv1d.pip install mamba-ssm[dev]
: Untuk menginstal paket inti Mamba dan dependensi dev. Itu juga dapat dibangun dari sumber dengan pip install .
dari repositori ini.
Jika pip
mengeluh tentang versi PyTorch, coba teruskan --no-build-isolation
ke pip
.
Persyaratan lainnya:
Untuk kartu AMD, lihat prasyarat tambahan di bawah.
Kami memaparkan beberapa tingkat antarmuka dengan model Mamba.
Mamba didasarkan pada lapisan SSM selektif, yang merupakan fokus makalah ini (Bagian 3; Algoritma 2).
Sumber: ops/selective_scan_interface.py.
Modul utama repositori ini adalah blok arsitektur Mamba yang membungkus SSM selektif.
Sumber: modul/mamba_simple.py.
Penggunaan:
import torch
from mamba_ssm import Mamba
batch , length , dim = 2 , 64 , 16
x = torch . randn ( batch , length , dim ). to ( "cuda" )
model = Mamba (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 16 , # SSM state expansion factor
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Blok Mamba-2 diimplementasikan di module/mamba2.py.
Versi yang lebih sederhana ada di module/mamba2_simple.py
Penggunaannya mirip dengan Mamba(-1):
from mamba_ssm import Mamba2
model = Mamba2 (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 64 , # SSM state expansion factor, typically 64 or 128
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
Versi minimal modul SSD bagian dalam (Daftar 1 dari makalah Mamba-2) dengan konversi antara versi SSM "diskrit" dan "kontinu" ada di module/ssd_minimal.py.
Terakhir, kami memberikan contoh model bahasa lengkap: tulang punggung model urutan dalam (dengan blok Mamba berulang) + kepala model bahasa.
Sumber: models/mixer_seq_simple.py.
Ini adalah contoh cara mengintegrasikan Mamba ke dalam jaringan saraf ujung ke ujung. Contoh ini digunakan dalam skrip generasi di bawah ini.
Model yang telah dilatih sebelumnya diunggah ke Hugging Face: mamba-130m
, mamba-370m
, mamba-790m
, mamba-1.4b
, mamba-2.8b
, mamba2-130m
, mamba2-370m
, mamba2-780m
, mamba2-1.3b
, mamba2-2.7b
, transformerpp-2.7b
, mamba2attn-2.7b
, dilatih dengan 300 miliar token di Pile, serta mamba-2.8b-slimpj
(dilatih dengan 600 miliar token di kumpulan data SlimPajama).
Model akan diunduh secara otomatis oleh skrip pembuatan di bawah.
Model ini dilatih di Pile, dan mengikuti dimensi model standar yang dijelaskan oleh GPT-3 dan diikuti oleh banyak model sumber terbuka:
Parameter | Lapisan | Modelnya redup. |
---|---|---|
130M | 24 | 768 |
370M | 48 | 1024 |
790M | 48 | 1536 |
1.4B | 48 | 2048 |
2.8B | 64 | 2560 |
(Jumlah lapisan Mamba dua kali lipat dari Transformer dengan ukuran serupa, karena diperlukan dua blok Mamba untuk setiap "lapisan" (blok MHA + blok MLP) dari Transformer.)
Catatan: ini adalah model dasar yang dilatih hanya untuk token 300 miliar, tanpa modifikasi hilir apa pun (penyetelan instruksi, dll.). Performanya diharapkan dapat dibandingkan atau lebih baik dibandingkan arsitektur lain yang dilatih berdasarkan data serupa, namun tidak dapat menandingi model yang lebih besar atau lebih baik.
Untuk menjalankan evaluasi model zero-shot (sesuai dengan Tabel 3 makalah), kami menggunakan perpustakaan lm-evaluation-harness.
lm-evaluation-harness
dengan pip install lm-eval==0.4.2
.lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
Untuk mereproduksi hasil pada model mamba-2.8b-slimpj
yang dilaporkan di postingan blog:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
Untuk menjalankan evaluasi pada model Mamba-2, cukup ganti nama model:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
Perhatikan bahwa hasil setiap tugas mungkin berbeda dari nilai yang dilaporkan sebesar 0,1-0,3 karena adanya gangguan dalam proses evaluasi.
Skrip benchmark/benchmark_generasi_mamba_simple.py
Opsi lain yang dapat dikonfigurasi mencakup probabilitas top-p (pengambilan sampel inti), dan suhu softmax.
Untuk menguji latensi pembangkitan (misalnya ukuran batch = 1) dengan strategi pengambilan sampel yang berbeda:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
Untuk menguji throughput pembangkitan dengan perintah acak (misalnya ukuran batch besar):
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --batch 64
Dengan Mamba-2, Anda hanya perlu mengubah nama model:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba2-2.7b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
Model kami dilatih menggunakan PyTorch AMP untuk presisi campuran. AMP menyimpan parameter model di float32 dan menghasilkan setengah presisi bila diperlukan. Di sisi lain, kerangka kerja lain seperti DeepSpeed menyimpan parameter di float16 dan upcast bila diperlukan (misalnya untuk akumulasi pengoptimal).
Kami telah mengamati bahwa presisi yang lebih tinggi untuk parameter model utama mungkin diperlukan, karena SSM sensitif terhadap dinamika berulangnya. Jika Anda mengalami ketidakstabilan, sebagai langkah awal silakan coba framework yang menyimpan parameter di fp32 (seperti AMP).
Beberapa bagian model memiliki inisialisasi yang diwarisi dari pekerjaan sebelumnya pada model S4. Misalnya, nn.Linear
ke nol). Jika hal ini terjadi, Anda mungkin harus menambahkan logika kustom (misalnya baris ini menonaktifkan inisialisasi ulang di trainer kami, namun tidak boleh digunakan di framework lain) yang khusus untuk framework pelatihan.
Jika Anda menggunakan ROCm 6.0, jalankan langkah-langkah berikut untuk menghindari kesalahan selama kompilasi. Ini tidak diperlukan untuk ROCm 6.1 dan seterusnya.
Temukan direktori instalasi ROCm Anda. Ini biasanya ditemukan di /opt/rocm/
, tetapi dapat bervariasi tergantung pada instalasi Anda.
Terapkan Patch. Jalankan dengan sudo
jika Anda mengalami masalah izin.
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
Jika Anda menggunakan basis kode ini, atau menganggap karya kami berharga, harap kutip Mamba:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}