MaxText est un LLM open source hautes performances , hautement évolutif , écrit en Python/Jax pur et ciblant les TPU et GPU Google Cloud pour la formation et l'inférence . MaxText atteint des MFU élevés et évolue d'un hôte unique à de très grands clusters tout en restant simple et « sans optimisation » grâce à la puissance de Jax et du compilateur XLA.
MaxText vise à être un point de départ pour des projets LLM ambitieux tant en recherche qu'en production. Nous encourageons les utilisateurs à commencer par expérimenter MaxText directement, puis à créer et modifier MaxText pour répondre à leurs besoins.
Nous avons utilisé MaxText pour démontrer une formation haute performance et bien convergente en int8 et étendre la formation à environ 51 000 puces.
Principales fonctionnalités prises en charge :
Pour votre première exécution de MaxText, nous fournissons des instructions spécifiques.
MaxText prend en charge la formation et l'inférence de divers modèles ouverts. Suivez les guides d'utilisation dans le dossier de démarrage pour en savoir plus.
Quelques guides supplémentaires utiles :
En plus des guides de démarrage, il existe toujours d'autres fonctionnalités MaxText qui sont constamment ajoutées ! La suite complète des tests de bout en bout se trouve dans end_to_end. Nous les exécutons avec une cadence nocturne. Ils peuvent être une bonne source pour comprendre MaxText. Vous pouvez également voir les tests unitaires continus qui sont exécutés presque en continu.
Plus de détails sur la reproduction de ces résultats peuvent être trouvés dans MaxText/configs/README.md.
Nombre de paramètres | Type d'accélérateur | TFLOP/puce/s | Utilisation des échecs du modèle (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71,47% |
64B | v5p-128 | 3.23e+02 | 70,31% |
128B | v5p-256 | 3.15e+02 | 68,68% |
128B | v5p-512 | 3.15e+02 | 68,53% |
256B | v5p-1024 | 3.16e+02 | 68,82% |
512B | v5p-1024 | 2.94e+02 | 63,99% |
1024B | v5p-2048 | 2.49e+02 | 64,05% |
1024B | v5p-4096 | 2.97e+02 | 64,80% |
1160B | v5p-7680 | 2,95e+02 | 64,27% |
1160B | v5p-12288 | 3.04e+02 | 66,23% |
Pour les modèles 16B, 32B, 64B et 128B. Voir les configurations complètes dans MaxText/configs/v5e/ comme 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
.
Matériel | 16 B TFLOP/sec/puce | 16B MFU | 32B TFLOP/sec/puce | 32B MFU | 64 B TFLOP/sec/puce | 64B MFU | 128 B TFLOP/sec/puce | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61,10% | 132 | 66,86% | 118 | 59,90% | 110 | 56,06% |
2x v5e-256 | 117 | 59,37% | 128 | 64,81% | 112 | 56,66% | 110 | 55,82% |
4x v5e-256 | 117 | 59,14% | 126 | 64,10% | 110 | 55,85% | 108 | 54,93% |
8x v5e-256 | 115 | 58,27% | 125 | 63,67% | 108 | 54,96% | 104 | 52,93% |
16x v5e-256 | 111 | 56,56% | 123 | 62,26% | 105 | 53,29% | 100 | 50,86% |
32x v5e-256 | 108 | 54,65% | 119 | 60,40% | 99 | 50,18% | 91 | 46,25% |
MaxText s'inspire fortement de MinGPT/NanoGPT, d'élégantes implémentations GPT autonomes écrites en PyTorch et ciblant les GPU Nvidia. MaxText est plus complexe, prend en charge davantage de modèles standard de l'industrie et s'adapte à des dizaines de milliers de puces. En fin de compte, MaxText a un MFU plus de trois fois supérieur aux 17 % signalés le plus récemment avec cette base de code, est massivement évolutif et implémente un cache clé-valeur pour un décodage auto-régressif efficace.
MaxText est plus similaire à Nvidia/Megatron-LM, une implémentation LLM très bien optimisée ciblant les GPU Nvidia. Les deux implémentations atteignent des MFU comparables. La différence entre les bases de code met en évidence les différentes stratégies de programmation. MaxText est du pur Python, s'appuyant fortement sur le compilateur XLA pour atteindre des performances élevées. En revanche, Megatron-LM est un mélange de Python et CUDA, s'appuyant sur des noyaux CUDA bien optimisés pour atteindre des performances élevées.
MaxText est également comparable à Pax. Comme Pax, MaxText fournit des implémentations hautes performances et évolutives de LLM dans Jax. Pax se concentre sur l'activation de paramètres de configuration puissants, permettant aux développeurs de modifier le modèle en modifiant les paramètres de configuration. En revanche, MaxText est une implémentation simple et concrète de divers LLM qui encourage les utilisateurs à étendre leurs connaissances en bifurquant et en éditant directement le code source.
Lors de l'exécution d'une tâche SPMD (Programme unique, données multiples) sur des accélérateurs, le processus global peut se bloquer en cas d'erreur ou si une VM se bloque/plante pour une raison quelconque. Dans ce scénario, la capture des traces de pile aidera à identifier et à résoudre les problèmes liés aux tâches exécutées sur les machines virtuelles TPU.
Les configurations suivantes aideront à déboguer une erreur ou lorsqu'un programme est bloqué ou bloqué quelque part en collectant des traces de pile. Modifiez les valeurs des paramètres en conséquence dans MaxText/configs/base.yml
:
collect_stack_trace: True
pour activer la collecte des traces de pile en cas d'erreurs ou lorsque le programme est bloqué. Ce paramètre videra périodiquement les traces du programme pour faciliter le débogage. Pour désactiver cela, définissez collect_stack_trace: False
.stack_trace_to_cloud: False
pour afficher les traces de pile sur la console. stack_trace_to_cloud: True
créera un fichier temporaire dans /tmp/debugging
dans les TPU pour stocker les traces de pile. Un agent s'exécute sur les machines virtuelles TPU qui télécharge périodiquement les traces du répertoire temporaire vers la journalisation cloud dans le projet gcp. Vous pouvez afficher les traces dans l'Explorateur de journaux sur Cloud Logging à l'aide de la requête suivante : logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
signifie la durée en secondes entre chaque événement de collecte de trace de pile. Le réglage stack_trace_interval_seconds: 600
collectera les traces de pile toutes les 600 secondes (10 minutes).Voici le package PyPI associé : https://pypi.org/project/cloud-tpu-diagnostics.
Pour compiler votre formation à l'avance, nous fournissons un outil train_compile.py
. Cet outil vous permet de compiler le train_step
principal dans train.py
pour le matériel cible (par exemple un grand nombre de périphériques v5e) sans utiliser le cluster complet.
Vous pouvez utiliser uniquement un processeur ou une seule machine virtuelle d'une famille différente pour précompiler un cluster TPU. Cette compilation répond à deux objectifs principaux :
Il signalera toute information de mémoire insuffisante (MOO), par exemple lorsque per_device_batch_size
est défini trop haut, avec une trace de pile MOO identique à celle si elle avait été compilée sur le matériel cible.
La compilation anticipée peut être enregistrée puis chargée pour des temps de démarrage et de redémarrage rapides sur le matériel cible.
L'outil train_compile.py
est étroitement lié à train.py
et utilise le même fichier de configuration configs/base.yml
. Bien que vous n'ayez pas besoin d'exécuter sur un TPU, vous devez installer jax[tpu]
en plus d'autres dépendances, nous vous recommandons donc d'exécuter setup.sh
pour les installer si vous ne l'avez pas déjà fait.
Après avoir installé les dépendances répertoriées ci-dessus, vous êtes prêt à compiler à l'avance :
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
Cela compilera un modèle MaxText de paramètres 16B sur 2 pods v5e.
Voici un exemple qui enregistre puis charge le train_step
compilé, en commençant par la sauvegarde :
Étape 1 : Exécutez AOT et enregistrez la fonction compilée
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
Étape 2 : Exécutez train.py et chargez la fonction compilée
Pour charger le train_step compilé, il vous suffit de passer compiled_trainstep_file=my_compiled_train.pickle
dans train.py
:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Dans l'étape de sauvegarde de l'exemple 2 ci-dessus, nous avons inclus l'exportation de l'indicateur du compilateur LIBTPU_INIT_ARGS
et learning_rate
car ceux-ci affectent l'objet compilé my_compiled_train.pickle.
Les tailles du modèle (par exemple global_parameter_scale
, max_sequence_length
et per_device_batch
) sont fixes lors de la compilation initiale via compile_train.py
, vous verrez une erreur de taille si vous essayez d'exécuter l'objet compilé enregistré avec des tailles différentes de celles avec lesquelles vous avez compilé. Cependant, une remarque subtile est que le calendrier du taux d'apprentissage est également fixé lorsque vous exécutez compile_train
- qui est déterminé à la fois par steps
et par learning_rate
. Les paramètres de l'optimiseur tels que adam_b1
sont transmis uniquement sous forme d'objets façonnés au compilateur - leurs valeurs réelles sont donc déterminées lorsque vous exécutez train.py
, pas pendant la compilation. Si vous transmettez différentes formes (par exemple per_device_batch
), vous recevrez un message d'erreur clair indiquant que la signature compilée a des formes attendues différentes de celles entrées. Si vous essayez d'exécuter sur un matériel différent de celui des cibles de compilation demandées via compile_topology
, vous obtiendrez une erreur indiquant qu'il y a un échec dans le mappage des périphériques compilés vers vos périphériques réels. L'utilisation d'indicateurs XLA ou d'un LIBTPU différents de ceux qui ont été compilés s'exécutera probablement silencieusement avec l'environnement dans lequel vous avez compilé sans erreur. Cependant, il n'y a aucun comportement garanti dans ce cas ; vous devez exécuter dans le même environnement que celui dans lequel vous avez compilé.
La compilation anticipée est également prise en charge pour les GPU, avec quelques différences par rapport aux TPU :
Le GPU ne prend pas en charge la compilation sur le matériel : un hôte GPU est toujours requis pour exécuter la compilation AoT, mais un seul hôte GPU peut compiler un programme pour un cluster plus grand du même matériel.
Pour les GPU A3 Cloud, la taille maximale de « tranche » est un hôte unique et le paramètre compile_topology_num_slices
représente le nombre de machines A3 pour lesquelles précompiler.
Cet exemple illustre les flags à utiliser pour une compilation GPU multi-hôtes ciblant un cluster de 4 hôtes A3 :
Étape 1 : Exécutez AOT et enregistrez la fonction compilée
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
Étape 2 : Exécutez train.py et chargez la fonction compilée
Pour charger le train_step compilé, il vous suffit de passer compiled_trainstep_file=my_compiled_train.pickle
dans train.py
:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Comme dans le cas TPU, notez que l'environnement de compilation doit correspondre à l'environnement d'exécution, dans ce cas en définissant le même XLA_FLAGS
.
MaxText prend en charge le téléchargement automatique des journaux collectés dans un répertoire vers une instance Tensorboard dans Vertex AI. Suivez le guide de l'utilisateur pour en savoir plus.