Dieses Repository enthält PyTorch-Implementierungscode für die fantastische kontinuierliche Lernmethode L2P.
Wang, Zifeng et al. „Lernen, zum kontinuierlichen Lernen aufzufordern.“ CVPR. 2022.
Die offizielle Jax-Implementierung ist hier.
Das System, das ich verwendet und getestet habe
Klonen Sie zunächst das Repository lokal:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
Installieren Sie dann die folgenden Pakete:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
Diese Pakete können einfach von installiert werden
pip install -r requirements.txt
Wenn Sie bereits über CIFAR-100- oder 5-Datensätze (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN) verfügen, übergeben Sie Ihren Datensatzpfad an --data-path
.
Die Datensätze sind noch nicht bereit. Ändern Sie das Download-Argument in datasets.py
wie folgt
CIFAR-100
datasets.CIFAR100(download=True)
5-Datensätze
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
So trainieren Sie ein Modell über die Befehlszeile:
Einzelner Knoten mit einzelner GPU
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Einzelner Knoten mit mehreren GPUs
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Auch im Slurm-System verfügbar, indem die Optionen auf train_cifar100_l2p.sh
oder train_five_datasets.sh
ordnungsgemäß geändert werden.
Verteilte Schulungen sind über Slurm verfügbar und reichen Sie ein:
pip install submitit
So trainieren Sie ein Modell auf 2 Knoten mit jeweils 4 GPUs:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
Der absolute Pfad des freigegebenen Ordners muss von allen Knoten aus zugänglich sein.
Abhängig von Ihrer Umgebung können Sie optional NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
verwenden.
So bewerten Sie ein trainiertes Modell:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
Testergebnisse auf einer einzelnen GPU.
Name | Acc@1 | Vergessen |
---|---|---|
Pytorch-Implementierung | 83,77 | 6.63 |
Offizielle Implementierung reproduzieren | 82,59 | 7,88 |
Papierergebnisse | 83,83 | 7.63 |
Name | Acc@1 | Vergessen |
---|---|---|
Pytorch-Implementierung | 80,22 | 3,81 |
Offizielle Implementierung reproduzieren | 79,68 | 3,71 |
Papierergebnisse | 81.14 | 4,64 |
Hier sind die im Test verwendeten Metriken und ihre entsprechende Bedeutung:
Metrisch | Beschreibung |
---|---|
Acc@1 | Durchschnittliche Bewertungsgenauigkeit bis zur letzten Aufgabe |
Vergessen | Durchschnittliches Vergessen bis zur letzten Aufgabe |
Dieses Repository wird unter der Apache 2.0-Lizenz veröffentlicht, wie in der LICENSE-Datei zu finden.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}