ورقة | Jax Codebase | الإعداد | بداية سريعة | معيار الاستنتاج
هذا هو تنفيذ نموذج Pytorch الرسمي للتعلم (التعلم في وقت الاختبار): RNNs مع حالات خفية تعبيرية. لا نوصي بالتدريب باستخدام قاعدة البيانات هذه ، لأنه مكتوب في Pytorch الخالص دون أي تحسين للأنظمة ، لذلك سيكون التدريب بطيئًا ، خاصةً عندما يكون حجم دفعة كل جهاز صغير.
للحصول على رمز التدريب ، أو لتكرار النتائج من ورقتنا ، يرجى الاطلاع على قاعدة كود Jax. بالنسبة إلى نواة الاستدلال ، أو لتكرار معايير السرعة من ورقتنا ، يرجى الاطلاع على تطبيقات kernel.
يعمل الاهتمام الذاتي جيدًا في سياق طويل ولكنه يتمتع بالتعقيد التربيعي. تحتوي طبقات RNN الحالية على تعقيد خطي ، لكن أدائها في السياق الطويل يقتصر على القوة التعبيرية لحالتها المخفية. نقترح فئة جديدة من طبقات نمذجة التسلسل مع التعقيد الخطي وحالة خفية تعبيرية. الفكرة الرئيسية هي جعل الحالة المخفية نموذجًا للتعلم الآلي نفسه ، وحكم التحديث خطوة من التعلم الخاضع للإشراف ذاتيًا.
نظرًا لأن الحالة المخفية يتم تحديثها عن طريق التدريب حتى على تسلسل الاختبار ، فإن طبقاتنا تسمى طبقات التدريب في وقت الاختبار (TTT) . نحن نعتبر اثنين من مثيلتين: TTT-Linar و TTT-MLP ، الذي تعتبر حالته المخفية نموذجًا خطيًا و MLP ثنائي الطبقات على التوالي.
pip install " transformers[torch] "
يعتمد تنفيذنا على محولات Huggingface. يمكنك استخدام الكود التالي لتحميل النموذج وإنشاء النص.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
ملاحظة: هذا تطبيق ساذج لطبقات TTT لأغراض البرنامج التعليمي. يمكن تدريب هذا النموذج باستخدام تسريع Luggingface ، أو حلقات التدريب المخصصة. لقد أصدرنا نواة الاستدلال الأسرع لدينا وقياس السرعة هنا.