Este repositório contém código de implementação PyTorch para um incrível método de aprendizagem contínua L2P,
Wang, Zifeng et al. "Aprender a estimular o aprendizado contínuo." CVPR. 2022.
A implementação oficial do Jax está aqui.
O sistema que usei e testei em
Primeiro, clone o repositório localmente:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
Em seguida, instale os pacotes abaixo:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
Esses pacotes podem ser instalados facilmente por
pip install -r requirements.txt
Se você já possui CIFAR-100 ou 5-Datasets (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN), passe o caminho do conjunto de dados para --data-path
.
Os conjuntos de dados não estão prontos, altere o argumento de download em datasets.py
da seguinte maneira
CIFAR-100
datasets.CIFAR100(download=True)
5-Conjuntos de dados
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
Para treinar um modelo via linha de comando:
Nó único com GPU único
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Nó único com vários gpus
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Também disponível no sistema Slurm alterando as opções em train_cifar100_l2p.sh
ou train_five_datasets.sh
corretamente.
O treinamento distribuído está disponível via Slurm e submit:
pip install submitit
Para treinar um modelo em 2 nós com 4 GPUs cada:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
O caminho absoluto da pasta compartilhada deve estar acessível em todos os nós.
De acordo com seu ambiente, você pode usar NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
opcionalmente.
Para avaliar um modelo treinado:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
Resultados de teste em uma única GPU.
Nome | Conta@1 | Esquecendo |
---|---|---|
Implementação Pytorch | 83,77 | 6,63 |
Reproduzir implementação oficial | 82,59 | 7,88 |
Resultados do artigo | 83,83 | 7,63 |
Nome | Conta@1 | Esquecendo |
---|---|---|
Implementação Pytorch | 80,22 | 3,81 |
Reproduzir implementação oficial | 79,68 | 3,71 |
Resultados do artigo | 81.14 | 4,64 |
Aqui estão as métricas usadas no teste e seus significados correspondentes:
Métrica | Descrição |
---|---|
Conta@1 | Precisão média da avaliação até a última tarefa |
Esquecendo | Esquecimento médio até a última tarefa |
Este repositório é lançado sob a licença Apache 2.0 conforme encontrado no arquivo LICENSE.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}