llama classification
v1.1.1
이 저장소는 LLaMA를 사용하여 텍스트 분류를 위한 기본 코드베이스를 제공합니다.
하드웨어에 대한 기타 정보가 필요한 경우 이슈를 열어주세요.
여기에서 공식 LLaMA 저장소의 체크포인트를 가져옵니다.
1-1. 체크포인트는 프로젝트 루트 방향에 위치한다고 가정하고 내용은 다음과 같이 정리할 것이다.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Python 환경을 준비합니다. 로컬 머신의 CUDA 버전을 분리하려면 아나콘다를 사용하는 것이 좋습니다.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
조건부 확률 p(y|x)
비교하는 것입니다.
다음 스크립트를 사용하여 포옹얼굴 데이터세트의 데이터를 전처리합니다. 이제부터는 ag_news 데이터세트를 사용합니다.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
LLaMA 및 예측 클래스를 사용하여 조건부 확률을 계산하는 추론입니다.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
Calibration 방식으로 Direct 방식을 개선한 것입니다.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
조건부 확률 p(x|y)
비교하는 것입니다.
다음 스크립트를 사용하여 포옹얼굴 데이터세트의 데이터를 전처리합니다. 이제부터는 ag_news 데이터세트를 사용합니다.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
LLaMA 및 예측 클래스를 사용하여 조건부 확률을 계산하는 추론입니다.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
모드 사용을 평가하려면 전처리된 직접 버전을 사용할 수 있습니다. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
데이터세트 | num_examples | 케이 | 방법 | 정확성 | 추론 시간 |
---|---|---|---|---|---|
ag_news | 7600 | 1 | 직접 | 0.7682 | 00:38:40 |
ag_news | 7600 | 1 | 직접+보정됨 | 0.8567 | 00:38:40 |
ag_news | 7600 | 1 | 채널 | 0.7825 | 00:38:37 |
연구를 위해 내 코드베이스를 사용한다면 내 작업을 인용하는 것이 좋습니다.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}