พื้นที่เก็บข้อมูลนี้มีโค้ดการใช้งาน PyTorch สำหรับวิธีการเรียนรู้อย่างต่อเนื่องที่ยอดเยี่ยม L2P
วัง ซีเฟิง และคณะ “การเรียนรู้เพื่อกระตุ้นให้เกิดการเรียนรู้อย่างต่อเนื่อง” ซีวีพีอาร์. 2022.
การใช้งาน Jax อย่างเป็นทางการอยู่ที่นี่
ระบบที่ผมใช้และทดสอบ
ขั้นแรก โคลนที่เก็บในเครื่อง:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
จากนั้นติดตั้งแพ็คเกจด้านล่าง:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
แพ็คเกจเหล่านี้สามารถติดตั้งได้อย่างง่ายดายโดย
pip install -r requirements.txt
หากคุณมีชุดข้อมูล CIFAR-100 หรือ 5 ชุดอยู่แล้ว (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN) ให้ส่งพาธชุดข้อมูลของคุณไปที่ --data-path
ชุดข้อมูลไม่พร้อม เปลี่ยนอาร์กิวเมนต์การดาวน์โหลดใน datasets.py
ดังนี้
ซีฟาร์-100
datasets.CIFAR100(download=True)
5-ชุดข้อมูล
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
วิธีฝึกโมเดลผ่านบรรทัดคำสั่ง:
โหนดเดียวที่มี GPU เดียว
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
โหนดเดียวที่มีหลาย GPU
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
ยังมีอยู่ในระบบ Slurm โดยเปลี่ยนตัวเลือกบน train_cifar100_l2p.sh
หรือ train_five_datasets.sh
อย่างถูกต้อง
การฝึกอบรมแบบกระจายสามารถทำได้ผ่าน Slurm และส่ง:
pip install submitit
วิธีฝึกโมเดลบน 2 โหนดโดยแต่ละโหนดมี 4 GPU:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
เส้นทางที่แน่นอนของโฟลเดอร์ที่ใช้ร่วมกันต้องสามารถเข้าถึงได้จากทุกโหนด
ตามสภาพแวดล้อมของคุณ คุณสามารถใช้ NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
เป็นทางเลือกได้
ในการประเมินแบบจำลองที่ได้รับการฝึก:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
ผลการทดสอบบน GPU ตัวเดียว
ชื่อ | บัญชี@1 | ลืม |
---|---|---|
Pytorch-การใช้งาน | 83.77 | 6.63 |
ทำซ้ำการใช้งานอย่างเป็นทางการ | 82.59 | 7.88 |
ผลลัพธ์กระดาษ | 83.83 | 7.63 |
ชื่อ | บัญชี@1 | ลืม |
---|---|---|
Pytorch-การใช้งาน | 80.22 | 3.81 |
ทำซ้ำการใช้งานอย่างเป็นทางการ | 79.68 | 3.71 |
ผลลัพธ์กระดาษ | 81.14 | 4.64 |
ต่อไปนี้คือหน่วยเมตริกที่ใช้ในการทดสอบและความหมายที่เกี่ยวข้อง:
เมตริก | คำอธิบาย |
---|---|
บัญชี@1 | ความแม่นยำในการประเมินโดยเฉลี่ยจนถึงงานสุดท้าย |
ลืม | เฉลี่ยลืมจนงานสุดท้าย |
พื้นที่เก็บข้อมูลนี้เผยแพร่ภายใต้ลิขสิทธิ์ Apache 2.0 ตามที่พบในไฟล์ LICENSE
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}