@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Python 版本 >= 3.6
PyTorch 版本 >= 1.0.0
配置參數解析 >= 0.14
為了訓練新模型,您還需要 NVIDIA GPU 和 NCCL
程式碼庫
從原始碼安裝 fairseq 並在本地開發:
pip install --可編輯。
客製化模組
我們還需要建立lightconv
和dynamicconv
以支援 GPU。
Lightconv_layer
cd fairseq/模組/lightconv_layer python cuda_function_gen.py python setup.py 安裝
動態卷積層
cd fairseq/模組/dynamicconv_layer python cuda_function_gen.py python setup.py 安裝
我們遵循 fairseq 中的資料準備。要下載並預處理數據,可以運行
bash 配置/iwslt14.de-en/prepare.sh
我們遵循fairseq中的資料預處理。 要下載並預處理數據,可以運行
bash 設定/wmt14.en-fr/prepare.sh
我們遵循fairseq中的資料預處理。首先應從 Google 提供的 Google Drive 下載預處理後的資料。要對資料進行二值化,可以運行
bash configs/wmt16.en-de/prepare.sh [下載的 zip 檔案的路徑]
由於語言模型任務有很多額外的程式碼,我們將其放在另一個分支: language-model
。我們遵循fairseq中的資料預處理。 要下載並預處理數據,可以運行
git checkout 語言模型 bash 配置/wikitext-103/prepare.sh
例如,要在 WMT'14 En-Fr 上測試模型,可以運行
configs/wmt14.en-fr/test.sh [模型檢查點的路徑] [gpu-id] [test|valid]
例如,要在 GPU 0 上評估 Lite Transformer(在 WMT'14 En-Fr 測試集上獲得 BLEU 分數),可以運行
configs/wmt14.en-fr/test.sh embed496/ 0 測試
我們在底部提供了幾個預訓練模型。您可以透過以下方式下載模型並解壓縮文件
tar -xzvf [檔名]
我們提供了幾個使用此儲存庫訓練 Lite Transformer 的範例:
要在 WMT'14 En-Fr(具有 8 個 GPU)上訓練 Lite Transformer,可以運行
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
要使用較少的 GPU(例如 4 個 GPU)訓練 Lite Transformer,可以運行
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 322
一般來說,要訓練模型,可以運行
python train.py [資料二進位檔案的路徑] --configs [設定檔的路徑] [覆寫選項]
請注意, --update-freq
應根據 GPU 數量進行調整(8 個 GPU 為 16,4 個 GPU 為 32)。
以分散式方式訓練 Lite Transformer。例如,在兩個 GPU 節點上,總共有 16 個 GPU。
# 在 host1python -m torch.distributed.launch 上 --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=主機1 --master_port=8080 train.py 資料/二進位/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --分散式無生成 --update-freq 8# 在 host2python -m torch.distributed.launch 上 --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=主機1 --master_port=8080 train.py 資料/二進位/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --分散式無生成 --更新頻率 8
我們提供了論文中報導的 Lite Transformer 的檢查點:
數據集 | #多添加 | 測驗成績 | 模型和測試集 |
---|---|---|---|
WMT'14 英語-法語 | 90M | 35.3 | 下載 |
360M | 39.1 | 下載 | |
527M | 39.6 | 下載 | |
WMT'16 恩-德 | 90M | 22.5 | 下載 |
360M | 25.6 | 下載 | |
527M | 26.5 | 下載 | |
CNN/每日郵報 | 800M | 38.3(RL) | 下載 |
維基文本-103 | 1147M | 22.2(PPL) | 下載 |