llama classification
v1.1.1
此儲存庫提供了使用 LLaMA 進行文字分類的基本程式碼庫。
如果您需要有關硬體的其他信息,請提出問題。
從此處從官方 LLaMA 儲存庫取得檢查點。
1-1.我假設檢查點位於專案根方向,內容安排如下。
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
準備你的Python環境。我建議使用 anaconda 來隔離本機的 CUDA 版本。
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
就是比較條件機率p(y|x)
。
使用以下腳本預處理 Huggingface 資料集中的資料。從現在開始,我們使用 ag_news 資料集。
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
使用 LLaMA 計算條件機率並預測類別的推理。
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
是用標定法改進直接法。
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
是比較條件機率p(x|y)
。
使用以下腳本預處理 Huggingface 資料集中的資料。從現在開始,我們使用 ag_news 資料集。
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
使用 LLaMA 計算條件機率並預測類別的推理。
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
模式進行評估,您可以使用預處理的直接版本。 torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
數據集 | 範例數 | k | 方法 | 準確性 | 推理時間 |
---|---|---|---|---|---|
農業新聞 | 7600 | 1 | 直接的 | 0.7682 | 00:38:40 |
農業新聞 | 7600 | 1 | 直接+校準 | 0.8567 | 00:38:40 |
農業新聞 | 7600 | 1 | 頻道 | 0.7825 | 00:38:37 |
如果您使用我的程式碼庫進行研究,我們將歡迎引用我的工作。
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}