Dieses Repository bietet eine grundlegende Codebasis für die Textklassifizierung mit LLaMA.
Wenn Sie weitere Informationen zur Hardware benötigen, öffnen Sie bitte ein Problem.
Holen Sie sich den Checkpoint hier aus dem offiziellen LLaMA-Repository.
1-1. Ich gehe davon aus, dass sich der Prüfpunkt in Projektstammrichtung befindet und der Inhalt wie folgt angeordnet ist.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Bereiten Sie Ihre Python-Umgebung vor. Ich empfehle die Verwendung von Anaconda, um die CUDA-Version Ihres lokalen Computers zu trennen.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
besteht darin, die bedingte Wahrscheinlichkeit p(y|x)
zu vergleichen.
Verarbeiten Sie die Daten aus Huggingface-Datensätzen mit den folgenden Skripts vor. Von nun an verwenden wir den ag_news-Datensatz.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
Inferenz zur Berechnung der bedingten Wahrscheinlichkeit mithilfe von LLaMA und zur Vorhersage der Klasse.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
soll die direkte Methode mit der Kalibrierungsmethode verbessern.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
besteht darin, die bedingte Wahrscheinlichkeit p(x|y)
zu vergleichen.
Verarbeiten Sie die Daten aus Huggingface-Datensätzen mit den folgenden Skripts vor. Von nun an verwenden wir den ag_news-Datensatz.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
Inferenz zur Berechnung der bedingten Wahrscheinlichkeit mithilfe von LLaMA und zur Vorhersage der Klasse.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
können Sie die vorverarbeitete Direktversion verwenden. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Datensatz | Anzahl_Beispiele | k | Verfahren | Genauigkeit | Inferenzzeit |
---|---|---|---|---|---|
ag_news | 7600 | 1 | direkt | 0,7682 | 00:38:40 |
ag_news | 7600 | 1 | direkt+kalibriert | 0,8567 | 00:38:40 |
ag_news | 7600 | 1 | Kanal | 0,7825 | 00:38:37 |
Es wäre willkommen, meine Arbeit zu zitieren, wenn Sie meine Codebasis für Ihre Forschung verwenden.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}