Una biblioteca de haiku que utiliza los operadores xmap
/ pjit
en JAX para el paralelismo de modelos de transformadores.
El esquema de paralelismo es similar al Megatron-LM original, que es eficiente en TPU debido a la red de malla 2D de alta velocidad. También hay una versión de modelo experimental que implementa la fragmentación al estilo ZeRo.
Esta biblioteca está diseñada para ofrecer escalabilidad hasta aproximadamente 40 mil millones de parámetros en TPUv3, más allá de los cuales se deben usar diferentes estrategias de paralelismo. Consulte otras implementaciones como GPT-NeoX o DeepSpeed para eso.
Una dirección futura para la investigación es integrar este código base con swarm-jax, para lograr una mayor escalabilidad con el paralelismo de canalización.
07/12/21 : Se agregó una guía para realizar ajustes finos.
Un modelo de generación de texto autorregresivo de 6 mil millones de parámetros entrenado en The Pile.
Descargue pesas delgadas (solo pesas bf16, a modo de inferencia, 9 GB)
Descargue los pesos completos (incluidos los parámetros del optimizador, 61 GB)
Puntos de control parcialmente capacitados
Demostración de colaboración
Demostración web
Publicación del blog de Aran
Este proyecto no habría sido posible sin la computación proporcionada generosamente por TPU Research Cloud con la ayuda de EleutherAI.
Gracias al equipo de Cloud TPU de Google por brindar acceso temprano a Cloud TPU VM alpha (¡ahora disponible públicamente!)
Gracias a todos los que han ayudado de una forma u otra (en orden alfabético):
Las pesas de GPT-J-6B tienen la licencia de la versión 2.0 de la licencia Apache.
Hiperparámetro | Valor |
---|---|
n_parámetros | 6.053.381.344 |
n_capas | 28* |
d_modelo | 4.096 |
maldita sea | 16.384 |
n_cabezas | 16 |
cabeza_d_ | 256 |
n_ctx | 2.048 |
n_vocab | 50,257 (mismo tokenizador que GPT-2/3) |
codificación de posición | Codificaciones de posición rotatoria (RoPE) |
Dimensiones del cable | 64 |
*
cada capa consta de un bloque de avance y un bloque de atención personal
El modelo consta de 28 capas con una dimensión de modelo de 4096 y una dimensión de avance de 16384. La dimensión del modelo se divide en 16 cabezales, cada uno con una dimensión de 256. Se aplicaron codificaciones de posición giratoria (RoPE) a 64 dimensiones de cada cabezal. . El modelo se entrena con un vocabulario de tokenización de 50257, utilizando el mismo conjunto de BPE que GPT-2/GPT-3.
Modelos ordenados aproximadamente por rendimiento o por FLOP si no están disponibles.
Modelo | Pesos | FLOP de entrenamiento | LAMBADA PPL ↓ | Cuenta LAMBADA ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Tamaño del conjunto de datos (GB) |
---|---|---|---|---|---|---|---|---|
Oportunidad | ✔ | 0 | ~mucho | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9,95 | 51,6% | 52,9% | 43,4% | 70,5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51,21% | 59,4% | 50,9% | 70,8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57,2% | 55,0% | 48,9% | 71,1% | 825 |
Megatrón-2.5B* | ✘ | 2.4e21 | ----- | 61,7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62,2% | 56,5% | 55,8% | 73,0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63,6% | 58,7% | 54,7% | 75,1% | ~800 |
GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62,4% | 59,0% | 54,5% | 75,5% | ----- |
Megatrón-8.3B* | ✘ | 7.8e21 | ----- | 66,5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67,1% | 62,3% | 62,8% | 75,6% | ~800 |
Megatrón-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3,99 | 69,7% | 65,3% | 66,1% | 76,5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70,3% | 64,5% | 67,4% | 78,0% | ~800 |
GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69,3% | 65,6% | 68,5% | 77,9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72,5% | 67,9% | 70,9% | 78,5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76,2% | 70,2% | 78,9% | 81,0% | ~800 |
GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
Tuza 230B* | ✘ | 6.31E+23 | ----- | 74,50% | 70,10% | 79,20% | 81,80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76,6% | 73,0% | 80,2% | 82,0% | ----- |
*
representa los números de evaluación informados por sus respectivos autores; todos los demás números se obtienen ejecutando lm-evaluación-arnés con los pesos publicados o con acceso API. Debido a diferencias sutiles en la implementación, así como a diferentes encuadres de tareas de tiro cero, es posible que no sean directamente comparables. Consulte esta publicación de blog para obtener más detalles.
†
El modelo Megatron-11B no proporciona métricas comparables y varias implementaciones que utilizan los pesos publicados no reproducen la calidad de generación ni las evaluaciones. (ver 1 2 3) Por lo tanto, no se intentó la evaluación.
‡
Estos modelos han sido entrenados con datos que contienen una posible contaminación del equipo de prueba. Los modelos OpenAI GPT-3 no lograron deduplicar los datos de entrenamiento para ciertos conjuntos de prueba, mientras que los modelos GPT-Neo y este se entrenan en The Pile, que no se ha deduplicado en ningún conjunto de prueba.
La mayoría de los scripts de este repositorio están diseñados para ejecutarse en TPU, que bajo la arquitectura TPU-VM son máquinas virtuales que pueden ejecutar código arbitrario. La mayoría de los scripts están diseñados para activar una TPU, SSH en ella para configurar las dependencias y copiar el código desde el directorio local, y luego iniciar un trabajador Ray que pueda aceptar llamadas RPC.
Los TPUVM manejan los pasos de entrenamiento y evaluación del modelo en ejecución, el guardado y la carga de puntos de control, mientras que el programa Python del controlador maneja la carga de datos y la orquestación general (como cuándo guardar los puntos de control, etc.).
Esto significa que la mayoría de los scripts ( train.py
, eval_harness.py
, etc.) esperan ejecutarse en una máquina virtual GCE en la misma región que las TPU, para minimizar la latencia de RPC y el costo de transferencia de datos. Otros scripts (generalmente aquellos que no toman un argumento --tpu
, como device_sample.py
, device_serve.py
o device_train.py
) esperan ejecutarse directamente en una TPUVM. Los scripts device_* solo funcionan en v3-8 y no en pods más grandes.
Además, hay un ejemplo ( resharding_example.py
) de cómo convertir los puntos de control proporcionados (que tienen 8 fragmentos en el caso de GPT-J-6B) a un número menor, como cuando se ejecuta en GPU.
Para ajustar el modelo, ejecute device_train.py
en una máquina virtual de TPU. Con una TPU v3-8, puede realizar ajustes a una velocidad de ~5000 tokens/segundo, lo que debería ser suficiente para conjuntos de datos de tamaño pequeño a mediano.
Lea la guía paso a paso para obtener instrucciones detalladas de ajuste.
Tenga en cuenta que esta biblioteca tiene algunos requisitos específicos para la versión JAX. Específicamente, para utilizar los modelos v1 (incluido GPT-J 6B), se requiere jax==0.2.12
. Esto a su vez depende de jaxlib==0.1.68
. Si no lo hace, obtendrá errores crípticos de xmap
Sin embargo, para utilizar el código del modelo v2 (sin pesos publicados públicamente), se puede utilizar la versión JAX más reciente.
Para citar este repositorio:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Para citar los pesos de GPT-J-6B:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Si utiliza este repositorio o cualquiera de los pesos previamente entrenados para hacer algo interesante, nos encantaría saberlo. No dudes en abrir un problema de github o comunicarte por correo electrónico (en el perfil).