@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Python-Version >= 3.6
PyTorch-Version >= 1.0.0
configparse >= 0,14
Zum Trainieren neuer Modelle benötigen Sie außerdem eine NVIDIA-GPU und NCCL
Codebasis
So installieren Sie fairseq aus dem Quellcode und entwickeln es lokal:
pip install --editable .
Maßgeschneiderte Module
Wir müssen auch lightconv
und dynamicconv
für die GPU-Unterstützung erstellen.
Lightconv_layer
cd fairseq/modules/lightconv_layer Python cuda_function_gen.py Python setup.py installieren
Dynamicconv_layer
cd fairseq/modules/dynamicconv_layer Python cuda_function_gen.py Python setup.py installieren
Wir verfolgen die Datenaufbereitung in fairseq. Um die Daten herunterzuladen und vorzuverarbeiten, kann man Folgendes ausführen
bash configs/iwslt14.de-en/prepare.sh
Wir verfolgen die Datenvorverarbeitung in fairseq. Um die Daten herunterzuladen und vorzuverarbeiten, kann man Folgendes ausführen
bash configs/wmt14.en-fr/prepare.sh
Wir verfolgen die Datenvorverarbeitung in fairseq. Man sollte zunächst die vorverarbeiteten Daten vom von Google bereitgestellten Google Drive herunterladen. Um die Daten zu binarisieren, kann man Folgendes ausführen
bash configs/wmt16.en-de/prepare.sh [Pfad zur heruntergeladenen ZIP-Datei]
Da die Sprachmodellaufgabe viele zusätzliche Codes enthält, platzieren wir sie in einem anderen Zweig: language-model
. Wir verfolgen die Datenvorverarbeitung in fairseq. Um die Daten herunterzuladen und vorzuverarbeiten, kann man Folgendes ausführen
Git Checkout Sprachmodell bash configs/wikitext-103/prepare.sh
Um beispielsweise die Modelle auf der WMT'14 En-Fr zu testen, kann man sie ausführen
configs/wmt14.en-fr/test.sh [Pfad zu den Modellprüfpunkten] [gpu-id] [test|valid]
Um beispielsweise Lite Transformer auf GPU 0 zu evaluieren (mit dem BLEU-Score im Testsatz von WMT'14 En-Fr), kann man es ausführen
configs/wmt14.en-fr/test.sh embed496/ 0 test
Unten stellen wir mehrere vorab trainierte Modelle zur Verfügung. Sie können das Modell herunterladen und die Datei extrahieren
tar -xzvf [Dateiname]
Wir haben mehrere Beispiele zum Trainieren von Lite Transformer mit diesem Repo bereitgestellt:
Um Lite Transformer auf WMT'14 En-Fr (mit 8 GPUs) zu trainieren, kann man es ausführen
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
Um Lite Transformer mit weniger GPUs, z. B. 4 GPUS, zu trainieren, kann man es ausführen
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
Um ein Modell zu trainieren, kann man im Allgemeinen laufen
python train.py [Pfad zur Datenbinärdatei] --configs [Pfad zur Konfigurationsdatei] [Optionen überschreiben]
Beachten Sie, dass --update-freq
entsprechend der GPU-Nummern angepasst werden sollte (16 für 8 GPUs, 32 für 4 GPUs).
Um Lite Transformer auf verteilte Weise zu trainieren. Zum Beispiel auf zwei GPU-Knoten mit insgesamt 16 GPUs.
# Auf host1python -m Torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=host1 --master_port=8080 train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distributed-no-spawn --update-freq 8# Auf host2python -m Torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=host1 --master_port=8080 train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distributed-no-spawn --update-freq 8
Wir stellen die im Artikel genannten Prüfpunkte für unseren Lite Transformer bereit:
Datensatz | #Mult-Adds | Testergebnis | Modell und Testset |
---|---|---|---|
WMT'14 En-Fr | 90M | 35.3 | herunterladen |
360M | 39.1 | herunterladen | |
527M | 39.6 | herunterladen | |
WMT'16 En-De | 90M | 22.5 | herunterladen |
360M | 25.6 | herunterladen | |
527M | 26.5 | herunterladen | |
CNN / DailyMail | 800M | 38,3 (RL) | herunterladen |
WIKITEXT-103 | 1147M | 22,2 (PPL) | herunterladen |