Ce référentiel fournit une base de code de base pour la classification de texte à l'aide de LLaMA.
Si vous avez besoin d'autres informations sur le matériel, veuillez ouvrir un problème.
Obtenez le point de contrôle du référentiel officiel LLaMA à partir d'ici.
1-1. Je suppose que le point de contrôle serait situé dans la direction racine du projet et que le contenu serait organisé comme suit.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Préparez votre environnement Python. Je recommande d'utiliser anaconda pour séparer la version CUDA de votre machine locale.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
consiste à comparer la probabilité conditionnelle p(y|x)
.
Prétraitez les données des ensembles de données Huggingface à l'aide des scripts suivants. Désormais, nous utilisons le jeu de données ag_news.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
Inférence pour calculer la probabilité conditionnelle à l'aide de LLaMA et de la classe de prédiction.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
consiste à améliorer la méthode directe avec la méthode d'étalonnage.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
consiste à comparer la probabilité conditionnelle p(x|y)
.
Prétraitez les données des ensembles de données Huggingface à l'aide des scripts suivants. Désormais, nous utilisons le jeu de données ag_news.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
Inférence pour calculer la probabilité conditionnelle à l'aide de LLaMA et de la classe de prédiction.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
, vous pouvez utiliser la version directe prétraitée. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Ensemble de données | num_examples | k | méthode | précision | temps d'inférence |
---|---|---|---|---|---|
ag_nouvelles | 7600 | 1 | direct | 0,7682 | 00:38:40 |
ag_nouvelles | 7600 | 1 | direct+calibré | 0,8567 | 00:38:40 |
ag_nouvelles | 7600 | 1 | canal | 0,7825 | 00:38:37 |
Ce serait bien de citer mon travail si vous utilisez ma base de code pour vos recherches.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}