纸| JAX代码库|设置|快速启动|推理基准
这是学习的官方模型实现(在测试时学习):具有表现力的隐藏状态的RNN。我们不建议使用此代码库进行培训,因为它是用纯pytorch编写的,没有任何系统优化,因此培训会很慢,尤其是当人均批处理大小很小时。
有关培训代码,或复制我们的论文结果,请查看我们的JAX代码库。有关推理内核,或复制我们论文中的速度基准,请查看我们的内核实现。
自我注意力在远面表现良好,但具有二次复杂性。现有的RNN层具有线性复杂性,但是它们在长篇小说中的性能受到其隐藏状态的表现力的限制。我们提出了具有线性复杂性和表达性隐藏状态的新的序列建模层。关键的想法是使隐藏状态成为机器学习模型本身,而更新规则是自我监督学习的步骤。
由于隐藏状态甚至通过在测试序列上进行培训来更新,因此我们的层被称为测试时间培训(TTT)层。我们考虑了两个实例:TTT线性和TTT-MLP,其隐藏状态分别是线性模型和两层MLP。
pip install " transformers[torch] "
我们的实现基于拥抱面变压器。您可以使用以下代码加载模型并生成文本。
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
注意:这是用于教程目的的TTT层的幼稚实现。可以使用拥抱面的加速或定制培训循环来训练该模型。我们在此处发布了更快的推理内核及其速度基准。