Papel | Jax CodeBase | Configuração | Iniciar rápido | Referência de inferência
Esta é a implementação oficial do modelo Pytorch do aprendizado (Aprenda no tempo de teste): RNNs com estados ocultos expressivos. Não recomendamos o treinamento com esta base de código, porque é escrita em pytorch puro sem qualquer otimização de sistemas, portanto o treinamento será lento, especialmente quando o tamanho do lote por dispositivo for pequeno.
Para o código de treinamento ou para replicar os resultados do nosso artigo, consulte nossa base de código JAX. Para kernels de inferência ou para replicar os benchmarks de velocidade em nosso artigo, consulte nossas implementações do kernel.
A auto-ataque tem um bom desempenho em longo contexto, mas tem complexidade quadrática. As camadas RNN existentes têm complexidade linear, mas seu desempenho em longo contexto é limitado pelo poder expressivo de seu estado oculto. Propomos uma nova classe de camadas de modelagem de sequência com complexidade linear e um estado oculto expressivo. A idéia principal é tornar o estado oculto um modelo de aprendizado de máquina e a regra de atualização uma etapa do aprendizado auto-supervisionado.
Como o estado oculto é atualizado por treinamento, mesmo em sequências de teste, nossas camadas são chamadas de camadas de treinamento em tempo de teste (TTT) . Consideramos duas instanciações: TTT-linear e TTT-MLP, cujo estado oculto é um modelo linear e um MLP de duas camadas, respectivamente.
pip install " transformers[torch] "
Nossa implementação é baseada em Transformers Huggingface. Você pode usar o código a seguir para carregar o modelo e gerar texto.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
Nota: Esta é uma implementação ingênua de camadas TTT para fins de tutorial. Este modelo pode ser treinado usando acelerar o HuggingFace ou loops de treinamento personalizados. Lançamos nosso kernel de inferência mais rápido e sua referência de velocidade aqui.