@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Python 版本 >= 3.6
PyTorch 版本 >= 1.0.0
配置参数解析 >= 0.14
为了训练新模型,您还需要 NVIDIA GPU 和 NCCL
代码库
从源代码安装 fairseq 并在本地开发:
pip install --可编辑。
定制模块
我们还需要构建lightconv
和dynamicconv
以支持 GPU。
Lightconv_layer
cd fairseq/模块/lightconv_layer python cuda_function_gen.py python setup.py 安装
动态卷积层
cd fairseq/模块/dynamicconv_layer python cuda_function_gen.py python setup.py 安装
我们遵循 fairseq 中的数据准备。要下载并预处理数据,可以运行
bash 配置/iwslt14.de-en/prepare.sh
我们遵循fairseq中的数据预处理。 要下载并预处理数据,可以运行
bash 配置/wmt14.en-fr/prepare.sh
我们遵循fairseq中的数据预处理。首先应从 Google 提供的 Google Drive 下载预处理后的数据。要对数据进行二值化,可以运行
bash configs/wmt16.en-de/prepare.sh [下载的 zip 文件的路径]
由于语言模型任务有很多额外的代码,我们将其放在另一个分支: language-model
。我们遵循fairseq中的数据预处理。 要下载并预处理数据,可以运行
git checkout 语言模型 bash 配置/wikitext-103/prepare.sh
例如,要在 WMT'14 En-Fr 上测试模型,可以运行
configs/wmt14.en-fr/test.sh [模型检查点的路径] [gpu-id] [test|valid]
例如,要在 GPU 0 上评估 Lite Transformer(在 WMT'14 En-Fr 测试集上获得 BLEU 分数),可以运行
configs/wmt14.en-fr/test.sh embed496/ 0 测试
我们在底部提供了几个预训练模型。您可以通过以下方式下载模型并解压文件
tar -xzvf [文件名]
我们提供了几个使用此存储库训练 Lite Transformer 的示例:
要在 WMT'14 En-Fr(具有 8 个 GPU)上训练 Lite Transformer,可以运行
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
要使用较少的 GPU(例如 4 个 GPU)训练 Lite Transformer,可以运行
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
一般来说,要训练模型,可以运行
python train.py [数据二进制文件的路径] --configs [配置文件的路径] [覆盖选项]
请注意, --update-freq
应根据 GPU 数量进行调整(8 个 GPU 为 16,4 个 GPU 为 32)。
以分布式方式训练 Lite Transformer。例如,在两个 GPU 节点上,总共有 16 个 GPU。
# 在 host1python -m torch.distributed.launch 上 --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=主机1 --master_port=8080 train.py 数据/二进制/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --分布式无生成 --update-freq 8# 在 host2python -m torch.distributed.launch 上 --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=主机1 --master_port=8080 train.py 数据/二进制/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --分布式无生成 --更新频率 8
我们提供了论文中报道的 Lite Transformer 的检查点:
数据集 | #多添加 | 测试成绩 | 模型和测试集 |
---|---|---|---|
WMT'14 英语-法语 | 90M | 35.3 | 下载 |
360M | 39.1 | 下载 | |
527M | 39.6 | 下载 | |
WMT'16 恩-德 | 90M | 22.5 | 下载 |
360M | 25.6 | 下载 | |
527M | 26.5 | 下载 | |
美国有线电视新闻网/每日邮报 | 800M | 38.3(RL) | 下载 |
维基文本-103 | 1147M | 22.2(PPL) | 下载 |