กระดาษ | Jax Codebase | การตั้งค่า | เริ่มต้นอย่างรวดเร็ว | เกณฑ์มาตรฐานการอนุมาน
นี่คือการใช้แบบจำลอง Pytorch อย่างเป็นทางการของการเรียนรู้ที่จะ (เรียนรู้ ณ เวลาทดสอบ): RNNs ที่มีสถานะซ่อนเร้น เรา ไม่แนะนำให้ฝึกอบรม กับ codebase นี้เพราะมันถูกเขียนด้วย pytorch บริสุทธิ์โดยไม่มีการเพิ่มประสิทธิภาพระบบใด ๆ ดังนั้นการฝึกอบรมจะช้าโดยเฉพาะอย่างยิ่งเมื่อขนาดแบทช์ต่ออุปกรณ์มีขนาดเล็ก
สำหรับรหัสการฝึกอบรมหรือเพื่อทำซ้ำผลลัพธ์จากบทความของเราโปรดดู JAX Codebase ของเรา สำหรับเมล็ดการอนุมานหรือเพื่อทำซ้ำการเปรียบเทียบความเร็วจากกระดาษของเราโปรดดูการใช้งานเคอร์เนลของเรา
การดูแลตนเองทำงานได้ดีในบริบทที่ยาวนาน แต่มีความซับซ้อนเป็นกำลังสอง เลเยอร์ RNN ที่มีอยู่มีความซับซ้อนเชิงเส้น แต่ประสิทธิภาพของพวกเขาในบริบทที่ยาวนานนั้นถูก จำกัด ด้วยพลังที่แสดงออกของสถานะที่ซ่อนอยู่ เราเสนอเลเยอร์การสร้างแบบจำลองลำดับคลาสใหม่ที่มีความซับซ้อนเชิงเส้นและสถานะที่ซ่อนอยู่ แนวคิดหลักคือการทำให้สถานะที่ซ่อนอยู่เป็นรูปแบบการเรียนรู้ของเครื่องเองและการอัปเดตกฎขั้นตอนของการเรียนรู้ที่ดูแลตนเอง
เนื่องจากสถานะที่ซ่อนอยู่ได้รับการปรับปรุงโดยการฝึกอบรมแม้ในลำดับการทดสอบเลเยอร์ของเราจึงเรียกว่า เลเยอร์การฝึกอบรมเวลาทดสอบ (TTT) เราพิจารณาสองอินสแตนซ์: TTT-Linear และ TTT-MLP ซึ่งสถานะที่ซ่อนอยู่เป็นแบบจำลองเชิงเส้นและ MLP สองชั้นตามลำดับ
pip install " transformers[torch] "
การใช้งานของเราขึ้นอยู่กับหม้อแปลง HuggingFace คุณสามารถใช้รหัสต่อไปนี้เพื่อโหลดโมเดลและสร้างข้อความ
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
หมายเหตุ: นี่คือการใช้เลเยอร์ TTT ที่ไร้เดียงสาเพื่อการสอน รุ่นนี้สามารถผ่านการฝึกอบรมโดยใช้ HuggingFace Accelerate หรือลูปการฝึกอบรมที่กำหนดเอง เราได้เปิดตัวเคอร์เนลการอนุมานที่เร็วขึ้นและเกณฑ์มาตรฐานความเร็วที่นี่