使卷积网络再次保持平移不变
张理查德.在 ICML,2019 年。
运行pip install antialiased-cnns
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True )
如果您已经有了模型并且想要抗锯齿并继续训练,请将旧权重复制到:
import torchvision . models as models
old_model = models . resnet50 ( pretrained = True ) # old (aliased) model
antialiased_cnns . copy_params_buffers ( old_model , model ) # copy the weights over
如果要修改自己的模型,请使用 BlurPool 图层。有关我们提供的模型以及如何使用 BlurPool 的更多信息如下。
C = 10 # example feature channel size
blurpool = antialiased_cnns . BlurPool ( C , stride = 2 ) # BlurPool layer; use to downsample a feature map
ex_tens = torch . Tensor ( 1 , C , 128 , 128 )
print ( blurpool ( ex_tens ). shape ) # 1xCx64x64 tensor
更新
pip install antialiased-cnns
并使用pretrained=True
标志加载模型。BlurPool
图层对您自己的模型进行抗锯齿处理的说明pip 安装这个包
pip install antialiased-cnns
或者克隆此存储库并安装要求(特别是 PyTorch)
https://github.com/adobe/antialiased-cnns.git
cd antialiased-cnns
pip install -r requirements.txt
下面加载一个预训练的抗锯齿模型,也许作为您的应用程序的骨干。
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True , filter_size = 4 )
我们还提供抗锯齿AlexNet
、 VGG16(bn)
、 Resnet18,34,50,101
、 Densenet121
和MobileNetv2
的权重(请参阅 example_usage.py)。
antialiased_cnns
模块包含BlurPool
类,它执行模糊+子采样。运行pip install antialiased-cnns
或复制antialiased_cnns
子目录。
方法方法很简单——首先使用步幅 1 进行评估,然后使用我们的BlurPool
层进行抗锯齿下采样。进行以下架构更改。
import antialiased_cnns
# MaxPool --> MaxBlurPool
baseline = nn . MaxPool2d ( kernel_size = 2 , stride = 2 )
antialiased = [ nn . MaxPool2d ( kernel_size = 2 , stride = 1 ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# Conv --> ConvBlurPool
baseline = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 2 , padding = 1 ),
nn . ReLU ( inplace = True )]
antialiased = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 1 , padding = 1 ),
nn . ReLU ( inplace = True ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# AvgPool --> BlurPool
baseline = nn . AvgPool2d ( kernel_size = 2 , stride = 2 )
antialiased = antialiased_cnns . BlurPool ( C , stride = 2 )
我们假设输入张量有C
通道。以步幅 1(而不是步幅 2)计算层会增加内存和运行时间。因此,我们通常会在最高分辨率下(在网络的早期)跳过抗锯齿,以防止大幅增加。
添加抗锯齿,然后继续训练如果您已经训练了模型,然后添加抗锯齿,则可以从旧模型进行微调:
antialiased_cnns . copy_params_buffers ( old_model , antialiased_model )
如果这不起作用,您可以只复制参数(而不是缓冲区)。添加抗锯齿功能不会添加任何参数,因此参数列表是相同的。 (它确实添加了缓冲区,因此使用一些启发式方法来匹配缓冲区,这可能会引发错误。)
antialiased_cnns . copy_params ( old_model , antialiased_model )
我们观察到准确性(图像被正确分类的频率)和一致性(同一图像的两个移位被分类为相同的频率)方面都有所提高。
准确性 | 基线 | 抗锯齿 | 三角洲 |
---|---|---|---|
亚历克斯网 | 56.55 | 56.94 | +0.39 |
VG11 | 69.02 | 70.51 | +1.49 |
VG13 | 69.93 | 71.52 | +1.59 |
VG16 | 71.59 | 72.96 | +1.37 |
VG19 | 72.38 | 73.54 | +1.16 |
vgg11_bn | 70.38 | 72.63 | +2.25 |
vgg13_bn | 71.55 | 73.61 | +2.06 |
vgg16_bn | 73.36 | 75.13 | +1.77 |
vgg19_bn | 74.24 | 75.68 | +1.44 |
资源网18 | 69.74 | 71.67 | +1.93 |
资源网34 | 73.30 | 74.60 | +1.30 |
resnet50 | 76.16 | 77.41 | +1.25 |
资源网101 | 77.37 | 78.38 | +1.01 |
资源网152 | 78.31 | 79.07 | +0.76 |
resnext50_32x4d | 77.62 | 77.93 | +0.31 |
resnext101_32x8d | 79.31 | 79.33 | +0.02 |
Wide_resnet50_2 | 78.47 | 78.70 | +0.23 |
Wide_resnet101_2 | 78.85 | 78.99 | +0.14 |
密集网121 | 74.43 | 75.79 | +1.36 |
密集网169 | 75.60 | 76.73 | +1.13 |
密集网201 | 76.90 | 77.31 | +0.41 |
密集网161 | 77.14 | 77.88 | +0.74 |
移动网络_v2 | 71.88 | 72.72 | +0.84 |
一致性 | 基线 | 抗锯齿 | 三角洲 |
---|---|---|---|
亚历克斯网 | 78.18 | 83.31 | +5.13 |
VG11 | 86.58 | 90.09 | +3.51 |
VG13 | 86.92 | 90.31 | +3.39 |
VG16 | 88.52 | 90.91 | +2.39 |
VG19 | 89.17 | 91.08 | +1.91 |
vgg11_bn | 87.16 | 90.67 | +3.51 |
vgg13_bn | 88.03 | 91.09 | +3.06 |
vgg16_bn | 89.24 | 91.58 | +2.34 |
vgg19_bn | 89.59 | 91.60 | +2.01 |
资源网18 | 85.11 | 88.36 | +3.25 |
资源网34 | 87.56 | 89.77 | +2.21 |
resnet50 | 89.20 | 91.32 | +2.12 |
资源网101 | 89.81 | 91.97 | +2.16 |
资源网152 | 90.92 | 92.42 | +1.50 |
resnext50_32x4d | 90.17 | 91.48 | +1.31 |
resnext101_32x8d | 91.33 | 92.67 | +1.34 |
Wide_resnet50_2 | 90.77 | 92.46 | +1.69 |
Wide_resnet101_2 | 90.93 | 92.10 | +1.17 |
密集网121 | 88.81 | 90.35 | +1.54 |
密集网169 | 89.68 | 90.61 | +0.93 |
密集网201 | 90.36 | 91.32 | +0.96 |
密集网161 | 90.82 | 91.66 | +0.84 |
移动网络_v2 | 86.50 | 87.73 | +1.23 |
为了减少混乱,这里提供了扩展结果(不同的过滤器尺寸)。帮助改善结果!
本作品根据 Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License 获得许可。
所有材料均由 Adobe Inc. 根据 Creative Commons BY-NC-SA 4.0 许可提供。您可以出于非商业目的使用、重新分发和改编这些材料,只要您通过引用我们的论文并注明任何更改来给予适当的认可你所做的。
该存储库基于 PyTorch 示例存储库和 torchvision 模型存储库构建。这些是 BSD 风格的许可。
如果您发现这对您的研究有用,请考虑引用此 bibtex。如有任何意见或反馈,请联系 Richard Zhu <rizhang at adobe dot com>。