Esta base de código é a implementação de LayerSkip: Habilitando Inferência de Saída Antecipada e Decodificação Autoespeculativa.
$ git clone [email protected]:facebookresearch/LayerSkip.git
$ cd LayerSkip
$ conda create --name layer_skip python=3.10
$ conda activate layer_skip
$ pip install -r requirements.txt
Modelos de acesso: Para observar a aceleração, você precisa acessar LLMs que foram treinados usando a receita LayerSkip. Fornecemos 6 pontos de verificação no HuggingFace de diferentes modelos de Llama continuamente pré-treinados usando a receita LayerSkip:
facebook/layerskip-llama2-7B
facebook/layerskip-llama2-13B
facebook/layerskip-codellama-7B
facebook/layerskip-codellama-34B
facebook/layerskip-llama3-8B
facebook/layerskip-llama3.2-1B
Para acessar cada modelo:
huggingface-cli login
e você será solicitado a fornecer o token obtido na Etapa 3.Depois de executar essas etapas, os comandos abaixo para executar os pontos de verificação do LayerSkip deverão funcionar.
Para executar um de nossos modelos em modo interativo usando decodificação autorregressiva regular:
$ torchrun generate.py --model facebook/layerskip-llama2-7B
--sample True
--max_steps 512
Para observar a aceleração, você precisa usar a decodificação autoespeculativa para gerar tokens e especificar --exit_layer
, a camada do estágio de rascunho para sair, e --num_speculations
, o número de tokens de rascunho:
$ torchrun generate.py --model facebook/layerskip-llama2-7B
--sample True
--max_steps 512
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
Pontas:
--model
para qualquer modelo HuggingFace, mas para observar a aceleração com a decodificação autoespeculativa, use um modelo treinado usando a receita LayerSkip, como aqueles que temos código aberto no HuggingFace.--sample
, --temperature
, --top_p
e --top_k
.python generate.py --help
para obter detalhes sobre diferentes argumentos de linha de comando. Para comparar um conjunto de dados:
$ torchrun benchmark.py --model facebook/layerskip-llama2-7B
--dataset cnn_dm_summarization
--num_samples 100
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
--output_dir ./logs
Pontas:
--dataset
:cnn_dm_summarization
: Resumo CNN/DMxsum_summarization
: Resumo XSUMcnn_dm_lm
: Modelagem de linguagem CNN/DM (dadas as primeiras palavras de um artigo, gere o artigo restante)human_eval
: codificação HumanEvaln
-shot especificado especificando o argumento --n_shot
.--sample
, --temperature
, --top_p
e --top_k
.python benchmark.py --help
para obter detalhes sobre diferentes argumentos de linha de comando. Integramos nossos scripts de geração com o Eleuther Language Model Evaluation Harness para permitir um grande número de tarefas e pós-processar adequadamente o texto gerado.
$ torchrun eval.py --model facebook/layerskip-llama2-7B
--tasks gsm8k
--limit 10
--generation_strategy self_speculative
--exit_layer 8
--num_speculations 6
--output_dir ./logs
Pontas:
gsm8k
ou cnn_dailymail
), enquanto tarefas de classificação, ou seja, tarefas de perguntas de múltipla escolha (por exemplo, piqa
, social_iqa
) ou tarefas de perguntas Verdadeiro/Falso (por exemplo, boolq
) irão não leva à aceleração.--tasks
. Para obter uma lista de todas as tarefas possíveis, verifique este link.generate.py
e benchmark.py
, você pode especificar diferentes modelos, conjuntos de dados e parâmetros de amostragempython benchmark.py --help
para obter detalhes sobre diferentes argumentos de linha de comando. Nossos hiperparâmetros de inferência, exit_layer
e num_speculations
determinam a aceleração durante a inferência:
exit_layer
:num_speculations
: A combinação ideal de exit_layer
e num_speculations
pode mudar com o modelo, conjunto de dados e parâmetros de amostragem. Portanto, fornecemos um script para varrer uma grade de diferentes exit_layer
e num_speculations
:
$ torchrun sweep.py --model facebook/layerskip-llama2-7B
--dataset human_eval
--generation_strategy self_speculative
--num_samples 150
--max_steps 256
--output_dir ./logs/
--sample False
Isso criará um arquivo CSV no diretório especificado no argumento --outpu_dir
.
Pontas:
generate.py
e benchmark.py
, você pode especificar diferentes modelos, conjuntos de dados e parâmetros de amostragempython sweep.py --help
para obter detalhes sobre diferentes argumentos de linha de comando. Para verificar se os tokens gerados pelo nosso algoritmo de decodificação autoespeculativa estão corretos, criamos um script para comparar as saídas da decodificação autorregressiva com a decodificação autoespeculativa. Observe que as saídas só podemos garantir equivalência quando não há amostragem (ou seja, --sample False
):
$ torchrun correctness.py --model facebook/layerskip-llama2-7B
--dataset human_eval
--generation_strategy self_speculative
--num_speculations 6
--exit_layer 4
--num_samples 10
--sample False
--output_dir ./logs
Verifique DOCKER.md para configurar o projeto usando docker
Também temos outras implementações de inferência do LayerSkip:
torch.compile()
, quantização e paralelismo de tensor.Nossa implementação de treinamento está em andamento. Você pode verificar esta solicitação pull para obter detalhes e discussões.
LayerSkip é licenciado sob licença CC-by-NC. Consulte o arquivo LICENSE no diretório de nível superior.
Aceitamos contribuições para o LayerSkip. Se você estiver interessado em contribuir, consulte este documento.
Se você usar LayerSkip em sua pesquisa, use a seguinte entrada BibTex:
@misc { layerskip ,
title = { LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding } ,
author = { Mostafa Elhoushi and Akshat Shrivastava and Diana Liskovich and Basil Hosmer and Bram Wasti and Liangzhen Lai and Anas Mahmoud and Bilge Acun and Saurabh Agarwal and Ahmed Roman and Ahmed A Aly and Beidi Chen and Carole-Jean Wu } ,
booktitle = " Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers) " ,
month = aug,
year = " 2024 " ,
address = " Bangkok, Thailand " ,
publisher = " Association for Computational Linguistics " ,
url = " https://aclanthology.org/2024.acl-long.681 " ,
doi = " 10.18653/v1/2024.acl-long.681 " ,
pages = " 12622--12642 " ,
}