Este repositorio contiene código de implementación de PyTorch para un increíble método de aprendizaje continuo L2P.
Wang, Zifeng y col. "Aprender a impulsar el aprendizaje continuo". CVPR. 2022.
La implementación oficial de Jax está aquí.
El sistema que utilicé y probé en
Primero, clona el repositorio localmente:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
Luego, instale los siguientes paquetes:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
Estos paquetes se pueden instalar fácilmente mediante
pip install -r requirements.txt
Si ya tiene CIFAR-100 o 5-Datasets (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN), pase la ruta de su conjunto de datos a --data-path
.
Los conjuntos de datos no están listos, cambie el argumento de descarga en datasets.py
de la siguiente manera
CIFAR-100
datasets.CIFAR100(download=True)
5-conjuntos de datos
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
Para entrenar un modelo a través de la línea de comando:
Un solo nodo con una sola gpu
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Nodo único con múltiples gpus
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
También disponible en el sistema Slurm cambiando las opciones en train_cifar100_l2p.sh
o train_five_datasets.sh
correctamente.
La capacitación distribuida está disponible a través de Slurm y submitit:
pip install submitit
Para entrenar un modelo en 2 nodos con 4 gpus cada uno:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
La ruta absoluta de la carpeta compartida debe ser accesible desde todos los nodos.
Según su entorno, puede utilizar NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
opcionalmente.
Para evaluar un modelo entrenado:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
Resultados de la prueba en una sola gpu.
Nombre | Acc@1 | Olvidando |
---|---|---|
Implementación de Pytorch | 83,77 | 6.63 |
Reproducir implementación oficial | 82,59 | 7,88 |
Resultados en papel | 83,83 | 7.63 |
Nombre | Acc@1 | Olvidando |
---|---|---|
Implementación de Pytorch | 80.22 | 3.81 |
Reproducir implementación oficial | 79,68 | 3.71 |
Resultados en papel | 81.14 | 4.64 |
Estas son las métricas utilizadas en la prueba y sus significados correspondientes:
Métrico | Descripción |
---|---|
Acc@1 | Precisión media de la evaluación hasta la última tarea |
Olvidando | Olvido promedio hasta la última tarea. |
Este repositorio se publica bajo la licencia Apache 2.0 como se encuentra en el archivo LICENCIA.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}