This code base is the implementation of LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding.
$ git clone [email protected]:facebookresearch/LayerSkip.git
$ cd LayerSkip
$ conda create --name layer_skip python=3.10
$ conda activate layer_skip
$ pip install -r requirements.txt
Access models: In order to observe speedup, you need to access LLMs that have been trained using the LayerSkip recipe. We provide 6 checkpoints on HuggingFace of different Llama models continually pretrained using the LayerSkip recipe:
facebook/layerskip-llama2-7B
facebook/layerskip-llama2-13B
facebook/layerskip-codellama-7B
facebook/layerskip-codellama-34B
facebook/layerskip-llama3-8B
facebook/layerskip-llama3.2-1B
In order to access each model:
huggingface-cli login
, and you will be prompted to provide the token you have obtained in Step 3.Once you run those steps, the commands below to run the LayerSkip checkpoints should work.
To run one of our models in interactive mode using regular autoregressive decoding:
$ torchrun generate.py --model facebook/layerskip-llama2-7B
--sample True
--max_steps 512
In order to observe speedup, you need to use self-speculative decoding to generate tokens, and specify --exit_layer
, the layer the draft stage to exit at, and --num_speculations
, the number of draft tokens:
$ torchrun generate.py --model facebook/layerskip-llama2-7B
--sample True
--max_steps 512
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
Tips:
--model
to any HuggingFace model but in order to observe speedup with self-speculative decoding, use a model trained using the LayerSkip recipe, such as those we have open sourced on HuggingFace.--sample
, --temperature
, --top_p
, and --top_k
arguments.python generate.py --help
for details on different command-line arguments.To benchmark on a dataset:
$ torchrun benchmark.py --model facebook/layerskip-llama2-7B
--dataset cnn_dm_summarization
--num_samples 100
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
--output_dir ./logs
Tips:
--dataset
argument:
cnn_dm_summarization
: CNN/DM Summarizationxsum_summarization
: XSUM Summarizationcnn_dm_lm
: CNN/DM Language Modeling (given the first few words of an article, generate the remaining article)human_eval
: HumanEval Codingn
-shot by specifying the --n_shot
argument.--sample
, --temperature
, --top_p
, and --top_k
arguments.python benchmark.py --help
for details on different command-line arguments.We have integrated our generation scripts with Eleuther Language Model Evaluation Harness to enable a large number of tasks and properly post-process generated text.
$ torchrun eval.py --model facebook/layerskip-llama2-7B
--tasks gsm8k
--limit 10
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
--output_dir ./logs
Tips:
gsm8k
or cnn_dailymail
), while classificaton tasks, i.e., multiple choice question tasks (e.g., piqa
, social_iqa
) or True/False question tasks (e.g., boolq
) will not lead to speedup.--tasks
argument. To get a list of all of possible tasks, check this link.generate.py
and benchmark.py
scripts, you may specify different models, datasets, and sampling parameterspython benchmark.py --help
for details on different command-line arguments.Our inference hyperparameters, exit_layer
and num_speculations
determine the speedup during inference:
exit_layer
:
num_speculations
:
The optimal combination of exit_layer
and num_speculations
may change with the model, dataset and sampling parameters. Hence, we provided a script to sweep over a grid of different exit_layer
and num_speculations
:
$ torchrun sweep.py --model facebook/layerskip-llama2-7B
--dataset human_eval
--generation_strategy self_speculative
--num_samples 150
--max_steps 256
--output_dir ./logs/
--sample False
This will create a CSV file in the directory specified in the --outpu_dir
argument.
Tips:
generate.py
and benchmark.py
scripts, you may specify different models, datasets, and sampling parameterspython sweep.py --help
for details on different command-line arguments.In order to verify that the generated tokens of our self-speculative decoding algorithm are correct, we have created a script to compare the outputs of autoregressive decoding with self-speculative decoding. Note that the outputs we can only guarantee equivalence when there is no sampling (i.e., --sample False
):
$ torchrun correctness.py --model facebook/layerskip-llama2-7B
--dataset human_eval
--generation_strategy self_speculative
--num_speculations 6
--exit_layer 4
--num_samples 10
--sample False
--output_dir ./logs
Kindy check DOCKER.md to setup the project using docker
We also have other implementations of LayerSkip inference:
torch.compile()
, quantization, and tensor parallelism.Our training implementation is work-in-progress. You can check this pull request for details and discussions.
LayerSkip is licensed under CC-by-NC license. Refer to the LICENSE file in the top level directory.
We welcome contributions to LayerSkip. If you are interested in contributing please see this document.
If you use LayerSkip in your research, please use the following BibTex entry:
@misc{layerskip,
title={LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding},
author={Mostafa Elhoushi and Akshat Shrivastava and Diana Liskovich and Basil Hosmer and Bram Wasti and Liangzhen Lai and Anas Mahmoud and Bilge Acun and Saurabh Agarwal and Ahmed Roman and Ahmed A Aly and Beidi Chen and Carole-Jean Wu},
booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = aug,
year = "2024",
address = "Bangkok, Thailand",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2024.acl-long.681",
doi = "10.18653/v1/2024.acl-long.681",
pages = "12622--12642",
}