Автор: Генри Ндубуаку (значки Discord и Docs кликабельны)
N/B: Коды реализованы педагогически за счет повторения. Каждая модель целенаправленно содержится в файле без межфайловых зависимостей.
Разработка и обучение моделей на основе трансформаторов обычно требует больших ресурсов и времени, и экспертам в области искусственного интеллекта и машинного обучения часто приходится создавать уменьшенные версии этих моделей для решения конкретных задач. Jax, малоресурсный, но мощный фреймворк, ускоряет разработку нейронных сетей и абстрагирует распределенное обучение, но существующие ресурсы для разработки преобразователей в Jax ограничены. NanoDL решает эту проблему благодаря следующим функциям:
Широкий набор блоков и слоев, облегчающий создание индивидуальных моделей трансформаторов с нуля.
Обширный выбор моделей, таких как Gemma, LlaMa3, Mistral, GPT3, GPT4 (предполагаемый), T5, Whisper, ViT, Микшеры, CLIP и т. д.
Распределенные модели тренажеров с параллельным использованием данных на нескольких графических процессорах или TPU без необходимости использования циклов ручного обучения.
Загрузчики данных, делающие процесс обработки данных для Jax/Flax более простым и эффективным.
Слои, отсутствующие в Flax/Jax, такие как RoPE, GQA, MQA и SW, заслуживают внимания, что позволяет более гибко разрабатывать модели.
Классические модели машинного обучения с ускорением на GPU/TPU, такие как PCA, KMeans, регрессия, гауссовские процессы и т. д.
Настоящие генераторы случайных чисел в Jax, которым не нужен подробный код.
Ряд продвинутых алгоритмов для задач НЛП и компьютерного зрения, таких как Gaussian Blur, BLEU, Tokenizer и т. д.
Каждая модель содержится в одном файле без каких-либо внешних зависимостей, поэтому исходный код также можно легко использовать.
Настоящие генераторы случайных чисел в Jax, которым не нужен подробный код (примеры показаны в следующих разделах).
В репозитории есть экспериментальные и/или незавершенные функции (такие как MAMBA, KAN, BitNet, GAT и RLHF), которые еще не доступны через пакет, но могут быть скопированы из этого репозитория. Отзывы о любых наших обсуждениях, проблемах и запросах на включение приветствуются! Пожалуйста, сообщайте о любых запросах функций, проблемах, вопросах или проблемах в Discord или просто дайте нам знать, над чем вы работаете!
Вам понадобится Python 3.9 или более поздняя версия, а также работающая установка JAX, установка FLAX, установка OPTAX (с поддержкой графического процессора для запуска обучения, без нее можно поддерживать только творения). Модели можно проектировать и тестировать на ЦП, но все тренажеры являются параллельными с распределенными данными, для чего потребуется графический процессор с числом от 1 до N GPU/TPUS. Для версии JAX только для ЦП:
pip install --upgrade pip # To support manylinux2010 wheels. pip install jax flax optax
Затем установите nanodl из PyPi:
pip install nanodl
Мы предоставляем различные примеры использования API nanodl.
import jaximport nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import GPT4, GPTDataParallelTrainer# Подготовка набора данныхbatch_size = 8max_length = 50vocab_size = 1000# Создайте случайные данныеdata = nanodl.uniform(shape=(batch_size, max_length), minval=0, maxval=vocab_size-1).astype(jnp.int32)# Shift для создания набора данных прогнозирования следующего токенаdummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]# Создать набор данных и dataloaderdataset = ArrayDataset(dummy_inputs , dummy_targets)dataloader = DataLoader(набор данных, Batch_size=batch_size, shuffle=True, drop_last=False)# параметры моделиhyperparams = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0,1,'vocab_size': vocab_size,' embed_dim': 256,'max_length': max_length, «start_token»: 0, «end_token»: 50, }# Предполагаемая модель GPT4 model = GPT4(**hyperparams)trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')trainer.train(train_loader=dataloader, num_epochs=100, val_loader=dataloader) # используем фактические данные val # Генерируем стартовый токенstart_tokens = jnp.array([[123, 456]])# Не забудьте загрузить обученные параметры params = Trainer.load_params('params.pkl')outputs = model.apply( {'params': params}, start_tokens,rngs={'dropout': nanodl.time_rng_key()}, метод=model.generate)
Пример видения
import nanodlimport jax.numpy как jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import DiffusionModel, DiffusionDataParallelTrainerimage_size = 32block_length = 2batch_size = 8widths = [32, 64, 128]input_shape = (101, image_size, image_size, 3)images = nanodl.normal(shape=input_shape)# Используйте свой собственный набор данных изображений = ArrayDataset(images) dataloader = DataLoader (набор данных, Batch_size = Batch_size, Shuffle = True, drop_last = False) # Создайте модель диффузииdiffusion_model = DiffusionModel(image_size, widths,block_length)# Обучение на вашем datatrainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.shape, Weights_filename='params.pkl', Learning_rate=1e-4)trainer.train(dataloader, 10)# Сгенерируйте несколько образцов: каждая модель представляет собой модуль Flax.linen# Используйте, как обычно, params = train.load_params('params.pkl')generated_images = Diffusion_model.apply( {'параметры': параметры}, число_изображений = 5, диффузия_шаги = 5, метод=diffusion_model.generate)
Аудио пример
import jaximport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Whisper, WhisperDataParallelTrainer# Параметры фиктивных данныхbatch_size = 8max_length = 50embed_dim = 256 vocab_size = 1000 # Генерация данных: заменить фактическими токенизированными/квантованными данными dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)dummy_inputs = jnp.ones((101, max_length, embed_dim))dataset = ArrayDataset(dummy_inputs, dummy_targets)dataloader = DataLoader(dataset,atch_size=batch_size, shuffle= Верно, drop_last=False)# модель параметрыгиперпарамс = {'num_layers': 1,'hidden_dim': 256,'num_heads': 2,'feedforward_dim': 256,'dropout': 0,1, 'vocab_size': 1000,'embed_dim': embed_dim,'max_length': max_length ,'start_token': 0,'end_token': 50, }# Инициализация modelmodel = Whisper(**hyperparams)# Обучение на вашем тренажере данных = WhisperDataParallelTrainer(model, dummy_inputs.shape, dummy_targets.shape, 'params.pkl')trainer.train(dataloader, 2, dataloader)# Пример вывода параметров = тренер.load_params('params.pkl')# для более чем одного примера, часто используйте model.generate_batchtranscripts = model.apply({'params ': параметры}, dummy_inputs[:1], метод=model.generate)
Пример модели вознаграждения для RLHF
import nanodlimport jax.numpy as jnpfrom nanodl import ArrayDataset, DataLoaderfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer# Создать фиктивный databatch_size = 8max_length = 10# Заменить фактическими токенизированными данными dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)# Создать набор данных и dataloaderdataset = ArrayDataset(dummy_chosen, dummy_rejected)dataloader = DataLoader(dataset,atch_size=batch_size, shuffle=True, drop_last=False) # параметры моделиhyperparams = {'num_layers': 1, 'hidden_dim': 256, 'num_heads': 2, 'feedforward_dim': 256, 'dropout': 0,1, 'vocab_size': 1000, 'embed_dim': 256, 'max_length': max_length,' start_token': 0,'end_token': 50,'num_groups': 2,'window_size': 5,'shift_size': 2}# Инициализировать модель вознаграждения из Mistralmodel = Mistral(**hyperparams)reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout= 0.1)# Обучение модели вознагражденияtrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')trainer.train(dataloader, 5, dataloader)params = train.load_params('reward_model_weights.pkl')# Вызов так же, как и обычная модель Flaxrewards = вознаграждение_модель.применить({'params' : params}, dummy_chosen, rngs={'dropout': nanodl.time_rng_key()})
Пример PCA
import nanodlfrom nanodl import PCA# Использовать фактические данныеdata = nanodl.normal(shape=(1000, 10))# Инициализировать и обучить модель PCApca = PCA(n_comComponents=2)pca.fit(data)# Получить PCA Transformstransformed_data = pca.transform( data)# Получить обратное преобразованиеoriginal_data = pca.inverse_transform(transformed_data)# Образец из распределениеX_sampled = pca.sample(n_samples=1000, ключ=Нет)
Он все еще находится в разработке, работает отлично, но ожидаются грубости, поэтому вклад очень приветствуется!
Вносите изменения, не меняя шаблоны проектирования.
При необходимости напишите тесты для ваших изменений.
Установите локально с помощью pip3 install -e .
.
Запускайте тесты с помощью python3 -m unittest discover -s tests
.
Затем отправьте запрос на вытягивание.
Взносы могут осуществляться в различных формах:
Написание документации.
Исправление ошибок.
Реализация документов.
Написание тестов с высоким покрытием.
Оптимизация существующих кодов.
Экспериментируйте и отправляйте реальные примеры в раздел примеров.
Сообщаем об ошибках.
Реагирование на сообщения о проблемах.
Присоединяйтесь к серверу Discord, чтобы узнать больше.
Название «NanoDL» означает Nano Deep Learning. Размеры моделей стремительно растут, поэтому эксперты и компании с ограниченными ресурсами не могут создавать гибкие модели без непомерно высоких затрат. После успеха моделей Phi долгосрочная цель состоит в том, чтобы создать и обучить нановерсии всех доступных моделей, гарантируя при этом, что они будут конкурировать с оригинальными моделями по производительности, с общим количеством параметров, не превышающим 1B. Тренированные веса будут доступны через эту библиотеку. Любая форма спонсорства и финансирования поможет с учебными ресурсами. Вы можете спонсировать проект через GitHub здесь или связаться с нами по адресу [email protected].
Чтобы процитировать этот репозиторий:
@software{nanodl2024github, author = {Henry Ndubuaku}, title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.}, url = {http://github.com/hmunachi/nanodl}, year = {2024}, }