MambaTransformer
1.0.0
將 Mamba/SSM 與 Transformer 整合以增強長上下文和高品質序列建模。
這是我設計的 100% 新穎的架構,結合了 SSM 和 Attention 的優點和缺點,形成了一種全新的高級架構,其目的是超越我們的舊限制。更快的處理速度、更長的上下文長度、更低的長序列困惑度、增強且卓越的推理,同時保持小而緊湊。
這個架構本質上是: x -> norm -> mamba -> norm -> transformer -> norm -> ffn -> norm -> out
。
我添加了許多標準化,因為我相信預設情況下,由於 2 個外部架構相互集成,訓練穩定性會嚴重降低。
pip3 install mambatransformer
import torch
from mamba_transformer import MambaTransformer
# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch . randint ( 0 , 100 , ( 1 , 10 ))
# Create an instance of the MambaTransformer model
model = MambaTransformer (
num_tokens = 100 , # Number of tokens in the input sequence
dim = 512 , # Dimension of the model
heads = 8 , # Number of attention heads
depth = 4 , # Number of transformer layers
dim_head = 64 , # Dimension of each attention head
d_state = 512 , # Dimension of the state
dropout = 0.1 , # Dropout rate
ff_mult = 4 , # Multiplier for the feed-forward layer dimension
return_embeddings = False , # Whether to return the embeddings,
transformer_depth = 2 , # Number of transformer blocks
mamba_depth = 10 , # Number of Mamba blocks,
use_linear_attn = True , # Whether to use linear attention
)
# Pass the input tensor through the model and print the output shape
out = model ( x )
print ( out . shape )
# After many training
model . eval ()
# Would you like to train this model? Zeta Corporation offers unmatchable GPU clusters at unbeatable prices, let's partner!
# Tokenizer
model . generate ( text )
麻省理工學院