Repositori ini menyediakan basis kode dasar untuk klasifikasi teks menggunakan LLaMA.
Jika Anda memerlukan informasi lain tentang perangkat keras, silakan buka terbitan.
Dapatkan pos pemeriksaan dari repositori resmi LLaMA dari sini.
1-1. Saya berasumsi bahwa pos pemeriksaan akan ditempatkan di arah akar proyek dan isinya akan diatur sebagai berikut.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
Siapkan lingkungan python Anda. Saya sarankan menggunakan anaconda untuk memisahkan versi CUDA mesin lokal Anda.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
adalah membandingkan probabilitas bersyarat p(y|x)
.
Proses awal data dari kumpulan data pelukan menggunakan skrip berikut. Mulai sekarang, kami menggunakan dataset ag_news.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
Inferensi untuk menghitung probabilitas bersyarat menggunakan LLaMA dan kelas prediksi.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
adalah menyempurnakan metode langsung dengan metode kalibrasi.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
adalah membandingkan probabilitas bersyarat p(x|y)
.
Proses awal data dari kumpulan data pelukan menggunakan skrip berikut. Mulai sekarang, kami menggunakan dataset ag_news.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
Inferensi untuk menghitung probabilitas bersyarat menggunakan LLaMA dan kelas prediksi.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
, Anda dapat menggunakan versi langsung yang telah diproses sebelumnya. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Kumpulan data | nomor_contoh | k | metode | ketepatan | waktu inferensi |
---|---|---|---|---|---|
ag_news | 7600 | 1 | langsung | 0,7682 | 00:38:40 |
ag_news | 7600 | 1 | langsung+dikalibrasi | 0,8567 | 00:38:40 |
ag_news | 7600 | 1 | saluran | 0,7825 | 00:38:37 |
Sebaiknya Anda mengutip karya saya jika Anda menggunakan basis kode saya untuk penelitian Anda.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}