@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Python バージョン >= 3.6
PyTorch バージョン >= 1.0.0
configargparse >= 0.14
新しいモデルをトレーニングするには、NVIDIA GPU と NCCL も必要になります。
コードベース
ソースから Faireq をインストールし、ローカルで開発するには:
pip install --editable 。
カスタマイズされたモジュール
GPU をサポートするためにlightconv
とdynamicconv
ビルドする必要もあります。
Lightconv_layer
cd Fairseq/モジュール/lightconv_layer Python cuda_function_gen.py Python setup.py インストール
Dynamicconv_layer
cd Fairseq/modules/dynamicconv_layer Python cuda_function_gen.py Python setup.py インストール
Fairseq でのデータ準備に従います。データをダウンロードして前処理するには、次のコマンドを実行します。
bash configs/iwslt14.de-en/prepare.sh
Fairseq でのデータの前処理に従います。 データをダウンロードして前処理するには、次のコマンドを実行します。
bash configs/wmt14.en-fr/prepare.sh
Fairseq でのデータの前処理に従います。まず、Google が提供する Google ドライブから前処理されたデータをダウンロードする必要があります。データを二値化するには、次のコマンドを実行できます。
bash configs/wmt16.en-de/prepare.sh [ダウンロードした zip ファイルへのパス]
言語モデル タスクには多くの追加コードがあるため、それを別のブランチlanguage-model
に配置します。 Fairseq でのデータの前処理に従います。 データをダウンロードして前処理するには、次のコマンドを実行します。
git checkout 言語モデル bash configs/wikitext-103/prepare.sh
たとえば、WMT'14 En-Fr でモデルをテストするには、次のように実行できます。
configs/wmt14.en-fr/test.sh [モデル チェックポイントへのパス] [gpu-id] [test|valid]
たとえば、(WMT'14 En-Fr のテスト セットの BLEU スコアを使用して) GPU 0 で Lite Transformer を評価するには、次のコマンドを実行できます。
configs/wmt14.en-fr/test.sh embed496/ 0 テスト
下部にはいくつかの事前トレーニング済みモデルが用意されています。モデルをダウンロードしてファイルを抽出できます。
tar -xzvf [ファイル名]
このリポジトリで Lite Transformer をトレーニングするための例をいくつか提供しました。
WMT'14 En-Fr (8 GPU 搭載) で Lite Transformer をトレーニングするには、次のコマンドを実行します。
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
少ない GPU (例: 4 GPUS) で Lite Transformer をトレーニングするには、次のコマンドを実行します。
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
一般に、モデルをトレーニングするには、次のように実行できます。
python train.py [データバイナリへのパス] --configs [設定ファイルへのパス] [オーバーライドオプション]
--update-freq
GPU 番号 (8 GPU の場合は 16、4 GPU の場合は 32) に応じて調整する必要があることに注意してください。
Lite Transformer を分散方式でトレーニングします。たとえば、合計 16 個の GPU を備えた 2 つの GPU ノードの場合です。
# host1python -m torch.distributed.launch 上 --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=host1 --master_port=8080 train.py データ/バイナリ/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distributed-no-spawn --update-freq 8# host2python -m torch.distributed.launch 上 --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=host1 --master_port=8080 train.py データ/バイナリ/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --distributed-no-spawn --更新頻度 8
論文で報告されている Lite Transformer のチェックポイントを提供します。
データセット | #複数追加 | テストのスコア | モデルとテストセット |
---|---|---|---|
WMT'14 日-フランス | 90M | 35.3 | ダウンロード |
360M | 39.1 | ダウンロード | |
527M | 39.6 | ダウンロード | |
WMT'16 エンデ | 90M | 22.5 | ダウンロード |
360M | 25.6 | ダウンロード | |
527M | 26.5 | ダウンロード | |
CNN / デイリーメール | 800M | 38.3 (標準) | ダウンロード |
ウィキテキスト-103 | 1147M | 22.2 (PPL) | ダウンロード |