使卷積網路再次維持平移不變
張理查德.在 ICML,2019 年。
運行pip install antialiased-cnns
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True )
如果您已經有了模型並且想要抗鋸齒並繼續訓練,請將舊權重複製到:
import torchvision . models as models
old_model = models . resnet50 ( pretrained = True ) # old (aliased) model
antialiased_cnns . copy_params_buffers ( old_model , model ) # copy the weights over
如果要修改自己的模型,請使用 BlurPool 圖層。有關我們提供的模型以及如何使用 BlurPool 的更多資訊如下。
C = 10 # example feature channel size
blurpool = antialiased_cnns . BlurPool ( C , stride = 2 ) # BlurPool layer; use to downsample a feature map
ex_tens = torch . Tensor ( 1 , C , 128 , 128 )
print ( blurpool ( ex_tens ). shape ) # 1xCx64x64 tensor
更新
pip install antialiased-cnns
並使用pretrained=True
標誌載入模型。BlurPool
圖層對您自己的模型進行抗鋸齒處理的說明pip 安裝這個包
pip install antialiased-cnns
或複製此儲存庫並安裝要求(特別是 PyTorch)
https://github.com/adobe/antialiased-cnns.git
cd antialiased-cnns
pip install -r requirements.txt
下面載入一個預先訓練的抗鋸齒模型,也許作為您的應用程式的骨幹。
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True , filter_size = 4 )
我們也提供抗鋸齒AlexNet
、 VGG16(bn)
、 Resnet18,34,50,101
、 Densenet121
和MobileNetv2
的權重(請參閱 example_usage.py)。
antialiased_cnns
模組包含BlurPool
類,它執行模糊+子取樣。運行pip install antialiased-cnns
或複製antialiased_cnns
子目錄。
方法方法很簡單—首先使用步幅 1 進行評估,然後使用我們的BlurPool
層進行抗鋸齒下取樣。進行以下架構變更。
import antialiased_cnns
# MaxPool --> MaxBlurPool
baseline = nn . MaxPool2d ( kernel_size = 2 , stride = 2 )
antialiased = [ nn . MaxPool2d ( kernel_size = 2 , stride = 1 ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# Conv --> ConvBlurPool
baseline = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 2 , padding = 1 ),
nn . ReLU ( inplace = True )]
antialiased = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 1 , padding = 1 ),
nn . ReLU ( inplace = True ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# AvgPool --> BlurPool
baseline = nn . AvgPool2d ( kernel_size = 2 , stride = 2 )
antialiased = antialiased_cnns . BlurPool ( C , stride = 2 )
我們假設輸入張量有C
頻道。以步幅 1(而不是步幅 2)計算層會增加記憶體和運行時間。因此,我們通常會在最高解析度下(在網路的早期)跳過抗鋸齒,以防止大幅增加。
添加抗鋸齒,然後繼續訓練如果您已經訓練了模型,然後添加抗鋸齒,則可以從舊模型進行微調:
antialiased_cnns . copy_params_buffers ( old_model , antialiased_model )
如果這不起作用,您可以只複製參數(而不是緩衝區)。添加抗鋸齒功能不會添加任何參數,因此參數列表是相同的。 (它確實添加了緩衝區,因此使用一些啟發式方法來匹配緩衝區,這可能會引發錯誤。)
antialiased_cnns . copy_params ( old_model , antialiased_model )
我們觀察到準確性(影像被正確分類的頻率)和一致性(同一影像的兩個移位被分類為相同的頻率)方面都有所提高。
準確性 | 基線 | 抗鋸齒 | 三角洲 |
---|---|---|---|
亞歷克斯網 | 56.55 | 56.94 | +0.39 |
VG11 | 69.02 | 70.51 | +1.49 |
VG13 | 69.93 | 71.52 | +1.59 |
VG16 | 71.59 | 72.96 | +1.37 |
VG19 | 72.38 | 73.54 | +1.16 |
vgg11_bn | 70.38 | 72.63 | +2.25 |
vgg13_bn | 71.55 | 73.61 | +2.06 |
vgg16_bn | 73.36 | 75.13 | +1.77 |
vgg19_bn | 74.24 | 75.68 | +1.44 |
資源網18 | 69.74 | 71.67 | +1.93 |
資源網34 | 73.30 | 74.60 | +1.30 |
resnet50 | 76.16 | 77.41 | +1.25 |
資源網101 | 77.37 | 78.38 | +1.01 |
資源網152 | 78.31 | 79.07 | +0.76 |
resnext50_32x4d | 77.62 | 77.93 | +0.31 |
resnext101_32x8d | 79.31 | 79.33 | +0.02 |
Wide_resnet50_2 | 78.47 | 78.70 | +0.23 |
Wide_resnet101_2 | 78.85 | 78.99 | +0.14 |
密集網121 | 74.43 | 75.79 | +1.36 |
密集網169 | 75.60 | 76.73 | +1.13 |
密集網201 | 76.90 | 77.31 | +0.41 |
密集網161 | 77.14 | 77.88 | +0.74 |
行動網路_v2 | 71.88 | 72.72 | +0.84 |
一致性 | 基線 | 抗鋸齒 | 三角洲 |
---|---|---|---|
亞歷克斯網 | 78.18 | 83.31 | +5.13 |
VG11 | 86.58 | 90.09 | +3.51 |
VG13 | 86.92 | 90.31 | +3.39 |
VG16 | 88.52 | 90.91 | +2.39 |
VG19 | 89.17 | 91.08 | +1.91 |
vgg11_bn | 87.16 | 90.67 | +3.51 |
vgg13_bn | 88.03 | 91.09 | +3.06 |
vgg16_bn | 89.24 | 91.58 | +2.34 |
vgg19_bn | 89.59 | 91.60 | +2.01 |
資源網18 | 85.11 | 88.36 | +3.25 |
資源網34 | 87.56 | 89.77 | +2.21 |
resnet50 | 89.20 | 91.32 | +2.12 |
資源網101 | 89.81 | 91.97 | +2.16 |
資源網152 | 90.92 | 92.42 | +1.50 |
resnext50_32x4d | 90.17 | 91.48 | +1.31 |
resnext101_32x8d | 91.33 | 92.67 | +1.34 |
Wide_resnet50_2 | 90.77 | 92.46 | +1.69 |
Wide_resnet101_2 | 90.93 | 92.10 | +1.17 |
密集網121 | 88.81 | 90.35 | +1.54 |
密集網169 | 89.68 | 90.61 | +0.93 |
密集網201 | 90.36 | 91.32 | +0.96 |
密集網161 | 90.82 | 91.66 | +0.84 |
行動網路_v2 | 86.50 | 87.73 | +1.23 |
為了減少混亂,這裡提供了擴展結果(不同的過濾器尺寸)。幫助改善結果!
本作品根據 Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License 授權。
所有資料均由 Adobe Inc. 根據 Creative Commons BY-NC-SA 4.0授權提供。來給予適當的認可你所做的。
此儲存庫基於 PyTorch 範例儲存庫和 torchvision 模型儲存庫建置。這些是 BSD 風格的許可。
如果您發現這對您的研究有用,請考慮引用此 bibtex。如有任何意見或回饋,請聯絡 Richard Zhu <rizhang at adobe dot com>。