MaxText es un LLM de código abierto , altamente escalable y de alto rendimiento escrito en Python/Jax puro y dirigido a TPU y GPU de Google Cloud para entrenamiento e inferencia . MaxText logra altas MFU y escala desde un solo host hasta clústeres muy grandes sin dejar de ser simple y "libre de optimización" gracias al poder de Jax y el compilador XLA.
MaxText pretende ser un punto de partida para ambiciosos proyectos LLM tanto en investigación como en producción. Alentamos a los usuarios a comenzar experimentando con MaxText de forma inmediata y luego bifurcar y modificar MaxText para satisfacer sus necesidades.
Hemos utilizado MaxText para demostrar un entrenamiento bien convergente y de alto rendimiento en int8 y escalar el entrenamiento a ~51K chips.
Funciones clave compatibles:
Para la primera vez que ejecuta MaxText, le proporcionamos instrucciones específicas.
MaxText admite el entrenamiento y la inferencia de varios modelos abiertos. Siga las guías del usuario en la carpeta de introducción para saber más.
Algunas guías adicionales útiles:
Además de las guías de introducción, siempre hay otras capacidades de MaxText que se agregan constantemente. El conjunto completo de pruebas de un extremo a otro se encuentra en end_to_end. Los ejecutamos con una cadencia nocturna. Pueden ser una buena fuente para comprender MaxText. Alternativamente, puede ver las pruebas unitarias continuas que se ejecutan casi continuamente.
Se pueden encontrar más detalles sobre la reproducción de estos resultados en MaxText/configs/README.md.
No. de parámetros | Tipo de acelerador | TFLOP/chip/seg | Utilización de fracasos del modelo (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71,47% |
64B | v5p-128 | 3.23e+02 | 70,31% |
128B | v5p-256 | 3.15e+02 | 68,68% |
128B | v5p-512 | 3.15e+02 | 68,53% |
256B | v5p-1024 | 3.16e+02 | 68,82% |
512B | v5p-1024 | 2.94e+02 | 63,99% |
1024B | v5p-2048 | 2.49e+02 | 64,05% |
1024B | v5p-4096 | 2.97e+02 | 64,80% |
1160B | v5p-7680 | 2.95e+02 | 64,27% |
1160B | v5p-12288 | 3.04e+02 | 66,23% |
Para modelos 16B, 32B, 64B y 128B. Vea las configuraciones de ejecución completa en MaxText/configs/v5e/ como 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
Hardware | 16B TFLOP/seg/chip | MFU 16B | 32B TFLOP/seg/chip | 32B MFU | 64B TFLOP/seg/chip | MFU 64B | 128B TFLOP/seg/chip | MFU 128B |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61,10% | 132 | 66,86% | 118 | 59,90% | 110 | 56,06% |
2x v5e-256 | 117 | 59,37% | 128 | 64,81% | 112 | 56,66% | 110 | 55,82% |
4xv5e-256 | 117 | 59,14% | 126 | 64,10% | 110 | 55,85% | 108 | 54,93% |
8x v5e-256 | 115 | 58,27% | 125 | 63,67% | 108 | 54,96% | 104 | 52,93% |
16x v5e-256 | 111 | 56,56% | 123 | 62,26% | 105 | 53,29% | 100 | 50,86% |
32x v5e-256 | 108 | 54,65% | 119 | 60,40% | 99 | 50,18% | 91 | 46,25% |
MaxText está fuertemente inspirado en MinGPT/NanoGPT, elegantes implementaciones GPT independientes escritas en PyTorch y dirigidas a GPU Nvidia. MaxText es más complejo, admite más modelos estándar de la industria y escala a decenas de miles de chips. En última instancia, MaxText tiene una MFU más de tres veces el 17% reportado más recientemente con esa base de código, es enormemente escalable e implementa un caché de valores clave para una decodificación autorregresiva eficiente.
MaxText es más similar a Nvidia/Megatron-LM, una implementación LLM muy bien adaptada dirigida a las GPU de Nvidia. Las dos implementaciones logran MFU comparables. La diferencia en las bases de código resalta las diferentes estrategias de programación. MaxText es Python puro y depende en gran medida del compilador XLA para lograr un alto rendimiento. Por el contrario, Megatron-LM es una combinación de Python y CUDA, que se basa en núcleos CUDA bien optimizados para lograr un alto rendimiento.
MaxText también es comparable a Pax. Al igual que Pax, MaxText proporciona implementaciones escalables y de alto rendimiento de LLM en Jax. Pax se centra en habilitar potentes parámetros de configuración, lo que permite a los desarrolladores cambiar el modelo editando los parámetros de configuración. Por el contrario, MaxText es una implementación simple y concreta de varios LLM que anima a los usuarios a ampliar bifurcando y editando directamente el código fuente.
Cuando se ejecuta un trabajo de programa único, datos múltiples (SPMD) en aceleradores, el proceso general puede bloquearse si hay algún error o si alguna máquina virtual se bloquea/bloquea por algún motivo. En este escenario, capturar seguimientos de la pila ayudará a identificar y solucionar los problemas de los trabajos que se ejecutan en las máquinas virtuales de TPU.
Las siguientes configuraciones ayudarán a depurar una falla o cuando un programa está bloqueado o colgado en algún lugar mediante la recopilación de seguimientos de la pila. Cambie los valores de los parámetros en consecuencia en MaxText/configs/base.yml
:
collect_stack_trace: True
para habilitar la recopilación de seguimientos de la pila en caso de fallas o cuando el programa se bloquea. Esta configuración volcará periódicamente los rastros del programa para ayudar en la depuración. Para deshabilitar esto, configure collect_stack_trace: False
.stack_trace_to_cloud: False
para mostrar seguimientos de pila en la consola. stack_trace_to_cloud: True
creará un archivo temporal en /tmp/debugging
en las TPU para almacenar los seguimientos de la pila. Hay un agente que se ejecuta en las máquinas virtuales de TPU que cargará periódicamente los seguimientos del directorio temporal al registro en la nube en el proyecto gcp. Puede ver los seguimientos en Logs Explorer en Cloud Logging mediante la siguiente consulta: logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
significa la duración en segundos entre cada evento de recopilación de seguimiento de pila. Configurar stack_trace_interval_seconds: 600
recopilará los seguimientos de la pila cada 600 segundos (10 minutos).Aquí está el paquete PyPI relacionado: https://pypi.org/project/cloud-tpu-diagnostics.
Para compilar su entrenamiento con anticipación, proporcionamos una herramienta train_compile.py
. Esta herramienta le permite compilar el train_step
principal en train.py
para el hardware de destino (por ejemplo, una gran cantidad de dispositivos v5e) sin utilizar el clúster completo.
Puede utilizar solo una CPU o una única máquina virtual de una familia diferente para precompilar un clúster de TPU. Esta compilación ayuda con dos objetivos principales:
Señalará cualquier información de falta de memoria (OOM), como cuando per_device_batch_size
está configurado en un valor demasiado alto, con un seguimiento de pila OOM idéntico al que se compilaría en el hardware de destino.
La compilación anticipada se puede guardar y luego cargar para tiempos de inicio y reinicio rápidos en el hardware de destino.
La herramienta train_compile.py
está estrechamente vinculada a train.py
y utiliza el mismo archivo de configuración configs/base.yml
. Aunque no es necesario ejecutarlo en una TPU, sí necesita instalar jax[tpu]
además de otras dependencias, por lo que recomendamos ejecutar setup.sh
para instalarlas si aún no lo ha hecho.
Después de instalar las dependencias enumeradas anteriormente, estará listo para compilar con anticipación:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
Esto compilará un modelo MaxText de 16B de parámetros en 2 pods v5e.
Aquí hay un ejemplo que guarda y luego carga el train_step
compilado, comenzando con guardar:
Paso 1: ejecute AOT y guarde la función compilada
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
Paso 2: ejecute train.py y cargue la función compilada
Para cargar el train_step compilado, solo necesita pasar compiled_trainstep_file=my_compiled_train.pickle
a train.py
:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
En el paso de guardar del ejemplo 2 anterior, incluimos la exportación del indicador del compilador LIBTPU_INIT_ARGS
y learning_rate
porque afectan el objeto compilado my_compiled_train.pickle.
Los tamaños del modelo (por ejemplo, global_parameter_scale
, max_sequence_length
y per_device_batch
) se fijan cuando compila inicialmente a través de compile_train.py
; verá un error de tamaño si intenta ejecutar el objeto compilado guardado con tamaños diferentes a los que compiló. Sin embargo, una nota sutil es que el programa de tasa de aprendizaje también se fija cuando ejecuta compile_train
, que está determinado tanto por steps
como por learning_rate
. Los parámetros del optimizador, como adam_b1
se pasan solo como objetos con forma al compilador; por lo tanto, sus valores reales se determinan cuando ejecuta train.py
, no durante la compilación. Si pasa diferentes formas (por ejemplo, per_device_batch
), recibirá un mensaje de error claro informando que la firma compilada tiene formas esperadas diferentes a las que se ingresó. Si intenta ejecutar en un hardware diferente al de los objetivos de compilación solicitados a través de compile_topology
, recibirá un error que indica que no se pudieron asignar los dispositivos compilados a sus dispositivos reales. El uso de indicadores XLA o un LIBTPU diferentes a los que se compiló probablemente se ejecutará silenciosamente con el entorno en el que compiló sin errores. Sin embargo, en este caso no existe un comportamiento garantizado; debes ejecutarlo en el mismo entorno en el que compilaste.
La compilación anticipada también es compatible con las GPU, con algunas diferencias con respecto a las TPU:
La GPU no admite la compilación entre hardware: aún se requiere un host de GPU para ejecutar la compilación de AoT, pero un único host de GPU puede compilar un programa para un clúster más grande del mismo hardware.
Para las GPU en la nube A3, el tamaño máximo de "porción" es un solo host y el parámetro compile_topology_num_slices
representa la cantidad de máquinas A3 para las que se precompilará.
Este ejemplo ilustra los indicadores que se utilizarán para una compilación de GPU multihost dirigida a un clúster de 4 hosts A3:
Paso 1: ejecute AOT y guarde la función compilada
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
Paso 2: ejecute train.py y cargue la función compilada
Para cargar el train_step compilado, solo necesita pasar compiled_trainstep_file=my_compiled_train.pickle
a train.py
:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Como en el caso de TPU, tenga en cuenta que el entorno de compilación debe coincidir con el entorno de ejecución, en este caso configurando el mismo XLA_FLAGS
.
MaxText admite la carga automática de registros recopilados en un directorio a una instancia de Tensorboard en Vertex AI. Siga la guía del usuario para saber más.