このリポジトリには、素晴らしい継続学習メソッド L2P の PyTorch 実装コードが含まれています。
王紫峰ら「継続的な学習を促す学習。」 CVPR。 2022年。
公式の Jax 実装はここにあります。
私が使用しテストしたシステム
まず、リポジトリのクローンをローカルに作成します。
git clone https://github.com/JH-LEE-KR/l2p-pytorch
cd l2p-pytorch
次に、以下のパッケージをインストールします。
pytorch==1.12.1
torchvision==0.13.1
timm==0.6.7
pillow==9.2.0
matplotlib==3.5.3
これらのパッケージは次の方法で簡単にインストールできます。
pip install -r requirements.txt
すでに CIFAR-100 または 5-Datasets (MNIST、Fashion-MNIST、NotMNIST、CIFAR10、SVHN) がある場合は、データセットのパスを--data-path
に渡します。
データセットの準備ができていませんdatasets.py
の download 引数を次のように変更します。
CIFAR-100
datasets.CIFAR100(download=True)
5-データセット
datasets.CIFAR10(download=True)
MNIST_RGB(download=True)
FashionMNIST(download=True)
NotMNIST(download=True)
SVHN(download=True)
コマンドライン経由でモデルをトレーニングするには:
単一の GPU を備えた単一ノード
python -m torch.distributed.launch
--nproc_per_node=1
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
マルチ GPU を備えた単一ノード
python -m torch.distributed.launch
--nproc_per_node=<Num GPUs>
--use_env main.py
<cifar100_l2p or five_datasets_l2p>
--model vit_base_patch16_224
--batch-size 16
--data-path /local_datasets/
--output_dir ./output
train_cifar100_l2p.sh
またはtrain_five_datasets.sh
のオプションを適切に変更することで、Slurm システムでも使用できます。
分散トレーニングは Slurm 経由で利用でき、次のように送信します。
pip install submitit
それぞれ 4 つの GPU を備えた 2 つのノードでモデルをトレーニングするには:
python run_with_submitit.py <cifar100_l2p or five_datasets_l2p> --shared_folder <Absolute Path of shared folder for all nodes>
共有フォルダーの絶対パスはすべてのノードからアクセスできる必要があります。
環境に応じて、 NCLL_SOCKET_IFNAME=<Your own IP interface to use for communication>
をオプションで使用できます。
トレーニングされたモデルを評価するには:
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py <cifar100_l2p or five_datasets_l2p> --eval
単一の GPU でのテスト結果。
名前 | ACC@1 | 忘れる |
---|---|---|
Pytorch の実装 | 83.77 | 6.63 |
正式実装を再現 | 82.59 | 7.88 |
論文結果 | 83.83 | 7.63 |
名前 | ACC@1 | 忘れる |
---|---|---|
Pytorch の実装 | 80.22 | 3.81 |
正式実装を再現 | 79.68 | 3.71 |
論文結果 | 81.14 | 4.64 |
テストで使用されるメトリクスと、それらの対応する意味は次のとおりです。
メトリック | 説明 |
---|---|
ACC@1 | 最後のタスクまでの平均評価精度 |
忘れる | 最後のタスクまでの平均物忘れ |
このリポジトリは、LICENSE ファイルに記載されている Apache 2.0 ライセンスに基づいてリリースされます。
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}