@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Versão Python >= 3.6
Versão PyTorch >= 1.0.0
configurarparse >= 0,14
Para treinar novos modelos, você também precisará de uma GPU NVIDIA e NCCL
Base de código
Para instalar o fairseq da fonte e desenvolver localmente:
pip instalar --editável.
Módulos Personalizados
Também precisamos construir lightconv
e dynamicconv
para suporte de GPU.
Lightconv_layer
cd fairseq/modules/lightconv_layer python cuda_function_gen.py instalação do python setup.py
Camada_conv dinâmica
cd fairseq/modules/dynamicconv_layer python cuda_function_gen.py instalação do python setup.py
Acompanhamos a preparação dos dados no fairseq. Para baixar e pré-processar os dados, pode-se executar
configurações do bash/iwslt14.de-en/prepare.sh
Seguimos o pré-processamento dos dados em fairseq. Para baixar e pré-processar os dados, pode-se executar
configurações do bash/wmt14.en-fr/prepare.sh
Seguimos o pré-processamento dos dados em fairseq. Deve-se primeiro baixar os dados pré-processados do Google Drive fornecido pelo Google. Para binarizar os dados, pode-se executar
bash configs/wmt16.en-de/prepare.sh [caminho para o arquivo zip baixado]
Como a tarefa de modelo de linguagem possui muitos códigos adicionais, colocamos ela em outro ramo: language-model
. Seguimos o pré-processamento dos dados em fairseq. Para baixar e pré-processar os dados, pode-se executar
modelo de linguagem git checkout configurações do bash/wikitext-103/prepare.sh
Por exemplo, para testar os modelos no WMT'14 En-Fr, pode-se executar
configs/wmt14.en-fr/test.sh [caminho para os pontos de verificação do modelo] [gpu-id] [test|valid]
Por exemplo, para avaliar o Lite Transformer na GPU 0 (com a pontuação BLEU no conjunto de testes do WMT'14 En-Fr), pode-se executar
configs/wmt14.en-fr/test.sh embed496/ 0 teste
Fornecemos vários modelos pré-treinados na parte inferior. Você pode baixar o modelo e extrair o arquivo clicando
tar -xzvf [nome do arquivo]
Fornecemos vários exemplos para treinar o Lite Transformer com este repositório:
Para treinar o Lite Transformer no WMT'14 En-Fr (com 8 GPUs), pode-se executar
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
Para treinar o Lite Transformer com menos GPUs, por exemplo, 4 GPUS, pode-se executar
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
Em geral, para treinar um modelo, pode-se executar
python train.py [caminho para o binário de dados] --configs [caminho para o arquivo de configuração] [opções de substituição]
Observe que --update-freq
deve ser ajustado de acordo com os números da GPU (16 para 8 GPUs, 32 para 4 GPUs).
Treinar Lite Transformer de forma distribuída. Por exemplo, em dois nós de GPU com um total de 16 GPUs.
# Em host1python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=host1 --master_port=8080 train.py dados/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distribuído-no-spawn --update-freq 8# Em host2python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=host1 --master_port=8080 train.py dados/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distribuído-no-spawn --update-freq 8
Fornecemos os pontos de verificação para nosso Lite Transformer relatados no artigo:
Conjunto de dados | #Mult-Adições | Pontuação do teste | Modelo e conjunto de testes |
---|---|---|---|
WMT'14 En-Fr | 90 milhões | 35,3 | download |
360M | 39,1 | download | |
527 milhões | 39,6 | download | |
WMT'16 En-De | 90 milhões | 22,5 | download |
360M | 25,6 | download | |
527 milhões | 26,5 | download | |
CNN/DailyMail | 800 milhões | 38,3 (RL) | download |
WIKITEXT-103 | 1147M | 22.2 (PPL) | download |