Papier | Jax CodeBase | Configuration | Démarrage rapide | Inférence Benchmark
Il s'agit de la mise en œuvre officielle du modèle Pytorch de l'apprentissage à (apprendre au moment du test): RNN avec des états cachés expressifs. Nous ne recommandons pas de formation avec cette base de code, car elle est écrite en pytorch pur sans aucune optimisation de systèmes, donc la formation sera lente, surtout lorsque la taille du lot par demande est petite.
Pour le code de formation ou pour reproduire les résultats de notre article, veuillez consulter notre base de code JAX. Pour les grains d'inférence, ou pour reproduire les repères de vitesse de notre article, veuillez consulter nos implémentations du noyau.
L'auto-attention fonctionne bien dans un contexte long mais a une complexité quadratique. Les couches RNN existantes ont une complexité linéaire, mais leur performance dans un contexte long est limitée par le pouvoir expressif de leur état caché. Nous proposons une nouvelle classe de couches de modélisation de séquence avec une complexité linéaire et un état caché expressif. L'idée clé est de faire de l'état caché un modèle d'apprentissage automatique lui-même, et la règle de mise à jour une étape d'apprentissage auto-supervisé.
Étant donné que l'état caché est mis à jour par la formation même sur les séquences de test, nos couches sont appelées couches de formation de test (TTT) . Nous considérons deux instanciations: TTT-linéaire et TTT-MLP, dont l'état caché est respectivement un modèle linéaire et un MLP à deux couches.
pip install " transformers[torch] "
Notre implémentation est basée sur des transformateurs HuggingFace. Vous pouvez utiliser le code suivant pour charger le modèle et générer du texte.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
Remarque: Il s'agit d'une implémentation naïve des couches TTT à des fins de tutoriel. Ce modèle peut être formé à l'aide de HuggingFace Accelerate ou de boucles de formation personnalisées. Nous avons publié notre noyau d'inférence plus rapide et sa référence de vitesse ici.