Papel | Jax CodeBase | Configuración | Inicio rápido | Punto de referencia de inferencia
Esta es la implementación oficial del modelo de Pytorch del aprendizaje para (aprender en el tiempo de prueba): RNN con estados ocultos expresivos. No recomendamos capacitar con esta base de código, porque está escrito en Pytorch puro sin ninguna optimización de sistemas, por lo que el entrenamiento será lento, especialmente cuando el tamaño por lotes por dispositivo sea pequeño.
Para el código de entrenamiento, o para replicar los resultados de nuestro documento, vea nuestra base de código JAX. Para los núcleos de inferencia, o para replicar los puntos de referencia de velocidad de nuestro artículo, vea nuestras implementaciones del núcleo.
La autoatación funciona bien en un contexto largo pero tiene complejidad cuadrática. Las capas RNN existentes tienen complejidad lineal, pero su rendimiento en un contexto largo está limitado por el poder expresivo de su estado oculto. Proponemos una nueva clase de capas de modelado de secuencia con complejidad lineal y un estado oculto expresivo. La idea clave es hacer del estado oculto un modelo de aprendizaje automático en sí, y la regla de actualización de un paso de aprendizaje auto-supervisado.
Dado que el estado oculto se actualiza mediante la capacitación incluso en secuencias de prueba, nuestras capas se denominan capas de entrenamiento de tiempo de prueba (TTT) . Consideramos dos instancias: TTT-Lineal y TTT-MLP, cuyo estado oculto es un modelo lineal y un MLP de dos capas, respectivamente.
pip install " transformers[torch] "
Nuestra implementación se basa en Huggingface Transformers. Puede usar el siguiente código para cargar el modelo y generar texto.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
Nota: Esta es una implementación ingenua de capas TTT para fines tutoriales. Este modelo puede ser entrenado con Huggingface Accelerate o bucles de entrenamiento personalizados. Hemos lanzado nuestro núcleo de inferencia más rápido y su punto de referencia de velocidad aquí.