Eine Haiku-Bibliothek, die die xmap
/ pjit
-Operatoren in JAX für die Modellparallelität von Transformatoren verwendet.
Das Parallelitätsschema ähnelt dem ursprünglichen Megatron-LM, das aufgrund des Hochgeschwindigkeits-2D-Mesh-Netzwerks auf TPUs effizient ist. Es gibt auch eine experimentelle Modellversion, die Sharding im ZeRo-Stil implementiert.
Diese Bibliothek ist für eine Skalierbarkeit auf bis zu etwa 40B Parameter auf TPUv3s ausgelegt, darüber hinaus sollten verschiedene Parallelitätsstrategien verwendet werden. Sehen Sie sich dazu andere Implementierungen wie GPT-NeoX oder DeepSpeed an.
Eine zukünftige Forschungsrichtung ist die Integration dieser Codebasis mit Swarm-Jax, um eine weitere Skalierbarkeit durch Pipeline-Parallelität zu erreichen.
21.07.12 : Anleitung zur Feinabstimmung hinzugefügt
Ein autoregressives Textgenerierungsmodell mit 6 Milliarden Parametern, das auf The Pile trainiert wurde.
Schlanke Gewichte herunterladen (nur bf16-Gewichte, zur Veranschaulichung, 9 GB)
Vollständige Gewichtungen herunterladen (einschließlich Optimierungsparameter, 61 GB)
Teilweise ausgebildete Kontrollpunkte
Colab-Demo
Web-Demo
Arans Blogbeitrag
Dieses Projekt wäre ohne die großzügige Bereitstellung von Rechenleistung durch die TPU Research Cloud mit Unterstützung von EleutherAI nicht möglich gewesen.
Vielen Dank an das Cloud TPU-Team von Google für den frühen Zugriff auf die Cloud TPU VM Alpha (jetzt öffentlich verfügbar!)
Vielen Dank an alle, die auf die eine oder andere Weise geholfen haben (in alphabetischer Reihenfolge):
Die Gewichte von GPT-J-6B sind unter Version 2.0 der Apache-Lizenz lizenziert.
Hyperparameter | Wert |
---|---|
n_parameter | 6.053.381.344 |
n_layers | 28* |
d_model | 4.096 |
d_ff | 16.384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2.048 |
n_vocab | 50.257 (gleicher Tokenizer wie GPT-2/3) |
Positionskodierung | Rotatorische Positionskodierungen (RoPE) |
RoPE-Abmessungen | 64 |
*
Jede Schicht besteht aus einem Feedforward-Block und einem Selbstaufmerksamkeitsblock
Das Modell besteht aus 28 Schichten mit einer Modelldimension von 4096 und einer Feedforward-Dimension von 16384. Die Modelldimension ist in 16 Köpfe mit jeweils einer Dimension von 256 aufgeteilt. Rotierende Positionskodierungen (RoPE) wurden auf 64 Dimensionen jedes Kopfes angewendet . Das Modell wird mit einem Tokenisierungsvokabular von 50257 trainiert, wobei derselbe BPE-Satz wie GPT-2/GPT-3 verwendet wird.
Modelle grob sortiert nach Leistung oder nach FLOPs, falls nicht verfügbar.
Modell | Gewichte | Trainings-FLOPs | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Datensatzgröße (GB) |
---|---|---|---|---|---|---|---|---|
Chance | ✔ | 0 | ~viel | ~0% | 50 % | 25 % | 25 % | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9,95 | 51,6 % | 52,9 % | 43,4 % | 70,5 % | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51,21 % | 59,4 % | 50,9 % | 70,8 % | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7,50 | 57,2 % | 55,0 % | 48,9 % | 71,1 % | 825 |
Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61,7 % | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62,2 % | 56,5 % | 55,8 % | 73,0 % | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63,6 % | 58,7 % | 54,7 % | 75,1 % | ~800 |
GPT-3-Babbage‡ | ✘ | ----- | 5.58 | 62,4 % | 59,0 % | 54,5 % | 75,5 % | ----- |
Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66,5 % | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4,60 | 67,1 % | 62,3 % | 62,8 % | 75,6 % | ~800 |
Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1,5e22 | 3,99 | 69,7 % | 65,3 % | 66,1 % | 76,5 % | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70,3 % | 64,5 % | 67,4 % | 78,0 % | ~800 |
GPT-3-Curie‡ | ✘ | ----- | 4.00 | 69,3 % | 65,6 % | 68,5 % | 77,9 % | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3,56 | 72,5 % | 67,9 % | 70,9 % | 78,5 % | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3,00 | 76,2 % | 70,2 % | 78,9 % | 81,0 % | ~800 |
GPT-3-Davinci‡ | ✘ | ----- | 3,0 | 75 % | 72 % | 78 % | 80 % | ----- |
Gopher 230B* | ✘ | 6.31E+23 | ----- | 74,50 % | 70,10 % | 79,20 % | 81,80 % | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76,6 % | 73,0 % | 80,2 % | 82,0 % | ----- |
*
stellt die von den jeweiligen Autoren gemeldeten Bewertungszahlen dar, alle anderen Zahlen werden durch Ausführen des lm-evaluation-harness entweder mit den freigegebenen Gewichtungen oder mit API-Zugriff bereitgestellt. Aufgrund subtiler Implementierungsunterschiede sowie unterschiedlicher Zero-Shot-Aufgabenrahmen sind diese möglicherweise nicht direkt vergleichbar. Weitere Informationen finden Sie in diesem Blogbeitrag.
†
Das Megatron-11B-Modell bietet keine vergleichbaren Metriken und mehrere Implementierungen, die die veröffentlichten Gewichte verwenden, reproduzieren die Generierungsqualität und Bewertungen nicht. (siehe 1 2 3) Eine Bewertung wurde daher nicht versucht.
‡
Diese Modelle wurden mit Daten trainiert, die eine mögliche Kontamination des Testsatzes enthalten. Die OpenAI GPT-3-Modelle konnten die Trainingsdaten für bestimmte Testsätze nicht deduplizieren, während die GPT-Neo-Modelle und dieses Modell auf The Pile trainiert werden, das mit keinem Testsatz dedupliziert wurde.
Die meisten Skripte in diesem Repository sind für die Ausführung auf TPUs konzipiert, bei denen es sich im Rahmen der TPU-VM-Architektur um virtuelle Maschinen handelt, die beliebigen Code ausführen können. Die meisten Skripte sind darauf ausgelegt, eine TPU hochzufahren, per SSH darauf zuzugreifen, um die Abhängigkeiten einzurichten, Code aus dem lokalen Verzeichnis zu kopieren und dann einen Ray-Worker zu starten, der RPC-Aufrufe annehmen kann.
Die TPUVMs kümmern sich um die Ausführung von Modelltrainingsschritten und die Auswertung, das Speichern und Laden von Prüfpunkten, während das Treiber-Python-Programm das Laden von Daten und die allgemeine Orchestrierung übernimmt (z. B. wann Prüfpunkte gespeichert werden sollen usw.).
Das bedeutet, dass die meisten Skripte ( train.py
, eval_harness.py
usw.) davon ausgehen, dass sie auf einer virtuellen GCE-Maschine in derselben Region wie die TPUs ausgeführt werden, um RPC-Latenz und Datenübertragungskosten zu minimieren. Andere Skripte (normalerweise solche, die kein --tpu
Argument annehmen, wie z. B. device_sample.py
, device_serve.py
oder device_train.py
) werden voraussichtlich direkt auf einer TPUVM ausgeführt. Die Skripte „device_*“ funktionieren nur auf v3-8 und nicht auf größeren Pods.
Darüber hinaus gibt es ein Beispiel ( resharding_example.py
), wie die bereitgestellten Prüfpunkte (die im Fall von GPT-J-6B 8 Shards haben) auf eine kleinere Anzahl konvertiert werden, beispielsweise für die Ausführung auf GPU(s).
Um das Modell zu optimieren, führen Sie device_train.py
auf einer TPU-VM aus. Mit einer TPU v3-8 können Sie eine Feinabstimmung mit einer Rate von ~5000 Token/Sekunde durchführen, was für kleine bis mittelgroße Datensätze ausreichend sein sollte.
Bitte lesen Sie die Schritt-für-Schritt-Anleitung für ausführliche Anweisungen zur Feinabstimmung.
Beachten Sie, dass diese Bibliothek einige spezifische Anforderungen für die JAX-Version hat. Insbesondere ist für die Verwendung der v1-Modelle (einschließlich GPT-J 6B) jax==0.2.12
erforderlich. Dies wiederum hängt von jaxlib==0.1.68
ab. Geschieht dies nicht, kommt es zu kryptischen xmap-Fehlern
Um jedoch den v2-Modellcode (keine öffentlich veröffentlichten Gewichtungen) zu verwenden, kann die neueste JAX-Version verwendet werden.
Um dieses Repository zu zitieren:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Um die Gewichte von GPT-J-6B zu nennen:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Wenn Sie dieses Repository oder eines der vorab trainierten Gewichte verwenden, um etwas Cooles zu machen, würden wir uns freuen, davon zu hören. Fühlen Sie sich frei, ein Github-Problem zu eröffnen oder uns per E-Mail (im Profil) zu kontaktieren.