Ce référentiel contient le code d'implémentation de PyTorch pour la superbe méthode d'apprentissage continu L2P,
Wang, Zifeng et coll. "Apprendre à inciter à un apprentissage continu." CVPR. 2022.
L'implémentation officielle de Jax est ici.
Le système que j'ai utilisé et testé dans
Tout d'abord, clonez le référentiel localement :
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
Ensuite, installez les packages ci-dessous :
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
Ces packages peuvent être installés facilement par
pip install -r requirements.txt
Si vous disposez déjà de CIFAR-100 ou de 5-Datasets (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN), transmettez le chemin de votre ensemble de données à --data-path
.
Les ensembles de données ne sont pas prêts, modifiez l'argument de téléchargement dans datasets.py
comme suit
CIFAR-100
datasets.CIFAR100(download=True)
5-Ensembles de données
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
Pour entraîner un modèle via la ligne de commande :
Nœud unique avec un seul GPU
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Nœud unique avec plusieurs GPU
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Également disponible dans le système Slurm en modifiant correctement les options sur train_cifar100_l2p.sh
ou train_five_datasets.sh
.
La formation distribuée est disponible via Slurm et soumettez-la :
pip install submitit
Pour entraîner un modèle sur 2 nœuds avec 4 GPU chacun :
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
Le chemin absolu du dossier partagé doit être accessible depuis tous les nœuds.
Selon votre environnement, vous pouvez utiliser NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
en option.
Pour évaluer un modèle entraîné :
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
Résultats des tests sur un seul GPU.
Nom | Acc@1 | Oubli |
---|---|---|
Implémentation de Pytorch | 83,77 | 6,63 |
Reproduire la mise en œuvre officielle | 82,59 | 7,88 |
Résultats papier | 83,83 | 7.63 |
Nom | Acc@1 | Oubli |
---|---|---|
Implémentation de Pytorch | 80.22 | 3,81 |
Reproduire la mise en œuvre officielle | 79,68 | 3,71 |
Résultats papier | 81.14 | 4,64 |
Voici les métriques utilisées dans le test et leurs significations correspondantes :
Métrique | Description |
---|---|
Acc@1 | Précision moyenne de l'évaluation jusqu'à la dernière tâche |
Oubli | Oubli moyen jusqu'à la dernière tâche |
Ce référentiel est publié sous la licence Apache 2.0 telle que trouvée dans le fichier LICENSE.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}