Бумага | JAX CODEBASE | Настройка | Быстрый старт | ИСПРАВЛЕНИЕ
Это официальная реализация модели Pytorch обучения (обучение во время теста): RNN с выразительными скрытыми состояниями. Мы не рекомендуем обучение с этой кодовой базой, потому что она написана в Pure Pytorch без какой-либо системной оптимизации, поэтому обучение будет медленным, особенно когда размер партии для каждого пакета невелик.
Для обучения или для воспроизведения результатов из нашей статьи, пожалуйста, просмотрите нашу кодовую базу JAX. Для ядра вывода или для воспроизведения резервуаров скорости из нашей статьи, пожалуйста, просмотрите наши реализации ядра.
Самоализация хорошо работает в длинном контексте, но имеет квадратичную сложность. Существующие слои RNN имеют линейную сложность, но их эффективность в длинном контексте ограничена выразительной силой их скрытого состояния. Мы предлагаем новый класс слоев моделирования последовательностей с линейной сложностью и выразительным скрытым состоянием. Ключевая идея состоит в том, чтобы сделать скрытое состояние самой моделью машинного обучения, а обновление правило шагом самоотверженного обучения.
Поскольку скрытое состояние обновляется обучением даже на тестовых последовательностях, наши слои называются слоями обучения времени тестирования (TTT) . Мы рассматриваем два экземпляра: TTT-Linear и TTT-MLP, чье скрытое состояние представляет собой линейную модель и двухслойный MLP соответственно.
pip install " transformers[torch] "
Наша реализация основана на трансформаторах HuggingFace. Вы можете использовать следующий код для загрузки модели и генерации текста.
from transformers import AutoTokenizer
from ttt import TTTForCausalLM , TTTConfig , TTT_STANDARD_CONFIGS
# Initializing a TTT ttt-1b style configuration
# configuration = TTTConfig(**TTT_STANDARD_CONFIGS['1b']) is equivalent to the following
configuration = TTTConfig ()
# Initializing a model from the ttt-1b style configuration
model = TTTForCausalLM ( configuration )
model . eval ()
# Accessing the model configuration
configuration = model . config
# Tokenizer
tokenizer = AutoTokenizer . from_pretrained ( 'meta-llama/Llama-2-7b-hf' )
# Prefill
input_ids = tokenizer ( "Greeting from TTT!" , return_tensors = "pt" ). input_ids
logits = model ( input_ids = input_ids )
print ( logits )
# Decoding
out_ids = model . generate ( input_ids = input_ids , max_length = 50 )
out_str = tokenizer . batch_decode ( out_ids , skip_special_tokens = True )
print ( out_str )
Примечание. Это наивная реализация слоев TTT для учебных целей. Эта модель может быть обучена с помощью ускорения HuggingFace или пользовательских петлей обучения. Мы выпустили нашу более быстрое ядро вывода и его эталон скорости здесь.