Este repositório fornece uma base de código básica para classificação de texto usando LLaMA.
Se você precisar de outras informações sobre hardware, abra um problema.
Obtenha o ponto de verificação do repositório oficial do LLaMA aqui.
1-1. Presumo que o ponto de verificação estaria localizado na direção raiz do projeto e o conteúdo seria organizado da seguinte forma.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Prepare seu ambiente python. Eu recomendo usar o anaconda para segregar a versão CUDA da sua máquina local.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
é comparar a probabilidade condicional p(y|x)
.
Pré-processe os dados dos conjuntos de dados huggingface usando os scripts a seguir. De agora em diante, usaremos o conjunto de dados ag_news.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
Inferência para calcular a probabilidade condicional usando LLaMA e classe de previsão.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
é melhorar o método direto com o método de calibração.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
é comparar a probabilidade condicional p(x|y)
.
Pré-processe os dados dos conjuntos de dados huggingface usando os scripts a seguir. De agora em diante, usaremos o conjunto de dados ag_news.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
Inferência para calcular a probabilidade condicional usando LLaMA e classe de previsão.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
, você pode usar a versão direta pré-processada. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Conjunto de dados | num_exemplos | k | método | precisão | tempo de inferência |
---|---|---|---|---|---|
ag_news | 7600 | 1 | direto | 0,7682 | 00:38:40 |
ag_news | 7600 | 1 | direto+calibrado | 0,8567 | 00:38:40 |
ag_news | 7600 | 1 | canal | 0,7825 | 00:38:37 |
Seria bem-vindo citar meu trabalho se você usar minha base de código para sua pesquisa.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}