该存储库是 INTR 的官方实现:用于细粒度图像分类和分析的简单可解释转换器。它目前包括用于解释细粒度数据的代码和模型。当本文在线发布时,我们将提供即将到来的 ICLR 2024 会议记录的链接。
INTR 是 Transformer 的一种新颖用法,使图像分类变得可解释。在 INTR 中,我们研究了一种主动的分类方法,要求每个类别在图像中寻找自己。我们学习特定于类的查询(每个类一个)作为解码器的输入,使它们能够通过交叉注意来查找图像中的存在。我们表明,INTR 本质上鼓励每个班级以不同的方式参加;因此,交叉注意力权重为模型的预测提供了有意义的解释。有趣的是,通过多头交叉注意力,INTR 可以学习定位一个类的不同属性,使其特别适合细粒度的分类和分析。
在INTR模型中,解码器中的每个查询负责一个类的预测。因此,查询会查看自身以从特征图中查找特定于类的特征。首先,我们可视化特征图,即 Transformer 架构的值矩阵,以查看图像中对象的重要部分。为了找到模型在值矩阵中关注的特定特征,我们显示了模型关注的热图。为了避免分类中的外部干扰,我们使用共享权重向量进行分类,因此注意力权重解释了模型的预测。
DETR-R50 主干上的 INTR、分类性能以及不同数据集上的微调模型。
数据集 | 帐户@1 | ACC@5 | 模型 |
---|---|---|---|
幼兽 | 71.8 | 89.3 | 检查点下载 |
鸟 | 97.4 | 99.2 | 检查点下载 |
蝴蝶 | 95.0 | 98.3 | 检查点下载 |
创建python环境(可选)
conda create -n intr python=3.8 -y
conda activate intr
克隆存储库
git clone https://github.com/dipanjyoti/INTR.git
cd INTR
安装python依赖项
pip install -r requirements.txt
请遵循以下数据格式。
datasets
├── dataset_name
│ ├── train
│ │ ├── class1
│ │ │ ├── img1.jpeg
│ │ │ ├── img2.jpeg
│ │ │ └── ...
│ │ ├── class2
│ │ │ ├── img3.jpeg
│ │ │ └── ...
│ │ └── ...
│ └── val
│ ├── class1
│ │ ├── img4.jpeg
│ │ ├── img5.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img6.jpeg
│ │ └── ...
│ └── ...
要评估 INTR 在CUB数据集上的性能,请在多 GPU(例如 4 个 GPU)设置上执行以下命令。 INTR 检查点可在微调模型和结果中使用。
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --eval --resume < path/to/intr_checkpoint_cub_detr_r50.pth > --dataset_path < path/to/datasets > --dataset_name < dataset_name >
要生成 INTR 解释的视觉表示,请执行下面提供的命令。此命令将显示索引为
python -m tools.visualization --eval --resume < path/to/intr_checkpoint_cub_detr_r50.pth > --dataset_path < path/to/datasets > --dataset_name < dataset_name > --class_index < class_number >
推理时间单图像预测和可视化:我们还提供了 Jupyter Notebook demo.ipynb,专为推理过程中的单图像预测和可视化而设计。请注意,该演示主要针对 CUB 数据集。
要准备 INTR 进行训练,请使用预训练模型 DETR-R50。要训练特定数据集,请修改“--num_queries”,将其设置为数据集中的类数。在 INTR 架构中,解码器中的每个查询都被分配了捕获特定于类的特征的任务,这意味着每个查询都可以通过学习过程进行调整。因此,模型参数的总数将与数据集中的类数量成比例增长。要在多 GPU 系统(例如 4 个 GPU)上训练 INTR,请执行以下命令。
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --finetune < path/to/detr-r50-e632da11.pth > --dataset_path < path/to/datasets > --dataset_name < dataset_name > --num_queries < num_of_classes >
我们的模型受到 DEtection TRansformer (DETR) 方法的启发。
我们感谢 DETR 的作者所做的如此出色的工作。
如果您发现我们的工作对您的研究有帮助,请考虑引用 BibTeX 条目。
@inproceedings{paul2024simple,
title={A Simple Interpretable Transformer for Fine-Grained Image Classification and Analysis},
author={Paul, Dipanjyoti and Chowdhury, Arpita and Xiong, Xinqi and Chang, Feng-Ju and Carlyn, David and Stevens, Samuel and Provost, Kaiya and Karpatne, Anuj and Carstens, Bryan and Rubenstein, Daniel and Stewart, Charles and Berger-Wolf, Tanya and Su, Yu and Chao, Wei-Lun},
booktitle={International Conference on Learning Representations},
year={2024}
}