論文| Jax CodeBase |セットアップ|クイックスタート|推論ベンチマーク
これは、学習の公式Pytorchモデルの実装です(テスト時に学習):表現力豊かな隠された状態を持つRNN。このコードベースでのトレーニングはお勧めしません。これは、システムの最適化なしに純粋なPytorchで書かれているため、特にデバイスごとのバッチサイズが小さい場合はトレーニングが遅くなります。
トレーニングコード、または私たちの論文の結果を再現するには、JAX CodeBaseをご覧ください。推論カーネルについて、または私たちの論文から速度ベンチマークを複製するには、カーネルの実装をご覧ください。
自己関節は長い文脈でうまく機能しますが、二次的な複雑さを持っています。既存のRNN層には線形の複雑さがありますが、長いコンテキストでのパフォーマンスは、隠された状態の表現力によって制限されています。直線的な複雑さと表現力豊かな隠された状態を備えた新しいクラスのシーケンスモデリングレイヤーを提案します。重要なアイデアは、隠された状態を機械学習モデル自体にすることであり、更新が自己教師の学習のステップを支配することです。
隠された状態はテストシーケンスでもトレーニングによって更新されるため、レイヤーはテスト時間トレーニング(TTT)レイヤーと呼ばれます。 2つのインスタンス化を検討します。TTT-LINEARとTTT-MLP。隠された状態はそれぞれ線形モデルと2層MLPです。
pip install " transformers[torch] "
私たちの実装は、ハギングフェイストランスに基づいています。次のコードを使用してモデルを読み込み、テキストを生成できます。
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
注:これは、チュートリアルの目的でTTTレイヤーの素朴な実装です。このモデルは、Huggingface Accelerateまたはカスタムトレーニングループを使用してトレーニングできます。ここで、より速い推論カーネルとその速度ベンチマークをリリースしました。