Une bibliothèque haïku utilisant les opérateurs xmap
/ pjit
en JAX pour le parallélisme des modèles de transformateurs.
Le schéma de parallélisme est similaire au Megatron-LM original, qui est efficace sur les TPU grâce au réseau maillé 2D à haut débit. Il existe également une version de modèle expérimental qui implémente le partitionnement de style ZeRo.
Cette bibliothèque est conçue pour une évolutivité jusqu'à environ 40 B de paramètres sur les TPUv3, au-delà desquels différentes stratégies de parallélisme doivent être utilisées. Voir d'autres implémentations telles que GPT-NeoX ou DeepSpeed pour cela.
Une orientation future de la recherche consiste à intégrer cette base de code à swarm-jax, pour atteindre une évolutivité accrue grâce au parallélisme des pipelines.
12-07-21 : Ajout d'un guide de réglage fin
Un modèle de génération de texte autorégressif de 6 milliards de paramètres formé sur The Pile.
Téléchargez les poids minces (poids bf16 uniquement, pour inférence, 9 Go)
Télécharger les poids complets (y compris les paramètres d'optimisation, 61 Go)
Points de contrôle partiellement formés
Démo Colab
Démo Web
Article de blog d'Aran
Ce projet n'aurait pas été possible sans le calcul généreusement fourni par TPU Research Cloud avec l'aide d'EleutherAI.
Merci à l'équipe Cloud TPU de Google pour avoir fourni un accès anticipé à la version alpha de la VM Cloud TPU (maintenant disponible publiquement !)
Merci à tous ceux qui ont aidé d'une manière ou d'une autre (classés par ordre alphabétique) :
Les poids de GPT-J-6B sont sous licence sous la version 2.0 de la licence Apache.
Hyperparamètre | Valeur |
---|---|
n_paramètres | 6 053 381 344 |
n_couches | 28* |
d_modèle | 4 096 |
d_ff | 16 384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2 048 |
n_vocab | 50 257 (même tokeniseur que GPT-2/3) |
codage de position | Codages de position rotatifs (RoPE) |
Dimensions du câble | 64 |
*
chaque couche se compose d'un bloc de rétroaction et d'un bloc d'auto-attention
Le modèle se compose de 28 couches avec une dimension de modèle de 4 096 et une dimension de rétroaction de 16 384. La dimension du modèle est divisée en 16 têtes, chacune avec une dimension de 256. Des codages de position rotatifs (RoPE) ont été appliqués à 64 dimensions de chaque tête. . Le modèle est formé avec un vocabulaire de tokenisation de 50257, en utilisant le même ensemble de BPE que GPT-2/GPT-3.
Modèles grossièrement triés par performances, ou par FLOP s'ils ne sont pas disponibles.
Modèle | Poids | FLOP de formation | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Taille de l'ensemble de données (Go) |
---|---|---|---|---|---|---|---|---|
Chance | ✔ | 0 | ~beaucoup | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51,6% | 52,9% | 43,4% | 70,5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51,21% | 59,4% | 50,9% | 70,8% | 40 |
GPTNéo-1.3B‡ | ✔ | 3.0e21 | 7h50 | 57,2% | 55,0% | 48,9% | 71,1% | 825 |
Mégatron-2,5B* | ✘ | 2.4e21 | ----- | 61,7% | ----- | ----- | ----- | 174 |
GPTNéo-2.7B‡ | ✔ | 6.8e21 | 5,63 | 62,2% | 56,5% | 55,8% | 73,0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63,6% | 58,7% | 54,7% | 75,1% | ~800 |
GPT-3-Babbage‡ | ✘ | ----- | 5,58 | 62,4% | 59,0% | 54,5% | 75,5% | ----- |
Mégatron-8.3B* | ✘ | 7.8e21 | ----- | 66,5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4,60 | 67,1% | 62,3% | 62,8% | 75,6% | ~800 |
Mégatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3,99 | 69,7% | 65,3% | 66,1% | 76,5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 16h00 | 70,3% | 64,5% | 67,4% | 78,0% | ~800 |
GPT-3-Curie‡ | ✘ | ----- | 16h00 | 69,3% | 65,6% | 68,5% | 77,9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3,56 | 72,5% | 67,9% | 70,9% | 78,5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3h00 | 76,2% | 70,2% | 78,9% | 81,0% | ~800 |
GPT-3-Davinci‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
Gopher 230B* | ✘ | 6.31E+23 | ----- | 74,50% | 70,10% | 79,20% | 81,80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76,6% | 73,0% | 80,2% | 82,0% | ----- |
*
représente les nombres d'évaluation rapportés par leurs auteurs respectifs, tous les autres nombres sont fournis en exécutant le lm-evaluation-harnais soit avec les poids publiés, soit avec un accès API. En raison de différences subtiles de mise en œuvre ainsi que de différents cadrages de tâches zéro tir, celles-ci peuvent ne pas être directement comparables. Consultez cet article de blog pour plus de détails.
†
Le modèle Megatron-11B ne fournit aucune métrique comparable, et plusieurs implémentations utilisant les poids publiés ne reproduisent pas la qualité de la génération et les évaluations. (voir 1 2 3) Ainsi, aucune évaluation n'a été tentée.
‡
Ces modèles ont été formés avec des données contenant une éventuelle contamination de l'ensemble de test. Les modèles OpenAI GPT-3 n'ont pas réussi à dédupliquer les données d'entraînement pour certains ensembles de tests, tandis que les modèles GPT-Neo ainsi que celui-ci sont formés sur The Pile, qui n'a été dédupliqué sur aucun ensemble de tests.
La plupart des scripts de ce référentiel sont conçus pour être exécutés sur des TPU, qui, dans l'architecture TPU-VM, sont des machines virtuelles pouvant exécuter du code arbitraire. La plupart des scripts sont conçus pour lancer un TPU, y accéder via SSH pour configurer les dépendances et copier le code à partir du répertoire local, puis démarrer un travailleur Ray qui peut accepter les appels RPC.
Les TPUVM gèrent l'exécution des étapes de formation et de l'évaluation du modèle, la sauvegarde et le chargement des points de contrôle, tandis que le programme pilote Python gère le chargement des données et l'orchestration générale (comme quand enregistrer les points de contrôle, etc.).
Cela signifie que la plupart des scripts ( train.py
, eval_harness.py
etc.) s'attendent à s'exécuter sur une machine virtuelle GCE dans la même région que les TPU, afin de minimiser la latence RPC et le coût de transfert de données. D'autres scripts (généralement ceux qui ne prennent pas d'argument --tpu
, tels que device_sample.py
, device_serve.py
ou device_train.py
) s'attendent à être exécutés directement sur une TPUVM. Les scripts device_* ne fonctionnent que sur une v3-8 et non sur des pods plus grands.
De plus, il existe un exemple ( resharding_example.py
) montrant comment convertir les points de contrôle fournis (qui ont 8 fragments dans le cas de GPT-J-6B) en un nombre plus petit, par exemple lors d'une exécution sur un ou plusieurs GPU.
Pour affiner le modèle, exécutez device_train.py
sur une VM TPU. À l’aide d’un TPU v3-8, vous pouvez affiner le réglage à un rythme d’environ 5 000 jetons/seconde, ce qui devrait être suffisant pour les ensembles de données de petite à moyenne taille.
Veuillez lire le guide étape par étape pour obtenir des instructions de réglage détaillées.
Notez que cette bibliothèque a des exigences spécifiques pour la version JAX. Plus précisément, pour utiliser les modèles v1 (y compris GPT-J 6B), jax==0.2.12
est requis. Cela dépend à son tour de jaxlib==0.1.68
. Si cela n'est pas fait, vous obtiendrez des erreurs xmap cryptiques
Cependant, pour utiliser le code du modèle v2 (pas de poids rendu public), la version JAX la plus récente peut être utilisée.
Pour citer ce référentiel :
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Pour citer les poids du GPT-J-6B :
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Si vous utilisez ce référentiel ou l'un des poids pré-entraînés pour faire quelque chose de sympa, nous serions ravis d'en entendre parler. N'hésitez pas à ouvrir un problème github ou à nous contacter par e-mail (dans le profil).