종이 | JAX Codebase | 설정 | 빠른 시작 | 추론 벤치 마크
이것은 표현적인 숨겨진 상태를 가진 RNNS에 대한 학습의 공식 Pytorch 모델 구현입니다. 시스템 최적화없이 순수한 Pytorch로 작성 되었기 때문에이 코드베이스로 교육을 권장하지 않으므로 특히 부정 당 배치 크기가 작을 때 교육이 느려집니다.
교육 코드 또는 논문의 결과를 복제하려면 JAX 코드베이스를보십시오. 커널을 추론하거나 용지에서 속도 벤치 마크를 복제하려면 커널 구현을보십시오.
자기 변환은 긴 맥락에서 잘 수행되지만 2 차 복잡성 을가집니다. 기존 RNN 층은 선형 복잡성을 가지지 만, 긴 맥락에서의 성능은 숨겨진 상태의 표현력에 의해 제한됩니다. 우리는 선형 복잡성과 표현적인 숨겨진 상태를 가진 새로운 클래스의 시퀀스 모델링 층을 제안합니다. 핵심 아이디어는 숨겨진 상태를 기계 학습 모델 자체로 만드는 것입니다. 업데이트 규칙은 자체 감독 학습의 단계입니다.
숨겨진 상태는 테스트 시퀀스에서도 훈련에 의해 업데이트되므로 우리의 계층을 TTT (Test-Time Training) 레이어 라고합니다. 우리는 TTT-linear와 TTT-MLP의 두 가지 인스턴스화를 고려합니다.
pip install " transformers[torch] "
우리의 구현은 Huggingface Transformers를 기반으로합니다. 다음 코드를 사용하여 모델을로드하고 텍스트를 생성 할 수 있습니다.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
참고 : 이것은 튜토리얼 목적으로 TTT 계층의 순진한 구현입니다. 이 모델은 Huggingface Accelerate 또는 맞춤형 교육 루프를 사용하여 교육을받을 수 있습니다. 우리는 더 빠른 추론 커널과 속도 벤치 마크를 여기에서 출시했습니다.