MaxText é um LLM de código aberto , altamente escalonável e de alto desempenho , escrito em Python/Jax puro e direcionado a TPUs e GPUs do Google Cloud para treinamento e inferência . MaxText atinge altos MFUs e escala desde um único host até clusters muito grandes, mantendo-se simples e "livre de otimização" graças ao poder do Jax e do compilador XLA.
MaxText pretende ser um ponto de lançamento para ambiciosos projetos de LLM tanto em pesquisa quanto em produção. Incentivamos os usuários a começar experimentando o MaxText pronto para uso e, em seguida, bifurcar e modificar o MaxText para atender às suas necessidades.
Usamos MaxText para demonstrar treinamento de alto desempenho e bem convergente em int8 e escalar o treinamento para chips de aproximadamente 51 mil.
Principais recursos suportados:
Na primeira vez que você executa o MaxText, fornecemos instruções específicas.
MaxText suporta treinamento e inferência de vários modelos abertos. Siga os guias do usuário na pasta de primeiros passos para saber mais.
Alguns guias extras úteis:
Além dos guias de introdução, há sempre outros recursos do MaxText que são constantemente adicionados! O conjunto completo de testes ponta a ponta está em end_to_end. Nós os executamos com uma cadência noturna. Eles podem ser uma boa fonte para entender o MaxText. Alternativamente, você pode ver os testes de unidade contínuos que são executados quase continuamente.
Mais detalhes sobre a reprodução desses resultados podem ser encontrados em MaxText/configs/README.md.
Nº de parâmetros | Tipo de acelerador | TFLOP/chip/s | Utilização de flops de modelo (MFU) |
---|---|---|---|
32B | v5p-128 | 3.28e+02 | 71,47% |
64B | v5p-128 | 3.23e+02 | 70,31% |
128B | v5p-256 | 3.15e+02 | 68,68% |
128B | v5p-512 | 3.15e+02 | 68,53% |
256B | v5p-1024 | 3.16e+02 | 68,82% |
512B | v5p-1024 | 2.94e+02 | 63,99% |
1024B | v5p-2048 | 2.49e+02 | 64,05% |
1024B | v5p-4096 | 2.97e+02 | 64,80% |
1160B | v5p-7680 | 2.95e+02 | 64,27% |
1160B | v5p-12288 | 3.04e+02 | 66,23% |
Para modelos 16B, 32B, 64B e 128B. Veja as configurações de execução completa em MaxText/configs/v5e/ como 16b.sh
, 32b.sh
, 64b.sh
, 128b.sh
.
Hardware | 16B TFLOP/seg/chip | 16B UMF | 32B TFLOP/seg/chip | 32B MFU | 64B TFLOP/seg/chip | 64B MFU | 128B TFLOP/seg/chip | 128B MFU |
---|---|---|---|---|---|---|---|---|
1x v5e-256 | 120 | 61,10% | 132 | 66,86% | 118 | 59,90% | 110 | 56,06% |
2x v5e-256 | 117 | 59,37% | 128 | 64,81% | 112 | 56,66% | 110 | 55,82% |
4x v5e-256 | 117 | 59,14% | 126 | 64,10% | 110 | 55,85% | 108 | 54,93% |
8x v5e-256 | 115 | 58,27% | 125 | 63,67% | 108 | 54,96% | 104 | 52,93% |
16x v5e-256 | 111 | 56,56% | 123 | 62,26% | 105 | 53,29% | 100 | 50,86% |
32x v5e-256 | 108 | 54,65% | 119 | 60,40% | 99 | 50,18% | 91 | 46,25% |
MaxText é fortemente inspirado em MinGPT/NanoGPT, elegantes implementações GPT autônomas escritas em PyTorch e voltadas para GPUs Nvidia. MaxText é mais complexo, suportando mais modelos padrão da indústria e podendo ser dimensionado para dezenas de milhares de chips. Em última análise, o MaxText tem um MFU mais de três vezes maior que os 17% relatados mais recentemente com essa base de código, é extremamente escalonável e implementa um cache de valor-chave para decodificação auto-regressiva eficiente.
MaxText é mais semelhante ao Nvidia/Megatron-LM, uma implementação LLM muito bem ajustada voltada para GPUs Nvidia. As duas implementações alcançam MFUs comparáveis. A diferença nas bases de código destaca as diferentes estratégias de programação. MaxText é Python puro, dependendo fortemente do compilador XLA para obter alto desempenho. Por outro lado, Megatron-LM é uma mistura de Python e CUDA, contando com kernels CUDA bem otimizados para alcançar alto desempenho.
MaxText também é comparável ao Pax. Assim como Pax, MaxText fornece implementações escalonáveis e de alto desempenho de LLMs em Jax. Pax se concentra em habilitar parâmetros de configuração poderosos, permitindo que os desenvolvedores alterem o modelo editando parâmetros de configuração. Por outro lado, MaxText é uma implementação simples e concreta de vários LLMs que incentiva os usuários a estender bifurcando e editando diretamente o código-fonte.
Ao executar um trabalho de Programa Único, Vários Dados (SPMD) em aceleradores, o processo geral pode travar se houver algum erro ou qualquer VM travar/travar por algum motivo. Neste cenário, a captura de rastreamentos de pilha ajudará a identificar e solucionar os problemas dos trabalhos em execução nas VMs da TPU.
As configurações a seguir ajudarão a depurar uma falha ou quando um programa estiver travado ou travado em algum lugar, coletando rastreamentos de pilha. Altere os valores dos parâmetros adequadamente em MaxText/configs/base.yml
:
collect_stack_trace: True
para ativar a coleta de rastreamentos de pilha em falhas ou quando o programa estiver travado. Esta configuração irá despejar periodicamente os rastreamentos do programa para ajudar na depuração. Para desabilitar isso, defina collect_stack_trace: False
.stack_trace_to_cloud: False
para exibir rastreamentos de pilha no console. stack_trace_to_cloud: True
criará um arquivo temporário em /tmp/debugging
nas TPUs para armazenar os rastreamentos de pilha. Há um agente em execução nas VMs da TPU que fará upload periódico dos rastreamentos do diretório temporário para o registro em nuvem no projeto gcp. É possível visualizar os traces no Logs Explorer no Cloud Logging usando a seguinte consulta: logName="projects/<project_name>/logs/tpu.googleapis.com%2Fruntime_monitor"
jsonPayload.verb="stacktraceanalyzer"
stack_trace_interval_seconds
significa a duração em segundos entre cada evento de coleta de rastreamento de pilha. Definir stack_trace_interval_seconds: 600
coletará os rastreamentos de pilha a cada 600 segundos (10 minutos).Aqui está o pacote PyPI relacionado: https://pypi.org/project/cloud-tpu-diagnostics.
Para compilar seu treinamento com antecedência, fornecemos uma ferramenta train_compile.py
. Esta ferramenta permite compilar o train_step
principal em train.py
para hardware de destino (por exemplo, um grande número de dispositivos v5e) sem usar o cluster completo.
Você pode usar apenas uma CPU ou uma única VM de uma família diferente para pré-compilar um cluster TPU. Esta compilação ajuda com dois objetivos principais:
Ele sinalizará qualquer informação de falta de memória (OOM), como quando per_device_batch_size
estiver definido muito alto, com um rastreamento de pilha OOM idêntico como se tivesse sido compilado no hardware de destino.
A compilação antecipada pode ser salva e carregada para tempos de inicialização e reinicialização rápidos no hardware de destino.
A ferramenta train_compile.py
está intimamente ligada a train.py
e usa o mesmo arquivo de configuração configs/base.yml
. Embora não seja necessário executar em uma TPU, você precisa instalar jax[tpu]
além de outras dependências, portanto, recomendamos executar setup.sh
para instalá-las, caso ainda não tenha feito isso.
Depois de instalar as dependências listadas acima, você estará pronto para compilar antecipadamente:
# Run the below on a single machine, e.g. a CPU
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2
global_parameter_scale=16 per_device_batch_size=4
Isso compilará um modelo MaxText de parâmetro 16B em 2 pods v5e.
Aqui está um exemplo que salva e carrega o train_step
compilado, começando com o save:
Etapa 1: execute AOT e salve a função compilada
# Run the below on a single machine, e.g. a CPU
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=v5e-256
compile_topology_num_slices=2
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
per_device_batch_size=4 steps=10000 learning_rate=1e-3
Etapa 2: execute train.py e carregue a função compilada
Para carregar o train_step compilado, você só precisa passar compiled_trainstep_file=my_compiled_train.pickle
para train.py
:
# Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256
export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Na etapa de salvamento do exemplo 2 acima, incluímos a exportação do sinalizador do compilador LIBTPU_INIT_ARGS
e learning_rate
porque eles afetam o objeto compilado my_compiled_train.pickle.
Os tamanhos do modelo (por exemplo, global_parameter_scale
, max_sequence_length
e per_device_batch
) são fixos quando você compila inicialmente via compile_train.py
, você verá um erro de tamanho se tentar executar o objeto compilado salvo com tamanhos diferentes daqueles com os quais você compilou. No entanto, uma observação sutil é que a programação da taxa de aprendizagem também é fixa quando você executa compile_train
- que é determinado por steps
e learning_rate
. Os parâmetros do otimizador, como adam_b1
são passados apenas como objetos moldados para o compilador - portanto, seus valores reais são determinados quando você executa train.py
, não durante a compilação. Se você passar formas diferentes (por exemplo, per_device_batch
), receberá uma mensagem de erro clara informando que a assinatura compilada tem formas esperadas diferentes daquelas que foram inseridas. Se você tentar executar em hardware diferente dos destinos de compilação solicitados via compile_topology
, você receberá um erro informando que há uma falha ao mapear os dispositivos compilados para seus dispositivos reais. Usar sinalizadores XLA ou LIBTPU diferentes daqueles que foram compilados provavelmente será executado silenciosamente com o ambiente em que você compilou, sem erros. Contudo não há comportamento garantido neste caso; você deve executar no mesmo ambiente em que compilou.
A compilação antecipada também é suportada para GPUs com algumas diferenças em relação às TPUs:
A GPU não oferece suporte à compilação em hardware: um host GPU ainda é necessário para executar a compilação AoT, mas um único host GPU pode compilar um programa para um cluster maior do mesmo hardware.
Para GPUs A3 Cloud, o tamanho máximo da "fatia" é um único host, e o parâmetro compile_topology_num_slices
representa o número de máquinas A3 para pré-compilar.
Este exemplo ilustra os sinalizadores a serem usados para uma compilação de GPU multihost visando um cluster de 4 hosts A3:
Etapa 1: execute AOT e salve a função compilada
# Run the below on a single A3 machine
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train_compile.py MaxText/configs/base.yml compile_topology=a3
compile_topology_num_slices=4
compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16
attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3
Etapa 2: execute train.py e carregue a função compilada
Para carregar o train_step compilado, você só precisa passar compiled_trainstep_file=my_compiled_train.pickle
para train.py
:
# Run the below on each of the 4 target A3 hosts.
export XLA_FLAGS="--xla_gpu_enable_async_collectives=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=example_load_compile
compiled_trainstep_file=my_compiled_train.pickle
attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3
base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket
Assim como no caso da TPU, observe que o ambiente de compilação deve corresponder ao ambiente de execução, neste caso definindo o mesmo XLA_FLAGS
.
MaxText oferece suporte ao upload automático de logs coletados em um diretório para uma instância do Tensorboard no Vertex AI. Siga o guia do usuário para saber mais.