Этот репозиторий содержит код реализации PyTorch для замечательного метода непрерывного обучения L2P.
Ван, Цзифэн и др. «Обучение побуждает к постоянному обучению». ЦВПР. 2022.
Официальная реализация Jax находится здесь.
Система, которую я использовал и тестировал в
Сначала клонируйте репозиторий локально:
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
Затем установите пакеты ниже:
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
Эти пакеты можно легко установить,
pip install -r requirements.txt
Если у вас уже есть наборы данных CIFAR-100 или 5 (MNIST, Fashion-MNIST, NotMNIST, CIFAR10, SVHN), передайте путь к набору данных в --data-path
.
Наборы данных не готовы. Измените аргумент загрузки в datasets.py
следующим образом.
СИФАР-100
datasets.CIFAR100(download=True)
5-Наборы данных
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
Чтобы обучить модель через командную строку:
Один узел с одним графическим процессором
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Один узел с несколькими графическими процессорами
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
Также доступно в системе Slurm, если правильно изменить параметры в train_cifar100_l2p.sh
или train_five_datasets.sh
.
Распределенное обучение доступно через Slurm и отправьте его:
pip install submitit
Чтобы обучить модель на двух узлах по 4 графических процессора каждый:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
Абсолютный путь к общей папке должен быть доступен со всех узлов.
В зависимости от вашей среды вы можете дополнительно использовать NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
.
Чтобы оценить обученную модель:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
Результаты тестирования на одном графическом процессоре.
Имя | Акк@1 | Забывчивость |
---|---|---|
Pytorch-реализация | 83,77 | 6,63 |
Воспроизведение официальной реализации | 82,59 | 7,88 |
Бумажные результаты | 83,83 | 7,63 |
Имя | Акк@1 | Забывчивость |
---|---|---|
Pytorch-реализация | 80,22 | 3,81 |
Воспроизведение официальной реализации | 79,68 | 3,71 |
Бумажные результаты | 81,14 | 4,64 |
Вот метрики, использованные в тесте, и их соответствующие значения:
Метрика | Описание |
---|---|
Акк@1 | Средняя точность оценки до последнего задания |
Забывчивость | Среднее забывание до последнего задания |
Этот репозиторий выпущен под лицензией Apache 2.0, указанной в файле LICENSE.
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}