يوفر هذا المستودع قاعدة بيانات أساسية لتصنيف النص باستخدام LLaMA.
إذا كنت بحاجة إلى معلومات أخرى حول الأجهزة، يرجى فتح مشكلة.
احصل على نقطة التفتيش من مستودع LLaMA الرسمي من هنا.
1-1. أفترض أن نقطة التفتيش ستكون موجودة في اتجاه جذر المشروع وسيتم ترتيب المحتويات على النحو التالي.
checkpoints
├── llama
│ ├── 7B
│ │ ├── checklist.chk
│ │ ├── consolidated.00.pth
│ │ └── params.json
│ └── tokenizer.model
قم بإعداد بيئة بايثون الخاصة بك. أوصي باستخدام اناكوندا لفصل إصدار CUDA الخاص بجهازك المحلي.
conda create -y -n llama-classification python=3.8
conda activate llama-classification
conda install cudatoolkit=11.7 -y -c nvidia
conda list cudatoolkit # to check what cuda version is installed (11.7)
pip install -r requirements.txt
Direct
هو مقارنة الاحتمال الشرطي p(y|x)
.
المعالجة المسبقة للبيانات من مجموعات بيانات Huggingface باستخدام البرامج النصية التالية. من الآن فصاعدا، نستخدم مجموعة بيانات ag_news.
python run_preprocess_direct_ag_news.py
python run_preprocess_direct_ag_news.py --sample=False --data_path=real/inputs_direct_ag_news.json # Use it for full evaluation
الاستدلال لحساب الاحتمال الشرطي باستخدام LLaMA والتنبؤ بالفئة.
torchrun --nproc_per_node 1 run_evaluate_direct_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_direct_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Calibration
هي تحسين الطريقة المباشرة بطريقة المعايرة.
torchrun --nproc_per_node 1 run_evaluate_direct_calibrate_llama.py
--direct_input_path samples/inputs_direct_ag_news.json
--direct_output_path samples/outputs_direct_ag_news.json
--output_path samples/outputs_direct_calibrate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
Channel
هي مقارنة الاحتمال الشرطي p(x|y)
.
المعالجة المسبقة للبيانات من مجموعات بيانات Huggingface باستخدام البرامج النصية التالية. من الآن فصاعدا، نستخدم مجموعة بيانات ag_news.
python run_preprocess_channel_ag_news.py
python run_preprocess_channel_ag_news.py --sample=False --data_path=real/inputs_channel_ag_news.json # Use it for full evaluation
الاستدلال لحساب الاحتمال الشرطي باستخدام LLaMA والتنبؤ بالفئة.
torchrun --nproc_per_node 1 run_evaluate_channel_llama.py
--data_path samples/inputs_channel_ag_news.json
--output_path samples/outputs_channel_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
generate
، يمكنك استخدام الإصدار المباشر الذي تمت معالجته مسبقًا. torchrun --nproc_per_node 1 run_evaluate_generate_llama.py
--data_path samples/inputs_direct_ag_news.json
--output_path samples/outputs_generate_ag_news.json
--ckpt_dir checkpoints/llama/7B
--tokenizer_path checkpoints/llama/tokenizer.model
مجموعة البيانات | num_examples | ك | طريقة | دقة | وقت الاستدلال |
---|---|---|---|---|---|
ag_news | 7600 | 1 | مباشر | 0.7682 | 00:38:40 |
ag_news | 7600 | 1 | مباشر + معايرة | 0.8567 | 00:38:40 |
ag_news | 7600 | 1 | قناة | 0.7825 | 00:38:37 |
سيكون موضع ترحيب الاستشهاد بعملي إذا كنت تستخدم قاعدة التعليمات البرمجية الخاصة بي لبحثك.
@software{Lee_Simple_Text_Classification_2023,
author = {Lee, Seonghyeon},
month = {3},
title = {{Simple Text Classification Codebase using LLaMA}},
url = {https://github.com/github/sh0416/llama-classification},
version = {1.1.0},
year = {2023}
}