Kertas | Jax CodeBase | Pengaturan | Mulai Cepat | Benchmark inferensi
Ini adalah implementasi model Pytorch resmi untuk belajar (belajar pada waktu tes): RNN dengan negara tersembunyi ekspresif. Kami tidak merekomendasikan pelatihan dengan basis kode ini, karena ditulis dalam pytorch murni tanpa optimasi sistem, sehingga pelatihan akan lambat, terutama ketika ukuran batch per-perangkat kecil.
Untuk kode pelatihan, atau untuk mereplikasi hasil dari makalah kami, silakan lihat basis kode JAX kami. Untuk kernel inferensi, atau untuk mereplikasi tolok ukur kecepatan dari makalah kami, silakan lihat implementasi kernel kami.
Perhatian diri bekerja dengan baik dalam konteks yang panjang tetapi memiliki kompleksitas kuadratik. Lapisan RNN yang ada memiliki kompleksitas linier, tetapi kinerja mereka dalam konteks panjang dibatasi oleh kekuatan ekspresif dari keadaan tersembunyi mereka. Kami mengusulkan kelas baru lapisan pemodelan urutan dengan kompleksitas linier dan keadaan tersembunyi ekspresif. Gagasan kuncinya adalah menjadikan status tersembunyi sebagai model pembelajaran mesin itu sendiri, dan pembaruan aturan langkah pembelajaran yang di-swadaya.
Karena keadaan tersembunyi diperbarui dengan pelatihan bahkan pada urutan tes, lapisan kami disebut lapisan pelatihan waktu tes (TTT) . Kami mempertimbangkan dua instansiasi: TTT-Linear dan TTT-MLP, yang keadaan tersembunyi masing-masing adalah model linier dan dua lapis MLP.
pip install " transformers[torch] "
Implementasi kami didasarkan pada transformator huggingface. Anda dapat menggunakan kode berikut untuk memuat model dan menghasilkan teks.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
Catatan: Ini adalah implementasi naif dari lapisan TTT untuk tujuan tutorial. Model ini dapat dilatih menggunakan huggingface accelerate, atau loop pelatihan khusus. Kami telah merilis kernel inferensi kami yang lebih cepat dan tolok ukur kecepatannya di sini.