Basis kode ini adalah implementasi LayerSkip: Mengaktifkan Inferensi Keluar Dini dan Decoding Spekulatif Mandiri.
$ git clone [email protected]:facebookresearch/LayerSkip.git
$ cd LayerSkip
$ conda create --name layer_skip python=3.10
$ conda activate layer_skip
$ pip install -r requirements.txt
Model akses: Untuk mengamati percepatan, Anda perlu mengakses LLM yang telah dilatih menggunakan resep LayerSkip. Kami menyediakan 6 pos pemeriksaan di HuggingFace dari model Llama berbeda yang terus dilatih sebelumnya menggunakan resep LayerSkip:
facebook/layerskip-llama2-7B
facebook/layerskip-llama2-13B
facebook/layerskip-codellama-7B
facebook/layerskip-codellama-34B
facebook/layerskip-llama3-8B
facebook/layerskip-llama3.2-1B
Untuk mengakses setiap model:
huggingface-cli login
, dan Anda akan diminta untuk memberikan token yang telah Anda peroleh di Langkah 3.Setelah Anda menjalankan langkah-langkah tersebut, perintah di bawah ini untuk menjalankan pos pemeriksaan LayerSkip akan berfungsi.
Untuk menjalankan salah satu model kami dalam mode interaktif menggunakan decoding autoregresif biasa:
$ torchrun generate.py --model facebook/layerskip-llama2-7B
--sample True
--max_steps 512
Untuk mengamati percepatan, Anda perlu menggunakan decoding spekulatif mandiri untuk menghasilkan token, dan menentukan --exit_layer
, lapisan tahap draf untuk keluar, dan --num_speculations
, jumlah token draf:
$ torchrun generate.py --model facebook/layerskip-llama2-7B
--sample True
--max_steps 512
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
Kiat:
--model
ke model HuggingFace apa pun, tetapi untuk mengamati percepatan dengan decoding spekulatif mandiri, gunakan model yang dilatih menggunakan resep LayerSkip, seperti yang kami miliki sebagai sumber terbuka di HuggingFace.--sample
, --temperature
, --top_p
, dan --top_k
.python generate.py --help
untuk detail tentang argumen baris perintah yang berbeda. Untuk melakukan benchmark pada kumpulan data:
$ torchrun benchmark.py --model facebook/layerskip-llama2-7B
--dataset cnn_dm_summarization
--num_samples 100
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
--output_dir ./logs
Kiat:
--dataset
:cnn_dm_summarization
: Ringkasan CNN/DMxsum_summarization
: Peringkasan XSUMcnn_dm_lm
: Pemodelan Bahasa CNN/DM (mengingat beberapa kata pertama dari sebuah artikel, menghasilkan artikel yang tersisa)human_eval
: Pengkodean HumanEvaln
-shot tertentu dengan menentukan argumen --n_shot
.--sample
, --temperature
, --top_p
, dan --top_k
.python benchmark.py --help
untuk detail tentang argumen baris perintah yang berbeda. Kami telah mengintegrasikan skrip generasi kami dengan Eleuther Language Model Evaluation Harness untuk mengaktifkan sejumlah besar tugas dan menghasilkan teks pasca-proses dengan benar.
$ torchrun eval.py --model facebook/layerskip-llama2-7B
--tasks gsm8k
--limit 10
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
--output_dir ./logs
Kiat:
gsm8k
atau cnn_dailymail
), sedangkan tugas klasifikasi, yaitu tugas pertanyaan pilihan ganda (misalnya piqa
, social_iqa
) atau tugas pertanyaan Benar/Salah (misalnya boolq
) akan tidak menyebabkan percepatan.--tasks
. Untuk mendapatkan daftar semua kemungkinan tugas, periksa tautan ini.generate.py
dan benchmark.py
, Anda dapat menentukan model, kumpulan data, dan parameter pengambilan sampel yang berbedapython benchmark.py --help
untuk detail tentang argumen baris perintah yang berbeda. Hyperparameter inferensi kami, exit_layer
dan num_speculations
menentukan percepatan selama inferensi:
exit_layer
:num_speculations
: Kombinasi optimal exit_layer
dan num_speculations
dapat berubah seiring dengan model, kumpulan data, dan parameter pengambilan sampel. Oleh karena itu, kami menyediakan skrip untuk menyapu grid exit_layer
dan num_speculations
yang berbeda :
$ torchrun sweep.py --model facebook/layerskip-llama2-7B
--dataset human_eval
--generation_strategy self_speculative
--num_samples 150
--max_steps 256
--output_dir ./logs/
--sample False
Ini akan membuat file CSV di direktori yang ditentukan dalam argumen --outpu_dir
.
Kiat:
generate.py
dan benchmark.py
, Anda dapat menentukan model, kumpulan data, dan parameter pengambilan sampel yang berbedapython sweep.py --help
untuk mengetahui detail tentang argumen baris perintah yang berbeda. Untuk memverifikasi bahwa token yang dihasilkan dari algoritme pengodean spekulatif mandiri kami benar, kami telah membuat skrip untuk membandingkan keluaran pengodean autoregresif dengan pengodean spekulatif mandiri. Perhatikan bahwa keluaran yang kami jamin hanya setara jika tidak ada pengambilan sampel (yaitu, --sample False
):
$ torchrun correctness.py --model facebook/layerskip-llama2-7B
--dataset human_eval
--generation_strategy self_speculative
--num_speculations 6
--exit_layer 4
--num_samples 10
--sample False
--output_dir ./logs
Silakan periksa DOCKER.md untuk mengatur proyek menggunakan buruh pelabuhan
Kami juga memiliki implementasi lain dari inferensi LayerSkip:
torch.compile()
, kuantisasi, dan paralelisme tensor.Implementasi pelatihan kami sedang dalam proses. Anda dapat memeriksa permintaan tarik ini untuk detail dan diskusi.
LayerSkip dilisensikan di bawah lisensi CC-by-NC. Lihat file LISENSI di direktori tingkat atas.
Kami menyambut kontribusi ke LayerSkip. Jika Anda tertarik untuk berkontribusi silakan lihat dokumen ini.
Jika Anda menggunakan LayerSkip dalam penelitian Anda, silakan gunakan entri BibTex berikut:
@misc { layerskip ,
title = { LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding } ,
author = { Mostafa Elhoushi and Akshat Shrivastava and Diana Liskovich and Basil Hosmer and Bram Wasti and Liangzhen Lai and Anas Mahmoud and Bilge Acun and Saurabh Agarwal and Ahmed Roman and Ahmed A Aly and Beidi Chen and Carole-Jean Wu } ,
booktitle = " Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) " ,
month = aug,
year = " 2024 " ,
address = " Bangkok, Thailand " ,
publisher = " Association for Computational Linguistics " ,
url = " https://aclanthology.org/2024.acl-long.681 " ,
doi = " 10.18653/v1/2024.acl-long.681 " ,
pages = " 12622--12642 " ,
}