Библиотека haiku, использующая операторы xmap
/ pjit
в JAX для параллелизма моделей преобразователей.
Схема параллелизма аналогична исходной схеме Megatron-LM, которая эффективна на TPU благодаря высокоскоростной 2D-ячеистой сети. Существует также экспериментальная версия модели, реализующая шардинг в стиле ZeRo.
Эта библиотека предназначена для масштабирования примерно до 40B параметров на TPUv3, после чего следует использовать различные стратегии параллелизма. Для этого ознакомьтесь с другими реализациями, такими как GPT-NeoX или DeepSpeed.
Одним из будущих направлений исследований является интеграция этой кодовой базы с swarm-jax для достижения дальнейшей масштабируемости за счет конвейерного параллелизма.
07.12.21 : Добавлено руководство по тонкой настройке.
Авторегрессионная модель генерации текста с 6 миллиардами параметров, обученная на The Pile.
Загрузите тонкие гири (только гири bf16, для примера, 9 ГБ)
Загрузить полную информацию (включая параметры оптимизатора, 61 ГБ)
Частично обученные контрольно-пропускные пункты
Демо-версия колаба
Веб-демо
Сообщение в блоге Арана
Этот проект был бы невозможен без вычислительных ресурсов, щедро предоставленных TPU Research Cloud при поддержке EleutherAI.
Спасибо команде Cloud TPU в Google за предоставление раннего доступа к альфа-версии Cloud TPU VM (теперь общедоступной!)
Спасибо всем, кто так или иначе помог (в алфавитном порядке):
Гири GPT-J-6B лицензированы по лицензии Apache версии 2.0.
Гиперпараметр | Ценить |
---|---|
n_параметров | 6 053 381 344 |
n_layers | 28* |
d_модель | 4096 |
д_фф | 16 384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2048 |
н_вокаб | 50 257 (тот же токенизатор, что и у GPT-2/3) |
кодирование положения | Кодировки поворотного положения (RoPE) |
Размеры веревки из полиэтилена | 64 |
*
каждый слой состоит из одного блока прямой связи и одного блока самообслуживания.
Модель состоит из 28 слоев с размером модели 4096 и размером прямой связи 16384. Размерность модели разделена на 16 головок, каждая с размером 256. Кодирование поворотного положения (RoPE) было применено к 64 измерениям каждой головки. . Модель обучается с помощью словаря токенизации 50257 с использованием того же набора BPE, что и GPT-2/GPT-3.
Модели грубо отсортированы по производительности или по количеству провалов, если они недоступны.
Модель | Веса | Тренировочные флопы | ЛАМБАДА PPL ↓ | ЛАМБАДА Acc ↑ | Виногранде ↑ | Хелласваг ↑ | ПИКА ↑ | Размер набора данных (ГБ) |
---|---|---|---|---|---|---|---|---|
Шанс | ✔ | 0 | ~много | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ада‡ | ✘ | ----- | 9,95 | 51,6% | 52,9% | 43,4% | 70,5% | ----- |
ГПТ-2-1,5Б | ✔ | ----- | 10.63 | 51,21% | 59,4% | 50,9% | 70,8% | 40 |
ГПТНео-1.3Б‡ | ✔ | 3.0e21 | 7.50 | 57,2% | 55,0% | 48,9% | 71,1% | 825 |
Мегатрон-2,5Б* | ✘ | 2.4e21 | ----- | 61,7% | ----- | ----- | ----- | 174 |
ГПТНео-2.7Б‡ | ✔ | 6.8e21 | 5,63 | 62,2% | 56,5% | 55,8% | 73,0% | 825 |
ГПТ-3-1.3Б*‡ | ✘ | 2.4e21 | 5.44 | 63,6% | 58,7% | 54,7% | 75,1% | ~800 |
GPT-3-Бэббидж‡ | ✘ | ----- | 5,58 | 62,4% | 59,0% | 54,5% | 75,5% | ----- |
Мегатрон-8.3Б* | ✘ | 7.8e21 | ----- | 66,5% | ----- | ----- | ----- | 174 |
ГПТ-3-2,7Б*‡ | ✘ | 4.8e21 | 4,60 | 67,1% | 62,3% | 62,8% | 75,6% | ~800 |
Мегатрон-11Б† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
ГПТ-J-6B ‡ | ✔ | 1.5e22 | 3,99 | 69,7% | 65,3% | 66,1% | 76,5% | 825 |
ГПТ-3-6.7Б*‡ | ✘ | 1.2e22 | 4.00 | 70,3% | 64,5% | 67,4% | 78,0% | ~800 |
GPT-3-Кюри‡ | ✘ | ----- | 4.00 | 69,3% | 65,6% | 68,5% | 77,9% | ----- |
ГПТ-3-13Б*‡ | ✘ | 2.3e22 | 3,56 | 72,5% | 67,9% | 70,9% | 78,5% | ~800 |
ГПТ-3-175Б*‡ | ✘ | 3.1e23 | 3.00 | 76,2% | 70,2% | 78,9% | 81,0% | ~800 |
GPT-3-Давинчи‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
Гофер 230Б* | ✘ | 6.31E+23 | ----- | 74,50% | 70,10% | 79,20% | 81,80% | 1344 |
МТ-НЛГ 530Б*‡ | ✘ | ----- | ----- | 76,6% | 73,0% | 80,2% | 82,0% | ----- |
*
представляет собой оценочные числа, сообщенные соответствующими авторами, все остальные цифры получены при запуске lm-evaluation-harness либо с выпущенными весами, либо с доступом к API. Из-за небольших различий в реализации, а также из-за разной постановки задач с нулевым выстрелом их нельзя напрямую сравнивать. Более подробную информацию можно найти в этом сообщении в блоге.
†
Модель Мегатрон-11Б не обеспечивает сопоставимых показателей, а некоторые реализации, использующие выпущенные веса, не воспроизводят качество генерации и оценки. (см. 1 2 3) Таким образом, попытка оценки не проводилась.
‡
Эти модели были обучены на данных, которые содержат возможное загрязнение набора тестов. В моделях OpenAI GPT-3 не удалось дедуплицировать данные обучения для определенных наборов тестов, в то время как модели GPT-Neo, как и эта, обучаются на The Pile, дедупликация которого не производилась ни на одном наборе тестов.
Большинство сценариев в этом репозитории предназначены для запуска на TPU, которые в архитектуре TPU-VM представляют собой виртуальные машины, способные выполнять произвольный код. Большинство сценариев предназначены для запуска TPU, подключения к нему SSH для настройки зависимостей и копирования кода из локального каталога, а затем запуска исполнителя Ray, который может принимать вызовы RPC.
TPUVM обрабатывает этапы обучения и оценки модели, сохранение и загрузку контрольных точек, в то время как программа Python-драйвера отвечает за загрузку данных и общую оркестровку (например, когда сохранять контрольные точки и т. д.).
Это означает, что большинство сценариев ( train.py
, eval_harness.py
и т. д.) ожидаются на виртуальной машине GCE в том же регионе, что и TPU, чтобы минимизировать задержку RPC и стоимость передачи данных. Другие сценарии (обычно те, которые не принимают аргумент --tpu
, такие как device_sample.py
, device_serve.py
или device_train.py
) ожидают запуска непосредственно на TPUVM. Скрипты device_* работают только на v3-8 , но не на более крупных модулях.
Кроме того, есть пример ( resharding_example.py
) того, как преобразовать предоставленные контрольные точки (которые имеют 8 сегментов в случае GPT-J-6B) в меньшее количество, например, при работе на графических процессорах.
Чтобы точно настроить модель, запустите device_train.py
на виртуальной машине TPU. Используя TPU v3-8, вы можете выполнить точную настройку со скоростью ~ 5000 токенов в секунду, чего должно быть достаточно для наборов данных малого и среднего размера.
Пожалуйста, прочитайте пошаговое руководство для получения подробных инструкций по точной настройке.
Обратите внимание, что эта библиотека имеет некоторые особые требования для версии JAX. В частности, для использования моделей v1 (включая GPT-J 6B) требуется jax==0.2.12
. Это, в свою очередь, зависит от jaxlib==0.1.68
. Если этого не сделать, вы получите загадочные ошибки xmap.
Однако для использования кода модели v2 (без общедоступных весов) можно использовать новейшую версию JAX.
Чтобы процитировать этот репозиторий:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Приведу вес GPT-J-6B:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Если вы используете этот репозиторий или любой из предварительно обученных весов, чтобы сделать что-нибудь классное, мы будем рады услышать об этом. Не стесняйтесь открыть вопрос на GitHub или написать по электронной почте (в профиле).