Este repositorio proporciona una base de código básica para la clasificación de texto utilizando LLaMA.
Si necesita más información sobre el hardware, abra un problema.
Obtenga el punto de control del repositorio oficial de LLaMA desde aquí.
1-1. Supongo que el punto de control estaría ubicado en la dirección raíz del proyecto y el contenido se organizaría de la siguiente manera.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Prepare su entorno Python. Recomiendo usar anaconda para segregar la versión CUDA de su máquina local.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
es comparar la probabilidad condicional p(y|x)
.
Preprocese los datos de los conjuntos de datos de Huggingface utilizando los siguientes scripts. A partir de ahora utilizaremos el conjunto de datos ag_news.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
Inferencia para calcular la probabilidad condicional usando LLaMA y predecir clase.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
consiste en mejorar el método directo con el método de calibración.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
es comparar la probabilidad condicional p(x|y)
.
Preprocese los datos de los conjuntos de datos de Huggingface utilizando los siguientes scripts. A partir de ahora utilizaremos el conjunto de datos ag_news.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
Inferencia para calcular la probabilidad condicional usando LLaMA y predecir clase.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
, puede usar la versión directa preprocesada. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Conjunto de datos | num_ejemplos | k | método | exactitud | tiempo de inferencia |
---|---|---|---|---|---|
noticias_ag | 7600 | 1 | directo | 0.7682 | 00:38:40 |
noticias_ag | 7600 | 1 | directo+calibrado | 0.8567 | 00:38:40 |
noticias_ag | 7600 | 1 | canal | 0,7825 | 00:38:37 |
Sería bienvenido citar mi trabajo si utiliza mi código base para su investigación.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}