Papier | JAX CODEBASE | Setup | Schneller Start | Inferenz -Benchmark
Dies ist die offizielle Implementierung des Lernens von Pytorch -Modell (Lernen Sie zur Testzeit): RNNs mit ausdrucksstarken versteckten Zuständen. Wir empfehlen kein Training mit dieser Codebasis, da es ohne Systemoptimierung in reinem Pytorch geschrieben ist. Das Training ist also langsam, insbesondere wenn die Größe pro Gerät gering ist.
Für den Trainingscode oder um Ergebnisse aus unserem Papier zu replizieren, lesen Sie bitte unsere JAX -Codebasis. Für Inferenzkerne oder um Geschwindigkeitsbenchmarks aus unserem Artikel zu replizieren, sehen Sie sich bitte unsere Kernel -Implementierungen an.
Die Selbstbekämpfung ist im langen Kontext gut funktioniert, hat jedoch eine quadratische Komplexität. Bestehende RNN -Schichten haben eine lineare Komplexität, aber ihre Leistung im langen Kontext ist durch die ausdrucksstarke Kraft ihres versteckten Zustands begrenzt. Wir schlagen eine neue Klasse von Sequenzmodellierungsschichten mit linearer Komplexität und einem ausdrucksstarken versteckten Zustand vor. Die Hauptidee besteht darin, den versteckten Zustand selbst zu einem maschinellen Lernmodell und der Aktualisierungsregel zu einem Schritt des selbstbewerteten Lernens zu machen.
Da der versteckte Zustand auch in Testsequenzen durch Training aktualisiert wird, werden unsere Ebenen als TTT (Test-Time Training) bezeichnet. Wir betrachten zwei Instanziationen: TTT-Linear und TTT-MLP, deren versteckter Zustand ein lineares Modell bzw. ein zweischichtiger MLP ist.
pip install " transformers[torch] "
Unsere Implementierung basiert auf Umarmungsface -Transformatoren. Sie können den folgenden Code verwenden, um das Modell zu laden und Text zu generieren.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
Hinweis: Dies ist eine naive Implementierung von TTT -Schichten für Tutorialzwecke. Dieses Modell kann mit Huggingface Accelerate oder benutzerdefinierten Trainingsschleifen geschult werden. Wir haben unseren schnelleren Inferenzkern und seinen Speed -Benchmark hier veröffentlicht.