紙| JAX代碼庫|設置|快速啟動|推理基準
這是學習的官方模型實現(在測試時學習):具有表現力的隱藏狀態的RNN。我們不建議使用此代碼庫進行培訓,因為它是用純pytorch編寫的,沒有任何系統優化,因此培訓會很慢,尤其是當人均批處理大小很小時。
有關培訓代碼,或複制我們的論文結果,請查看我們的JAX代碼庫。有關推理內核,或複制我們論文中的速度基準,請查看我們的內核實現。
自我注意力在遠面表現良好,但具有二次復雜性。現有的RNN層具有線性複雜性,但是它們在長篇小說中的性能受到其隱藏狀態的表現力的限制。我們提出了具有線性複雜性和表達性隱藏狀態的新的序列建模層。關鍵的想法是使隱藏狀態成為機器學習模型本身,而更新規則是自我監督學習的步驟。
由於隱藏狀態甚至通過在測試序列上進行培訓來更新,因此我們的層被稱為測試時間培訓(TTT)層。我們考慮了兩個實例:TTT線性和TTT-MLP,其隱藏狀態分別是線性模型和兩層MLP。
pip install " transformers[torch] "
我們的實現基於擁抱面變壓器。您可以使用以下代碼加載模型並生成文本。
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
注意:這是用於教程目的的TTT層的幼稚實現。可以使用擁抱面的加速或定制培訓循環來訓練該模型。我們在此處發布了更快的推理內核及其速度基準。