يحتوي هذا المستودع على كود تنفيذ PyTorch لطريقة التعلم المستمر الرائعة L2P،
وانغ، زيفنغ، وآخرون. "التعلم للمطالبة بالتعلم المستمر." CVPR. 2022.
تطبيق Jax الرسمي موجود هنا.
النظام الذي استخدمته واختبرته
أولاً، قم باستنساخ المستودع محليًا:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
ثم قم بتثبيت الحزم أدناه:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
يمكن تثبيت هذه الحزم بسهولة عن طريق
pip install -r requirements.txt
إذا كان لديك بالفعل CIFAR-100 أو 5-Datasets (MNIST، Fashion-MNIST، NotMNIST، CIFAR10، SVHN)، قم بتمرير مسار مجموعة البيانات الخاصة بك إلى --data-path
.
مجموعات البيانات ليست جاهزة، قم بتغيير وسيطة التنزيل في datasets.py
على النحو التالي
سيفار-100
datasets.CIFAR100(download=True)
5-مجموعات البيانات
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
لتدريب نموذج عبر سطر الأوامر:
عقدة واحدة مع وحدة معالجة رسومات واحدة
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
عقدة واحدة مع وحدات معالجة رسومية متعددة
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
متوفر أيضًا في نظام Slurm عن طريق تغيير الخيارات في train_cifar100_l2p.sh
أو train_five_datasets.sh
بشكل صحيح.
التدريب الموزع متاح عبر Slurm و Submitit:
pip install submitit
لتدريب نموذج على عقدتين مع 4 وحدات معالجة رسومية لكل منهما:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
يجب أن يكون المسار المطلق للمجلد المشترك قابلاً للوصول من جميع العقد.
وفقًا لبيئتك، يمكنك استخدام NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
بشكل اختياري.
لتقييم نموذج مدرب:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
نتائج الاختبار على وحدة معالجة الرسومات واحدة.
اسم | حساب @1 | النسيان |
---|---|---|
تنفيذ الشعلة | 83.77 | 6.63 |
إعادة إنتاج التنفيذ الرسمي | 82.59 | 7.88 |
نتائج الورق | 83.83 | 7.63 |
اسم | حساب @1 | النسيان |
---|---|---|
تنفيذ الشعلة | 80.22 | 3.81 |
إعادة إنتاج التنفيذ الرسمي | 79.68 | 3.71 |
نتائج الورق | 81.14 | 4.64 |
فيما يلي المقاييس المستخدمة في الاختبار والمعاني المقابلة لها:
متري | وصف |
---|---|
حساب @1 | متوسط دقة التقييم حتى المهمة الأخيرة |
النسيان | متوسط النسيان حتى المهمة الأخيرة |
تم إصدار هذا المستودع بموجب ترخيص Apache 2.0 كما هو موجود في ملف الترخيص.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}