Uma biblioteca haiku usando os operadores xmap
/ pjit
em JAX para paralelismo de modelo de transformadores.
O esquema de paralelismo é semelhante ao Megatron-LM original, que é eficiente em TPUs devido à rede mesh 2d de alta velocidade. Há também uma versão experimental do modelo que implementa fragmentação no estilo ZeRo.
Esta biblioteca foi projetada para escalabilidade de até aproximadamente 40B de parâmetros em TPUv3s, além dos quais diferentes estratégias de paralelismo devem ser utilizadas. Veja outras implementações como GPT-NeoX ou DeepSpeed para isso.
Uma direção futura para pesquisa é integrar esta base de código com swarm-jax, para alcançar maior escalabilidade com paralelismo de pipeline.
12/07/21 : Adicionado guia para ajuste fino
Um modelo de geração de texto autoregressivo de 6 bilhões de parâmetros treinado no The Pile.
Baixe pesos finos (somente pesos bf16, para inferência, 9 GB)
Baixe pesos completos (incluindo parâmetros do otimizador, 61 GB)
Pontos de verificação parcialmente treinados
Demonstração do Colab
Demonstração na Web
Postagem do blog de Aran
Este projeto não teria sido possível sem a computação generosamente fornecida pela TPU Research Cloud com a assistência da EleutherAI.
Agradecemos à equipe do Cloud TPU do Google por fornecer acesso antecipado à VM alfa do Cloud TPU (agora disponível publicamente!)
Obrigado a todos que ajudaram de uma forma ou de outra (listados em ordem alfabética):
Os pesos do GPT-J-6B são licenciados sob a versão 2.0 da Licença Apache.
Hiperparâmetro | Valor |
---|---|
n_parâmetros | 6.053.381.344 |
n_camadas | 28* |
d_modelo | 4.096 |
d_ff | 16.384 |
n_cabeças | 16 |
d_head | 256 |
n_ctx | 2.048 |
n_vocab | 50.257 (mesmo tokenizador do GPT-2/3) |
codificação de posição | Codificações de posição rotativa (RoPE) |
Dimensões do RoPE | 64 |
*
cada camada consiste em um bloco feedforward e um bloco de autoatenção
O modelo consiste em 28 camadas com uma dimensão de modelo de 4.096 e uma dimensão feedforward de 16.384. A dimensão do modelo é dividida em 16 cabeças, cada uma com uma dimensão de 256. Codificações de posição rotativa (RoPE) foram aplicadas a 64 dimensões de cada cabeça. . O modelo é treinado com um vocabulário de tokenização de 50257, usando o mesmo conjunto de BPEs do GPT-2/GPT-3.
Modelos classificados aproximadamente por desempenho ou por FLOPs, se não estiverem disponíveis.
Modelo | Pesos | Treinamento de FLOPs | LAMBADA PPL ↓ | LAMBADA Acc ↑ | Winogrande ↑ | Hellaswag ↑ | PIQA ↑ | Tamanho do conjunto de dados (GB) |
---|---|---|---|---|---|---|---|---|
Chance | ✔ | 0 | ~muito | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9,95 | 51,6% | 52,9% | 43,4% | 70,5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10,63 | 51,21% | 59,4% | 50,9% | 70,8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7h50 | 57,2% | 55,0% | 48,9% | 71,1% | 825 |
Megatron-2.5B* | ✘ | 2.4e21 | ----- | 61,7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5,63 | 62,2% | 56,5% | 55,8% | 73,0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5,44 | 63,6% | 58,7% | 54,7% | 75,1% | ~800 |
GPT-3-Babbage‡ | ✘ | ----- | 5,58 | 62,4% | 59,0% | 54,5% | 75,5% | ----- |
Megatron-8.3B* | ✘ | 7.8e21 | ----- | 66,5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4,60 | 67,1% | 62,3% | 62,8% | 75,6% | ~800 |
Megatron-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3,99 | 69,7% | 65,3% | 66,1% | 76,5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4h00 | 70,3% | 64,5% | 67,4% | 78,0% | ~800 |
GPT-3-Curie‡ | ✘ | ----- | 4h00 | 69,3% | 65,6% | 68,5% | 77,9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3,56 | 72,5% | 67,9% | 70,9% | 78,5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3h00 | 76,2% | 70,2% | 78,9% | 81,0% | ~800 |
GPT-3-Davinci‡ | ✘ | ----- | 3,0 | 75% | 72% | 78% | 80% | ----- |
Esquilo 230B* | ✘ | 6.31E+23 | ----- | 74,50% | 70,10% | 79,20% | 81,80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76,6% | 73,0% | 80,2% | 82,0% | ----- |
*
representa os números de avaliação relatados por seus respectivos autores, todos os outros números são fornecidos executando o lm-evaluation-harness com os pesos liberados ou com acesso à API. Devido a diferenças sutis de implementação, bem como diferentes enquadramentos de tarefas zero shot, estes podem não ser diretamente comparáveis. Veja esta postagem do blog para mais detalhes.
†
O modelo Megatron-11B não fornece métricas comparáveis, e diversas implementações que utilizam os pesos liberados não reproduzem a qualidade e as avaliações da geração. (ver 1 2 3) Assim, a avaliação não foi tentada.
‡
Esses modelos foram treinados com dados que contêm possível contaminação do conjunto de testes. Os modelos OpenAI GPT-3 não conseguiram desduplicar os dados de treinamento para determinados conjuntos de testes, enquanto os modelos GPT-Neo, bem como este, são treinados no The Pile, que não foi desduplicado em nenhum conjunto de testes.
A maioria dos scripts neste repositório são projetados para serem executados em TPUs, que na arquitetura TPU-VM são máquinas virtuais que podem executar código arbitrário. A maioria dos scripts é projetada para ativar um TPU, SSH nele para configurar as dependências e copiar o código do diretório local e, em seguida, iniciar um trabalhador Ray que pode aceitar chamadas RPC.
Os TPUVMs lidam com a execução de etapas de treinamento e avaliação do modelo, salvamento e carregamento de pontos de verificação, enquanto o programa driver python lida com carregamento de dados e orquestração geral (como quando salvar pontos de verificação, etc.).
Isso significa que a maioria dos scripts ( train.py
, eval_harness.py
etc) espera estar em execução em uma máquina virtual GCE na mesma região que as TPUs, para minimizar a latência RPC e o custo de transferência de dados. Outros scripts (geralmente aqueles que não aceitam um argumento --tpu
, como device_sample.py
, device_serve.py
ou device_train.py
) esperam ser executados diretamente em uma TPUVM. Os scripts device_* funcionam apenas em v3-8 e não em pods maiores.
Além disso, há um exemplo ( resharding_example.py
) de como converter os pontos de verificação fornecidos (que possuem 8 fragmentos no caso do GPT-J-6B) para um número menor, como quando executado em GPU(s).
Para ajustar o modelo, execute device_train.py
em uma VM TPU. Usando uma TPU v3-8, você pode fazer o ajuste fino a uma taxa de aproximadamente 5.000 tokens/segundo, o que deve ser suficiente para conjuntos de dados de pequeno a médio porte.
Leia o guia passo a passo para obter instruções completas de ajuste fino.
Observe que esta biblioteca possui alguns requisitos específicos para a versão JAX. Especificamente, para usar os modelos v1 (incluindo GPT-J 6B), é necessário jax==0.2.12
. Isso, por sua vez, depende de jaxlib==0.1.68
. Se isso não for feito, você receberá erros enigmáticos de xmap
No entanto, para usar o código do modelo v2 (sem pesos divulgados publicamente), a versão JAX mais recente pode ser usada.
Para citar este repositório:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Para citar os pesos do GPT-J-6B:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
Se você usar este repositório ou qualquer um dos pesos pré-treinados para fazer algo legal, adoraríamos saber mais sobre isso. Sinta-se à vontade para abrir um problema no github ou entrar em contato por e-mail (no perfil).