@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Versión de Python >= 3.6
Versión de PyTorch >= 1.0.0
configurargparse >= 0.14
Para entrenar nuevos modelos, también necesitarás una GPU NVIDIA y NCCL
Base de código
Para instalar fairseq desde el código fuente y desarrollar localmente:
instalación de pip --editable.
Módulos personalizados
También necesitamos crear lightconv
dynamicconv
para compatibilidad con GPU.
capa_conv_luz
cd fairseq/modules/lightconv_layer Python cuda_function_gen.py instalación de python setup.py
Capa_conv_dinámica
cd fairseq/modules/dynamicconv_layer Python cuda_function_gen.py instalación de python setup.py
Seguimos la preparación de datos en fairseq. Para descargar y preprocesar los datos, se puede ejecutar
bash configs/iwslt14.de-en/prepare.sh
Seguimos el preprocesamiento de datos en fairseq. Para descargar y preprocesar los datos, se puede ejecutar
bash configs/wmt14.en-fr/prepare.sh
Seguimos el preprocesamiento de datos en fairseq. Primero se deben descargar los datos preprocesados de Google Drive proporcionado por Google. Para binarizar los datos, se puede ejecutar
bash configs/wmt16.en-de/prepare.sh [ruta al archivo zip descargado]
Como la tarea del modelo de lenguaje tiene muchos códigos adicionales, la ubicamos en otra rama: language-model
. Seguimos el preprocesamiento de datos en fairseq. Para descargar y preprocesar los datos, se puede ejecutar
modelo de lenguaje de pago git configuraciones de bash/wikitext-103/prepare.sh
Por ejemplo, para probar los modelos en WMT'14 En-Fr, se puede ejecutar
configs/wmt14.en-fr/test.sh [ruta a los puntos de control del modelo] [gpu-id] [prueba|válido]
Por ejemplo, para evaluar Lite Transformer en GPU 0 (con la puntuación BLEU en el conjunto de pruebas de WMT'14 En-Fr), se puede ejecutar
configs/wmt14.en-fr/test.sh embed496/ 0 prueba
Proporcionamos varios modelos previamente entrenados en la parte inferior. Puede descargar el modelo y extraer el archivo haciendo
tar -xzvf [nombre de archivo]
Proporcionamos varios ejemplos para entrenar Lite Transformer con este repositorio:
Para entrenar Lite Transformer en WMT'14 En-Fr (con 8 GPU), se puede ejecutar
python train.py datos/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
Para entrenar Lite Transformer con menos GPU, por ejemplo, 4 GPUS, se puede ejecutar
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
En general, para entrenar un modelo, se puede ejecutar
python train.py [ruta al binario de datos] --configs [ruta al archivo de configuración] [opciones de anulación]
Tenga en cuenta que --update-freq
debe ajustarse según los números de GPU (16 para 8 GPU, 32 para 4 GPU).
Entrenar Lite Transformer de manera distribuida. Por ejemplo, en dos nodos de GPU con un total de 16 GPU.
# En host1python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=host1 --master_port=8080 datos train.py/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distribuido-sin-generación --update-freq 8# En host2python -m torch.distributed.launch --nproc_per_node=8 --nnodos=2 --node_rank=1 --master_addr=host1 --master_port=8080 datos train.py/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distribuido-sin-generación --frecuencia de actualización 8
Proporcionamos los puntos de control para nuestro Lite Transformer informados en el documento:
Conjunto de datos | #Agregaciones múltiples | Puntuación de la prueba | Modelo y conjunto de prueba |
---|---|---|---|
WMT'14 En-Fr | 90M | 35.3 | descargar |
360M | 39.1 | descargar | |
527M | 39,6 | descargar | |
WMT'16 En-De | 90M | 22,5 | descargar |
360M | 25.6 | descargar | |
527M | 26,5 | descargar | |
CNN/Correo diario | 800M | 38.3 (RL) | descargar |
WIKITEXT-103 | 1147M | 22.2 (PLP) | descargar |