MaxText ist ein leistungsstarkes , hoch skalierbares Open-Source -LLM, das in reinem Python/Jax geschrieben ist und für Training und Inferenz auf Google Cloud-TPUs und -GPUs abzielt. MaxText erreicht hohe MFUs und skaliert von einem einzelnen Host bis hin zu sehr großen Clustern und bleibt dabei dank der Leistung von Jax und dem XLA-Compiler einfach und „optimierungsfrei“.
MaxText möchte ein Ausgangspunkt für ambitionierte LLM-Projekte sowohl in der Forschung als auch in der Produktion sein. Wir empfehlen Benutzern, zunächst sofort mit MaxText zu experimentieren und MaxText dann zu teilen und zu modifizieren, um es an ihre Bedürfnisse anzupassen.
Wir haben MaxText verwendet, um ein leistungsstarkes, gut konvergentes Training in int8 zu demonstrieren und das Training auf ~51.000 Chips zu skalieren.
Wichtige unterstützte Funktionen:
Wenn Sie MaxText zum ersten Mal ausführen, stellen wir Ihnen spezifische Anweisungen zur Verfügung.
MaxText unterstützt Training und Inferenz verschiedener offener Modelle. Befolgen Sie die Benutzerhandbücher im Ordner „Erste Schritte“, um mehr zu erfahren.
Einige besonders hilfreiche Anleitungen:
Zusätzlich zu den Einführungshandbüchern gibt es immer noch weitere MaxText-Funktionen, die ständig hinzugefügt werden! Die vollständige Suite von End-to-End-Tests finden Sie in end_to_end. Wir lassen sie im nächtlichen Rhythmus laufen. Sie können eine gute Quelle zum Verständnis von MaxText sein. Alternativ können Sie sich die kontinuierlichen Unit-Tests ansehen, die fast kontinuierlich ausgeführt werden.
Weitere Details zum Reproduzieren dieser Ergebnisse finden Sie in MaxText/configs/README.md.
Anzahl der Parameter | Beschleunigertyp | TFLOP/Chip/Sek | Modell-Flops-Auslastung (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71,47 % |
64B | v5p-128 | 3.23e+02 | 70,31 % |
128B | v5p-256 | 3.15e+02 | 68,68 % |
128B | v5p-512 | 3.15e+02 | 68,53 % |
256B | v5p-1024 | 3.16e+02 | 68,82 % |
512B | v5p-1024 | 2.94e+02 | 63,99 % |
1024B | v5p-2048 | 2.49e+02 | 64,05 % |
1024B | v5p-4096 | 2,97e+02 | 64,80 % |
1160B | v5p-7680 | 2,95e+02 | 64,27 % |
1160B | v5p-12288 | 3.04e+02 | 66,23 % |
Für 16B-, 32B-, 64B- und 128B-Modelle. Vollständige Ausführungskonfigurationen finden Sie in MaxText/configs/v5e/ als 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
Hardware | 16B TFLOP/Sek./Chip | 16B MFU | 32B TFLOP/Sek./Chip | 32B MFU | 64B TFLOP/Sek./Chip | 64B MFU | 128 B TFLOP/Sek./Chip | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61,10 % | 132 | 66,86 % | 118 | 59,90 % | 110 | 56,06 % |
2x v5e-256 | 117 | 59,37 % | 128 | 64,81 % | 112 | 56,66 % | 110 | 55,82 % |
4x v5e-256 | 117 | 59,14 % | 126 | 64,10 % | 110 | 55,85 % | 108 | 54,93 % |
8x v5e-256 | 115 | 58,27 % | 125 | 63,67 % | 108 | 54,96 % | 104 | 52,93 % |
16x v5e-256 | 111 | 56,56 % | 123 | 62,26 % | 105 | 53,29 % | 100 | 50,86 % |
32x v5e-256 | 108 | 54,65 % | 119 | 60,40 % | 99 | 50,18 % | 91 | 46,25 % |
MaxText ist stark von MinGPT/NanoGPT inspiriert, eleganten eigenständigen GPT-Implementierungen, die in PyTorch geschrieben wurden und auf Nvidia-GPUs abzielen. MaxText ist komplexer, unterstützt mehr Industriestandardmodelle und lässt sich auf Zehntausende Chips skalieren. Letztendlich hat MaxText eine MFU, die mehr als dreimal so hoch ist wie die zuletzt mit dieser Codebasis gemeldeten 17 %, ist enorm skalierbar und implementiert einen Schlüsselwert-Cache für eine effiziente autoregressive Dekodierung.
MaxText ähnelt eher Nvidia/Megatron-LM, einer sehr gut abgestimmten LLM-Implementierung, die auf Nvidia-GPUs abzielt. Die beiden Implementierungen erreichen vergleichbare MFUs. Der Unterschied in den Codebasen verdeutlicht die unterschiedlichen Programmierstrategien. MaxText ist reines Python und verlässt sich stark auf den XLA-Compiler, um eine hohe Leistung zu erzielen. Im Gegensatz dazu ist Megatron-LM eine Mischung aus Python und CUDA und basiert auf gut optimierten CUDA-Kerneln, um eine hohe Leistung zu erzielen.
MaxText ist auch mit Pax vergleichbar. Wie Pax bietet MaxText leistungsstarke und skalierbare Implementierungen von LLMs in Jax. Pax konzentriert sich auf die Aktivierung leistungsstarker Konfigurationsparameter, sodass Entwickler das Modell durch Bearbeiten von Konfigurationsparametern ändern können. Im Gegensatz dazu ist MaxText eine einfache, konkrete Implementierung verschiedener LLMs, die Benutzer dazu ermutigt, durch Verzweigung und direkte Bearbeitung des Quellcodes zu erweitern.
Beim Ausführen eines SPMD-Jobs (Single Program, Multiple Data) auf Beschleunigern kann der Gesamtprozess hängen bleiben, wenn ein Fehler auftritt oder eine VM aus irgendeinem Grund hängt/abstürzt. In diesem Szenario hilft die Erfassung von Stack-Traces dabei, die Probleme für die auf TPU-VMs ausgeführten Jobs zu identifizieren und zu beheben.
Die folgenden Konfigurationen helfen beim Debuggen eines Fehlers oder wenn ein Programm hängen bleibt oder irgendwo hängt, indem sie Stack-Traces sammeln. Ändern Sie die Parameterwerte entsprechend in MaxText/configs/base.yml
:
collect_stack_trace: True
fest, um die Erfassung von Stack-Traces bei Fehlern oder beim Aufhängen des Programms zu ermöglichen. Mit dieser Einstellung werden die Ablaufverfolgungen regelmäßig ausgegeben, damit das Programm beim Debuggen hilft. Um dies zu deaktivieren, legen Sie collect_stack_trace: False
fest.stack_trace_to_cloud: False
fest, um Stack-Traces auf der Konsole anzuzeigen. stack_trace_to_cloud: True
erstellt eine temporäre Datei in /tmp/debugging
in den TPUs, um die Stack-Traces zu speichern. Auf TPU-VMs läuft ein Agent, der die Traces regelmäßig aus dem temporären Verzeichnis in die Cloud-Protokollierung im GCP-Projekt hochlädt. Sie können die Traces im Logs Explorer in Cloud Logging mit der folgenden Abfrage anzeigen: logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
gibt die Dauer in Sekunden zwischen den einzelnen Stack-Trace-Erfassungsereignissen an. Wenn Sie stack_trace_interval_seconds: 600
festlegen, werden die Stack-Traces alle 600 Sekunden (10 Minuten) erfasst.Hier ist das zugehörige PyPI-Paket: https://pypi.org/project/cloud-tpu-diagnostics.
Um Ihren Trainingslauf vorab zu kompilieren, stellen wir das Tool train_compile.py
zur Verfügung. Mit diesem Tool können Sie den Haupt train_step
in train.py
für Zielhardware (z. B. eine große Anzahl von v5e-Geräten) kompilieren, ohne den gesamten Cluster zu verwenden.
Sie können für die Vorkompilierung für einen TPU-Cluster nur eine CPU oder eine einzelne VM aus einer anderen Familie verwenden. Diese Zusammenstellung hilft bei zwei Hauptzielen:
Es markiert alle Informationen über nicht genügend Arbeitsspeicher (OOM), z. B. wenn per_device_batch_size
zu hoch eingestellt ist, mit einem identischen OOM-Stack-Trace, als ob es auf der Zielhardware kompiliert worden wäre.
Die Vorabkompilierung kann gespeichert und dann geladen werden, um schnelle Start- und Neustartzeiten auf der Zielhardware zu ermöglichen.
Das Tool train_compile.py
ist eng mit train.py
verknüpft und verwendet dieselbe Konfigurationsdatei configs/base.yml
. Obwohl Sie nicht auf einer TPU laufen müssen, müssen Sie jax[tpu]
zusätzlich zu anderen Abhängigkeiten installieren. Wir empfehlen daher setup.sh
auszuführen, um diese zu installieren, falls Sie dies noch nicht getan haben.
Nachdem Sie die oben aufgeführten Abhängigkeiten installiert haben, können Sie mit der vorzeitigen Kompilierung beginnen:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
Dadurch wird ein MaxText-Modell mit 16B Parametern auf 2 v5e-Pods kompiliert.
Hier ist ein Beispiel, das den kompilierten train_step
speichert und dann lädt, beginnend mit dem Speichern:
Schritt 1: Führen Sie AOT aus und speichern Sie die kompilierte Funktion
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
Schritt 2: Führen Sie train.py aus und laden Sie die kompilierte Funktion
Um den kompilierten train_step zu laden, müssen Sie nur compiled_trainstep_file=my_compiled_train.pickle
an train.py
übergeben:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Im Speicherschritt von Beispiel 2 oben haben wir den Export der Compiler-Flags LIBTPU_INIT_ARGS
und learning_rate
einbezogen, da diese sich auf das kompilierte Objekt my_compiled_train.pickle.
Die Größen des Modells (z. B. global_parameter_scale
, max_sequence_length
und per_device_batch
) sind festgelegt, wenn Sie zum ersten Mal über compile_train.py
kompilieren. Wenn Sie versuchen, das gespeicherte kompilierte Objekt mit anderen Größen als beim Kompilieren auszuführen, wird ein Größenfehler angezeigt. Ein subtiler Hinweis ist jedoch, dass der Zeitplan für die Lernrate auch festgelegt ist, wenn Sie compile_train
ausführen – was sowohl von steps
als auch von learning_rate
bestimmt wird. Die Optimierungsparameter wie adam_b1
werden nur als geformte Objekte an den Compiler übergeben – daher werden ihre tatsächlichen Werte bestimmt, wenn Sie train.py
ausführen, nicht während der Kompilierung. Wenn Sie unterschiedliche Formen übergeben (z. B. per_device_batch
), erhalten Sie eine eindeutige Fehlermeldung, die besagt, dass die kompilierte Signatur andere erwartete Formen aufweist als die eingegebenen. Wenn Sie versuchen, auf einer anderen Hardware als den über compile_topology
angeforderten Kompilierungszielen auszuführen, erhalten Sie eine Fehlermeldung, die besagt, dass die Zuordnung der Geräte von der kompilierten zu Ihren realen Geräten fehlgeschlagen ist. Die Verwendung anderer XLA-Flags oder einer LIBTPU als der, die kompiliert wurde, wird in der Umgebung, in der Sie kompiliert haben, wahrscheinlich unbeaufsichtigt und ohne Fehler ausgeführt. Eine Verhaltensgarantie gibt es in diesem Fall allerdings nicht; Sie sollten in derselben Umgebung ausgeführt werden, in der Sie kompiliert haben.
Die Ahead-of-Time-Kompilierung wird auch für GPUs unterstützt, mit einigen Unterschieden zu TPUs:
Die GPU unterstützt keine hardwareübergreifende Kompilierung: Zum Ausführen der AoT-Kompilierung ist weiterhin ein GPU-Host erforderlich, aber ein einzelner GPU-Host kann ein Programm für einen größeren Cluster derselben Hardware kompilieren.
Bei A3-Cloud-GPUs beträgt die maximale „Slice“-Größe einen einzelnen Host, und der Parameter compile_topology_num_slices
stellt die Anzahl der A3-Maschinen dar, für die vorkompiliert werden soll.
Dieses Beispiel veranschaulicht die Flags, die für eine Multihost-GPU-Kompilierung verwendet werden sollen, die auf einen Cluster von 4 A3-Hosts abzielt:
Schritt 1: Führen Sie AOT aus und speichern Sie die kompilierte Funktion
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
Schritt 2: Führen Sie train.py aus und laden Sie die kompilierte Funktion
Um den kompilierten train_step zu laden, müssen Sie nur compiled_trainstep_file=my_compiled_train.pickle
an train.py
übergeben:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Beachten Sie wie im TPU-Fall, dass die Kompilierungsumgebung mit der Ausführungsumgebung übereinstimmen muss, in diesem Fall durch Festlegen derselben XLA_FLAGS
.
MaxText unterstützt das automatische Hochladen von in einem Verzeichnis gesammelten Protokollen auf eine Tensorboard-Instanz in Vertex AI. Befolgen Sie die Bedienungsanleitung, um mehr zu erfahren.