Mamba: การสร้างแบบจำลองลำดับเวลาเชิงเส้นพร้อมช่องว่างสถานะแบบเลือก
อัลเบิร์ต กู*, ตรีดาว*
บทความ: https://arxiv.org/abs/2312.00752
Transformers คือ SSM: โมเดลทั่วไปและอัลกอริทึมที่มีประสิทธิภาพ
ผ่านความเป็นทวิภาคอวกาศที่มีโครงสร้าง
ตรีดาว*, อัลเบิร์ต กู*
บทความ: https://arxiv.org/abs/2405.21060
Mamba เป็นสถาปัตยกรรมโมเดลพื้นที่สถานะใหม่ที่แสดงประสิทธิภาพที่น่าคาดหวังกับข้อมูลที่มีความหนาแน่นสูง เช่น การสร้างแบบจำลองภาษา ซึ่งโมเดลย่อยกำลังสองก่อนหน้านี้ยังขาด Transformers มันขึ้นอยู่กับแนวความก้าวหน้าของแบบจำลองพื้นที่รัฐที่มีโครงสร้าง พร้อมด้วยการออกแบบที่คำนึงถึงฮาร์ดแวร์ที่มีประสิทธิภาพและการใช้งานตามจิตวิญญาณของ FlashAttention
pip install causal-conv1d>=1.4.0
: การใช้งานที่มีประสิทธิภาพของเลเยอร์ Conv1d เชิงสาเหตุแบบง่ายๆ ที่ใช้ภายในบล็อก Mambapip install mamba-ssm
: แพ็คเกจ Mamba หลักpip install mamba-ssm[causal-conv1d]
: เพื่อติดตั้งแพ็คเกจ Mamba หลักและ causal-conv1dpip install mamba-ssm[dev]
: เพื่อติดตั้งแพ็คเกจ Mamba หลักและการพึ่งพา dev นอกจากนี้ยังสามารถสร้างขึ้นจากแหล่งที่มาด้วย pip install .
จากที่เก็บข้อมูลนี้
หาก pip
บ่นเกี่ยวกับเวอร์ชันของ PyTorch ให้ลองส่ง --no-build-isolation
ไปยัง pip
ข้อกำหนดอื่นๆ:
สำหรับการ์ด AMD โปรดดูข้อกำหนดเบื้องต้นเพิ่มเติมด้านล่าง
เราเปิดเผยอินเทอร์เฟซหลายระดับด้วยโมเดล Mamba
Mamba ขึ้นอยู่กับเลเยอร์ SSM แบบเลือกสรร ซึ่งเป็นจุดสนใจของรายงาน (ส่วนที่ 3; อัลกอริทึม 2)
ที่มา: ops/selective_scan_interface.py
โมดูลหลักของพื้นที่เก็บข้อมูลนี้คือบล็อกสถาปัตยกรรม Mamba ที่ห่อ SSM แบบเลือกไว้
ที่มา: modules/mamba_simple.py
การใช้งาน:
import torch
from mamba_ssm import Mamba
batch , length , dim = 2 , 64 , 16
x = torch . randn ( batch , length , dim ). to ( "cuda" )
model = Mamba (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 16 , # SSM state expansion factor
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
บล็อก Mamba-2 ถูกใช้งานที่ modules/mamba2.py
เวอร์ชันที่เรียบง่ายกว่าอยู่ที่ modules/mamba2_simple.py
การใช้งานคล้ายกับ Mamba(-1):
from mamba_ssm import Mamba2
model = Mamba2 (
# This module uses roughly 3 * expand * d_model^2 parameters
d_model = dim , # Model dimension d_model
d_state = 64 , # SSM state expansion factor, typically 64 or 128
d_conv = 4 , # Local convolution width
expand = 2 , # Block expansion factor
). to ( "cuda" )
y = model ( x )
assert y . shape == x . shape
โมดูล SSD ภายในเวอร์ชันขั้นต่ำ (รายการ 1 จากกระดาษ Mamba-2) ที่มีการแปลงระหว่างเวอร์ชัน SSM "ไม่ต่อเนื่อง" และ "ต่อเนื่อง" อยู่ที่ modules/ssd_minimal.py
สุดท้ายนี้ เราได้จัดเตรียมตัวอย่างของโมเดลภาษาที่สมบูรณ์: แกนหลักของโมเดลลำดับเชิงลึก (พร้อมบล็อก Mamba ที่ทำซ้ำ) + ส่วนหัวของโมเดลภาษา
ที่มา: models/mixer_seq_simple.py
นี่คือตัวอย่างวิธีการรวม Mamba เข้ากับโครงข่ายประสาทเทียมแบบ end-to-end ตัวอย่างนี้ใช้ในสคริปต์การสร้างด้านล่าง
โมเดลที่ฝึกไว้ล่วงหน้าจะถูกอัปโหลดไปยัง Hugging Face: mamba-130m
, mamba-370m
, mamba-790m
, mamba-1.4b
, mamba-2.8b
, mamba2-130m
, mamba2-370m
, mamba2-780m
, mamba2-1.3b
, mamba2-2.7b
, transformerpp-2.7b
, mamba2attn-2.7b
ฝึกฝนบนโทเค็น 300B บน Pile เช่นเดียวกับ mamba-2.8b-slimpj
(ฝึกฝนบนโทเค็น 600B บนชุดข้อมูล SlimPajama)
โมเดลจะถูกดาวน์โหลดอัตโนมัติโดยสคริปต์การสร้างด้านล่าง
โมเดลเหล่านี้ได้รับการฝึกฝนบน Pile และเป็นไปตามขนาดโมเดลมาตรฐานที่อธิบายโดย GPT-3 และตามด้วยโมเดลโอเพ่นซอร์สหลายตัว:
พารามิเตอร์ | เลเยอร์ | โมเดลสลัว |
---|---|---|
130ม | 24 | 768 |
370ม | 48 | 1,024 |
790ม | 48 | 1536 |
1.4B | 48 | 2048 |
2.8B | 64 | 2560 |
(จำนวนเลเยอร์ของ Mamba จะเพิ่มเป็นสองเท่าของ Transformer ที่มีขนาดใกล้เคียงกัน เนื่องจากต้องใช้บล็อก Mamba สองบล็อกสำหรับแต่ละ "เลเยอร์" (บล็อก MHA + บล็อก MLP) ของ Transformer)
หมายเหตุ: โมเดลเหล่านี้เป็นโมเดลพื้นฐานที่ได้รับการฝึกสำหรับโทเค็น 300B เท่านั้น โดยไม่มีการปรับเปลี่ยนดาวน์สตรีมในรูปแบบใดๆ (การปรับแต่งคำสั่ง ฯลฯ) ประสิทธิภาพคาดว่าจะเทียบเคียงหรือดีกว่าสถาปัตยกรรมอื่นๆ ที่ได้รับการฝึกด้วยข้อมูลที่คล้ายคลึงกัน แต่ไม่ตรงกับโมเดลที่ใหญ่กว่าหรือได้รับการปรับแต่งอย่างละเอียด
ในการรันการประเมินแบบ Zero-Shot ของแบบจำลอง (สอดคล้องกับตารางที่ 3 ของรายงาน) เราใช้ไลบรารี lm-evalue-harness
lm-evaluation-harness
โดย pip install lm-eval==0.4.2
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
หากต้องการทำซ้ำผลลัพธ์ในโมเดล mamba-2.8b-slimpj
ที่รายงานในบล็อกโพสต์:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
หากต้องการรันการประเมินโมเดล Mamba-2 เพียงเปลี่ยนชื่อโมเดล:
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
โปรดทราบว่าผลลัพธ์ของแต่ละงานอาจแตกต่างกันไปจากค่าที่รายงาน 0.1-0.3 เนื่องจากเสียงรบกวนในกระบวนการประเมิน
เกณฑ์มาตรฐานสคริปต์/benchmark_generation_mamba_simple.py
ตัวเลือกที่กำหนดค่าได้อื่นๆ ได้แก่ ความน่าจะเป็น top-p (การสุ่มตัวอย่างนิวเคลียส) และอุณหภูมิซอฟต์แม็กซ์
วิธีทดสอบเวลาแฝงในการสร้าง (เช่น ขนาดแบทช์ = 1) ด้วยกลยุทธ์การสุ่มตัวอย่างที่แตกต่างกัน:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --prompt " My cat wrote all this CUDA code for a new language model and " --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
หากต้องการทดสอบปริมาณงานการสร้างด้วยการแจ้งเตือนแบบสุ่ม (เช่น ขนาดแบตช์มาก):
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba-2.8b " --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name " EleutherAI/pythia-2.8b " --batch 64
ด้วย Mamba-2 คุณเพียงแค่ต้องเปลี่ยนชื่อรุ่น:
python benchmarks/benchmark_generation_mamba_simple.py --model-name " state-spaces/mamba2-2.7b " --prompt " My cat wrote all this CUDA code for a new language model and " --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
โมเดลของเราได้รับการฝึกฝนโดยใช้ PyTorch AMP เพื่อความแม่นยำแบบผสม AMP เก็บพารามิเตอร์โมเดลไว้ใน float32 และแปลงให้มีความแม่นยำเพียงครึ่งหนึ่งเมื่อจำเป็น ในทางกลับกัน เฟรมเวิร์กอื่นๆ เช่น DeepSpeed จะจัดเก็บพารามิเตอร์ใน float16 และอัปคาสต์เมื่อจำเป็น (เช่น สำหรับการสะสมของเครื่องมือเพิ่มประสิทธิภาพ)
เราสังเกตว่าอาจจำเป็นต้องมีความแม่นยำสูงกว่าสำหรับพารามิเตอร์โมเดลหลัก เนื่องจาก SSM มีความไวต่อไดนามิกที่เกิดซ้ำ หากคุณประสบปัญหาความไม่เสถียร ขั้นแรกโปรดลองใช้เฟรมเวิร์กที่จัดเก็บพารามิเตอร์ใน fp32 (เช่น AMP)
บางส่วนของโมเดลมีการกำหนดค่าเริ่มต้นที่สืบทอดมาจากงานก่อนหน้าในรุ่น S4 ตัวอย่างเช่น nn.Linear
ให้เป็นศูนย์) หากเป็นกรณีนี้ คุณอาจต้องเพิ่มตรรกะที่กำหนดเอง (เช่น บรรทัดนี้จะปิดการเริ่มต้นใหม่ในเทรนเนอร์ของเรา แต่จะเป็นการไม่ต้องดำเนินการในเฟรมเวิร์กอื่นใด) ที่เฉพาะเจาะจงกับเฟรมเวิร์กการฝึกอบรม
หากคุณใช้ ROCm 6.0 ให้รันขั้นตอนต่อไปนี้เพื่อหลีกเลี่ยงข้อผิดพลาดระหว่างการคอมไพล์ สิ่งนี้ไม่จำเป็นสำหรับ ROCm 6.1 เป็นต้นไป
ค้นหาไดเรกทอรีการติดตั้ง ROCm ของคุณ โดยทั่วไปจะพบได้ที่ /opt/rocm/
แต่อาจแตกต่างกันไปขึ้นอยู่กับการติดตั้งของคุณ
ใช้แพทช์ ทำงานด้วย sudo
ในกรณีที่คุณประสบปัญหาในการอนุญาต
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
หากคุณใช้โค้ดเบสนี้ หรือพบว่างานของเรามีคุณค่า โปรดอ้างอิง Mamba:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
@inproceedings{mamba2,
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}