Repositori ini berisi kode implementasi PyTorch untuk metode pembelajaran berkelanjutan L2P yang mengagumkan,
Wang, Zifeng, dkk. "Belajar untuk mendorong pembelajaran terus-menerus." CVPR. 2022.
Implementasi resmi Jax ada di sini.
Sistem yang saya gunakan dan uji
Pertama, kloning repositori secara lokal:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
Kemudian, instal paket-paket di bawah ini:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
Paket-paket ini dapat diinstal dengan mudah melalui
pip install -r requirements.txt
Jika Anda sudah memiliki CIFAR-100 atau 5 Kumpulan Data (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN), teruskan jalur kumpulan data Anda ke --data-path
.
Kumpulan data belum siap, ubah argumen download di datasets.py
sebagai berikut
CIFAR-100
datasets.CIFAR100(download=True)
5-Kumpulan Data
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
Untuk melatih model melalui baris perintah:
Node tunggal dengan GPU tunggal
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Node tunggal dengan multi GPU
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Juga tersedia di sistem Slurm dengan mengubah opsi di train_cifar100_l2p.sh
atau train_five_datasets.sh
dengan benar.
Pelatihan terdistribusi tersedia melalui Slurm dan kirimkan:
pip install submitit
Untuk melatih model pada 2 node dengan masing-masing 4 GPU:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
Jalur Absolut folder bersama harus dapat diakses dari semua node.
Sesuai dengan lingkungan Anda, Anda dapat menggunakan NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
secara opsional.
Untuk mengevaluasi model terlatih:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
Hasil pengujian pada satu GPU.
Nama | Akun@1 | Lupa |
---|---|---|
Implementasi Pytorch | 83,77 | 6.63 |
Reproduksi Implementasi Resmi | 82.59 | 7.88 |
Hasil Makalah | 83.83 | 7.63 |
Nama | Akun@1 | Lupa |
---|---|---|
Implementasi Pytorch | 80.22 | 3.81 |
Reproduksi Implementasi Resmi | 79.68 | 3.71 |
Hasil Makalah | 81.14 | 4.64 |
Berikut adalah metrik yang digunakan dalam pengujian, dan maknanya yang terkait:
Metrik | Keterangan |
---|---|
Akun@1 | Akurasi evaluasi rata-rata hingga tugas terakhir |
Lupa | Rata-rata lupa sampai tugas terakhir |
Repositori ini dirilis di bawah lisensi Apache 2.0 seperti yang terdapat dalam file LICENSE.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}