이 저장소에는 멋진 연속 학습 방법인 L2P를 위한 PyTorch 구현 코드가 포함되어 있습니다.
왕, Zifeng, 그 외 여러분. "지속적인 학습을 유도하는 학습." CVPR. 2022.
공식 Jax 구현은 여기에 있습니다.
내가 사용하고 테스트한 시스템
먼저 저장소를 로컬로 복제합니다.
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
그런 다음 아래 패키지를 설치하십시오.
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
이 패키지는 다음을 통해 쉽게 설치할 수 있습니다.
pip install -r requirements.txt
이미 CIFAR-100 또는 5개 데이터세트(MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN)가 있는 경우 데이터세트 경로를 --data-path
에 전달하세요.
데이터세트가 준비되지 않았습니다. 다음과 같이 datasets.py
에서 다운로드 인수를 변경하세요.
CIFAR-100
datasets.CIFAR100(download=True)
5-데이터세트
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
명령줄을 통해 모델을 학습하려면 다음 안내를 따르세요.
단일 GPU를 갖춘 단일 노드
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
다중 GPU를 갖춘 단일 노드
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
train_cifar100_l2p.sh
또는 train_five_datasets.sh
의 옵션을 적절하게 변경하여 Slurm 시스템에서도 사용할 수 있습니다.
분산 교육은 Slurm을 통해 제공되며 다음을 제출하세요.
pip install submitit
각각 4개의 GPU를 사용하는 2개의 노드에서 모델을 학습하려면 다음 안내를 따르세요.
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
공유 폴더의 절대 경로는 모든 노드에서 접근 가능해야 합니다.
사용자의 환경에 따라 선택적으로 NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
사용할 수 있습니다.
학습된 모델을 평가하려면 다음 안내를 따르세요.
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
단일 GPU에 대한 테스트 결과입니다.
이름 | Acc@1 | 망각 |
---|---|---|
Pytorch 구현 | 83.77 | 6.63 |
공식 구현 재현 | 82.59 | 7.88 |
논문 결과 | 83.83 | 7.63 |
이름 | Acc@1 | 망각 |
---|---|---|
Pytorch 구현 | 80.22 | 3.81 |
공식 구현 재현 | 79.68 | 3.71 |
논문 결과 | 81.14 | 4.64 |
테스트에 사용된 측정항목과 해당 의미는 다음과 같습니다.
미터법 | 설명 |
---|---|
Acc@1 | 마지막 작업까지 평균 평가 정확도 |
망각 | 마지막 작업까지 평균 잊어버림 |
이 저장소는 LICENSE 파일에 있는 Apache 2.0 라이센스에 따라 릴리스됩니다.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}