使用保形预测的图像分类器的不确定性集
@article{angelopoulos2020sets,
title={Uncertainty Sets for Image Classifiers using Conformal Prediction},
author={Angelopoulos, Anastasios N and Bates, Stephen and Malik, Jitendra and Jordan, Michael I},
journal={arXiv preprint arXiv:2009.14193},
year={2020}
}
此代码库修改任何 PyTorch 分类器以输出一个预测集,该预测集可证明包含具有您指定的概率的真实类。它使用一种称为正则化自适应预测集(RAPS)的方法,我们在随附的论文中介绍了该方法。该过程与 Platt 缩放一样简单且快速,但为每个模型和数据集提供了正式的有限样本覆盖保证。
Imagenet 上的预测集示例。我们展示了狐狸松鼠类的三个示例以及我们的方法生成的 95% 预测集,以说明集合大小如何根据测试时图像的难度而变化。我们编写了一个 Colab,可让您探索RAPS
和保形分类。您无需安装任何东西即可运行 Colab。该笔记本将引导您从预训练模型构建预测集。您还可以可视化 ImageNet 中的示例及其相应的RAPS
集,并使用正则化参数。
您可以通过单击下面的盾牌来访问 Colab。
如果您想在自己的项目中使用我们的代码并重现我们的实验,我们提供以下工具。请注意,虽然我们的代码库不是一个包,但它很容易像包一样使用,并且我们在上面的 Colab 笔记本中这样做。
从根目录,安装依赖项并通过执行以下命令来运行我们的示例:
git clone https://github.com/aangelopoulos/conformal-classification
cd conformal-classification
conda env create -f environment.yml
conda activate conformal
python example.py 'path/to/imagenet/val/'
查看example.py
中的一个最小示例,该示例修改预训练分类器以输出 90% 的预测集。
如果您想在您自己的模型上使用我们的代码库,请首先将其放在文件的顶部:
from conformal.py import *
from utils.py import *
然后使用如下行创建用于保形校准的保留集:
calib, val = random_split(mydataset, [num_calib,total-num_calib])
最后,您可以创建模型
model = ConformalModel(model, calib_loader, alpha=0.1, lamda_criterion='size')
ConformalModel
对象采用布尔标志randomized
。当randomized=True
,在测试时,集合将不会被随机化。这将导致保守的报道,但确定性的行为。
ConformalModel
对象采用第二个布尔标志allow_zero_sets
。当allow_zero_sets=True
时,在测试时,不允许使用大小为零的集合。这将导致保守的覆盖范围,但不会出现零大小集。
请参阅下面有关手动选择alpha
、 kreg
和lamda
讨论。
example.py
的输出应该是:
Begin Platt scaling.
Computing logits for model (only happens once).
100%|███████████████████████████████████████| 79/79 [02:24<00:00, 1.83s/it]
Optimal T=1.1976691484451294
Model calibrated and conformalized! Now evaluate over remaining data.
N: 40000 | Time: 1.686 (2.396) | Cvg@1: 0.766 (0.782) | Cvg@5: 0.969 (0.941) | Cvg@RAPS: 0.891 (0.914) | Size@RAPS: 2.953 (2.982)
Complete!
括号中的值为运行平均值。前面的值仅适用于最近的批次。您的系统上的计时值会有所不同,但其余数字应该完全相同。如果您的终端窗口很小,进度条可能会打印很多行。
实验的预期输出存储在experiments/outputs
中,它们与我们论文中报告的结果完全相同。安装我们的依赖项后,您可以通过执行“./experiments/”中的 python 脚本来重现结果。对于表 2,我们使用了 ImageNet-V2 的matched-frequencies
版本。
alpha
、 kreg
和lamda
alpha
是您愿意容忍的最大错误比例。因此,目标覆盖范围为1-alpha
。较小的alpha
通常会导致较大的集合,因为所需的覆盖范围更严格。
我们提供了两个挑选“kreg”和“lamda”的最佳程序。如果您想要小尺寸的集合,请设置 'lamda_criterion='size''。如果您想要设置近似条件覆盖范围,请设置“lamda_criterion='adaptiveness”。
麻省理工学院许可证