@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Version Python >= 3.6
Version PyTorch >= 1.0.0
configurationparse >= 0,14
Pour entraîner de nouveaux modèles, vous aurez également besoin d'un GPU NVIDIA et d'un NCCL
Base de code
Pour installer fairseq à partir des sources et développer localement :
pip install --editable .
Modules personnalisés
Nous devons également créer lightconv
et dynamicconv
pour la prise en charge du GPU.
Lightconv_layer
cd fairseq/modules/lightconv_layer python cuda_function_gen.py installation de python setup.py
Dynamicconv_layer
cd fairseq/modules/dynamicconv_layer python cuda_function_gen.py installation de python setup.py
Nous suivons la préparation des données dans fairseq. Pour télécharger et prétraiter les données, on peut exécuter
bash configs/iwslt14.de-en/prepare.sh
Nous suivons le prétraitement des données dans fairseq. Pour télécharger et prétraiter les données, on peut exécuter
bash configs/wmt14.en-fr/prepare.sh
Nous suivons le prétraitement des données dans fairseq. Il faut d'abord télécharger les données prétraitées à partir du Google Drive fourni par Google. Pour binariser les données, on peut exécuter
bash configs/wmt16.en-de/prepare.sh [chemin d'accès au fichier zip téléchargé]
Comme la tâche de modèle de langage comporte de nombreux codes supplémentaires, nous la plaçons dans une autre branche : language-model
. Nous suivons le prétraitement des données dans fairseq. Pour télécharger et prétraiter les données, on peut exécuter
modèle de langage de paiement git bash configs/wikitext-103/prepare.sh
Par exemple, pour tester les modèles sur WMT'14 En-Fr, on peut lancer
configs/wmt14.en-fr/test.sh [chemin d'accès aux points de contrôle du modèle] [gpu-id] [test|valid]
Par exemple, pour évaluer Lite Transformer sur GPU 0 (avec le score BLEU sur l'ensemble de tests de WMT'14 En-Fr), on peut exécuter
configs/wmt14.en-fr/test.sh embed496/ 0 test
Nous proposons plusieurs modèles pré-entraînés en bas. Vous pouvez télécharger le modèle et extraire le fichier en
tar -xzvf [nom du fichier]
Nous avons fourni plusieurs exemples pour entraîner Lite Transformer avec ce dépôt :
Pour entraîner Lite Transformer sur WMT'14 En-Fr (avec 8 GPU), on peut exécuter
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
Pour entraîner Lite Transformer avec moins de GPU, par exemple 4 GPUS, on peut exécuter
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
En général, pour entraîner un modèle, on peut exécuter
python train.py [chemin d'accès au binaire de données] --configs [chemin d'accès au fichier de configuration] [options de remplacement]
Notez que --update-freq
doit être ajusté en fonction des numéros de GPU (16 pour 8 GPU, 32 pour 4 GPU).
Pour former Lite Transformer de manière distribuée. Par exemple sur deux nœuds GPU avec un total de 16 GPU.
# Sur host1python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=hôte1 --master_port=8080 train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distributed-no-spawn --update-freq 8# Sur host2python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=hôte1 --master_port=8080 train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distributed-no-spawn --update-freq 8
Nous fournissons les points de contrôle pour notre Lite Transformer rapportés dans le document :
Ensemble de données | #Mult-Ajouts | Résultat du test | Modèle et ensemble de tests |
---|---|---|---|
WMT'14 En-Fr | 90M | 35.3 | télécharger |
360M | 39.1 | télécharger | |
527M | 39,6 | télécharger | |
WMT'16 Fr-De | 90M | 22,5 | télécharger |
360M | 25.6 | télécharger | |
527M | 26,5 | télécharger | |
CNN / DailyMail | 800M | 38.3 (RL) | télécharger |
WIKITEXTE-103 | 1147M | 22.2 (LPP) | télécharger |