@inproceedings{Wu2020LiteTransformer, title={Lite Transformer with Long-Short Range Attention}, author={Zhanghao Wu* and Zhijian Liu* and Ji Lin and Yujun Lin and Song Han}, booktitle={International Conference on Learning Representations (ICLR)}, year={2020} }
Python 버전 >= 3.6
PyTorch 버전 >= 1.0.0
configargparse >= 0.14
새로운 모델을 훈련하려면 NVIDIA GPU 및 NCCL도 필요합니다
코드베이스
소스에서 fairseq를 설치하고 로컬로 개발하려면 다음을 수행하세요.
pip install --editable .
맞춤형 모듈
또한 GPU 지원을 위해 lightconv
및 dynamicconv
구축해야 합니다.
Lightconv_layer
CD fairseq/모듈/lightconv_layer 파이썬 cuda_function_gen.py 파이썬 setup.py 설치
Dynamicconv_layer
CD fairseq/모듈/dynamicconv_layer 파이썬 cuda_function_gen.py 파이썬 setup.py 설치
우리는 fairseq의 데이터 준비를 따릅니다. 데이터를 다운로드하고 전처리하려면 다음을 실행할 수 있습니다.
bash configs/iwslt14.de-en/prepare.sh
우리는 fairseq에서 데이터 전처리를 따릅니다. 데이터를 다운로드하고 전처리하려면 다음을 실행할 수 있습니다.
bash configs/wmt14.en-fr/prepare.sh
우리는 fairseq에서 데이터 전처리를 따릅니다. 먼저 구글에서 제공하는 구글 드라이브에서 전처리된 데이터를 다운로드 받아야 합니다. 데이터를 이진화하려면 다음을 실행할 수 있습니다.
bash configs/wmt16.en-de/prepare.sh [다운로드한 zip 파일 경로]
언어 모델 작업에는 추가 코드가 많기 때문에 이를 다른 분기인 language-model
에 배치합니다. 우리는 fairseq에서 데이터 전처리를 따릅니다. 데이터를 다운로드하고 전처리하려면 다음을 실행할 수 있습니다.
git checkout 언어 모델 bash 구성/wikitext-103/prepare.sh
예를 들어, WMT'14 En-Fr에서 모델을 테스트하려면 다음을 실행하세요.
configs/wmt14.en-fr/test.sh [모델 체크포인트 경로] [gpu-id] [test|valid]
예를 들어 GPU 0(WMT'14 En-Fr 테스트 세트의 BLEU 점수 사용)에서 Lite Transformer를 평가하려면 다음을 실행하면 됩니다.
configs/wmt14.en-fr/test.sh embed496/0 테스트
하단에는 사전 훈련된 여러 모델이 제공됩니다. 모델을 다운로드하고 파일을 추출할 수 있습니다.
tar -xzvf [파일 이름]
이 저장소를 사용하여 Lite Transformer를 교육하기 위한 몇 가지 예를 제공했습니다.
WMT'14 En-Fr(8개의 GPU 포함)에서 Lite Transformer를 훈련하려면 다음을 실행하면 됩니다.
python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml
더 적은 GPU(예: 4 GPU)를 사용하여 Lite Transformer를 훈련하려면 다음을 실행할 수 있습니다.
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py data/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --update-freq 32
일반적으로 모델을 훈련하려면 다음을 실행하면 됩니다.
python train.py [데이터 바이너리 경로] --configs [구성 파일 경로] [옵션 재정의]
--update-freq
는 GPU 수(8개 GPU의 경우 16개, 4개 GPU의 경우 32개)에 따라 조정되어야 합니다.
Lite Transformer를 분산 방식으로 훈련합니다. 예를 들어 총 16개의 GPU가 있는 2개의 GPU 노드의 경우입니다.
# Host1python에서 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=호스트1 --master_port=8080 train.py 데이터/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --분산 없음 생성 --update-freq 8# Host2python에서 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=호스트1 --master_port=8080 train.py 데이터/binary/wmt14_en_fr --configs configs/wmt14.en-fr/attention/multibranch_v2/embed496.yml --분산 없음 생성 --업데이트-빈도 8
우리는 논문에 보고된 Lite Transformer에 대한 체크포인트를 제공합니다.
데이터세트 | #다중 추가 | 시험 점수 | 모델 및 테스트 세트 |
---|---|---|---|
WMT'14 En-Fr | 90M | 35.3 | 다운로드 |
360M | 39.1 | 다운로드 | |
527M | 39.6 | 다운로드 | |
WMT'16 엔데 | 90M | 22.5 | 다운로드 |
360M | 25.6 | 다운로드 | |
527M | 26.5 | 다운로드 | |
CNN / 데일리메일 | 800M | 38.3(RL) | 다운로드 |
위키텍스트-103 | 1147M | 22.2(PPL) | 다운로드 |