llama classification
v1.1.1
このリポジトリは、LLaMA を使用したテキスト分類のための基本的なコードベースを提供します。
ハードウェアに関するその他の情報が必要な場合は、問題を開いてください。
ここから公式 LLaMA リポジトリからチェックポイントを取得します。
1-1.チェックポイントはプロジェクトのルート方向にあり、内容は以下のような配置になると思います。
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Python 環境を準備します。 anaconda を使用してローカル マシンの CUDA バージョンを分離することをお勧めします。
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
条件付き確率p(y|x)
を比較することです。
次のスクリプトを使用して、huggingface データセットからのデータを前処理します。これからは、ag_news データセットを使用します。
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
LLaMA を使用して条件付き確率を計算し、クラスを予測する推論。
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
、直接法をキャリブレーション法で改良したものです。
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
は、条件付き確率p(x|y)
を比較します。
次のスクリプトを使用して、huggingface データセットからのデータを前処理します。これからは、ag_news データセットを使用します。
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
LLaMA を使用して条件付き確率を計算し、クラスを予測する推論。
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
モードを使用して評価するには、前処理された直接バージョンを使用できます。 torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
データセット | num_examples | k | 方法 | 正確さ | 推論時間 |
---|---|---|---|---|---|
ag_ニュース | 7600 | 1 | 直接 | 0.7682 | 00:38:40 |
ag_ニュース | 7600 | 1 | 直接+校正済み | 0.8567 | 00:38:40 |
ag_ニュース | 7600 | 1 | チャネル | 0.7825 | 00:38:37 |
研究に私のコードベースを使用する場合は、私の作品を引用することを歓迎します。
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}