畳み込みネットワークを再びシフト不変にする
リチャード・チャン。 ICMLにて、2019年。
pip install antialiased-cnns
実行します
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True )
すでにモデルがあり、アンチエイリアスを適用してトレーニングを続行したい場合は、古い重みをコピーします。
import torchvision . models as models
old_model = models . resnet50 ( pretrained = True ) # old (aliased) model
antialiased_cnns . copy_params_buffers ( old_model , model ) # copy the weights over
独自のモデルを変更する場合は、BlurPool レイヤーを使用します。提供されているモデルと BlurPool の使用方法の詳細については、以下をご覧ください。
C = 10 # example feature channel size
blurpool = antialiased_cnns . BlurPool ( C , stride = 2 ) # BlurPool layer; use to downsample a feature map
ex_tens = torch . Tensor ( 1 , C , 128 , 128 )
print ( blurpool ( ex_tens ). shape ) # 1xCx64x64 tensor
アップデート
pip install antialiased-cnns
、 pretrained=True
フラグを使用してモデルを読み込むこともできるようになりました。BlurPool
レイヤーを使用して独自のモデルをアンチエイリアスする手順このパッケージを PIP インストールします
pip install antialiased-cnns
または、このリポジトリのクローンを作成し、要件 (特に PyTorch) をインストールします。
https://github.com/adobe/antialiased-cnns.git
cd antialiased-cnns
pip install -r requirements.txt
以下は、おそらくアプリケーションのバックボーンとして、事前トレーニングされたアンチエイリアス モデルを読み込みます。
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True , filter_size = 4 )
また、アンチエイリアスされたAlexNet
、 VGG16(bn)
、 Resnet18,34,50,101
、 Densenet121
、およびMobileNetv2
の重みも提供します (example_usage.py を参照)。
antialiased_cnns
モジュールには、ブラー+サブサンプリングを行うBlurPool
クラスが含まれています。 pip install antialiased-cnns
実行するか、 antialiased_cnns
サブディレクトリをコピーします。
方法論 方法論は単純です。最初にストライド 1 で評価し、次にBlurPool
レイヤーを使用してアンチエイリアス ダウンサンプリングを実行します。次のアーキテクチャ上の変更を加えます。
import antialiased_cnns
# MaxPool --> MaxBlurPool
baseline = nn . MaxPool2d ( kernel_size = 2 , stride = 2 )
antialiased = [ nn . MaxPool2d ( kernel_size = 2 , stride = 1 ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# Conv --> ConvBlurPool
baseline = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 2 , padding = 1 ),
nn . ReLU ( inplace = True )]
antialiased = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 1 , padding = 1 ),
nn . ReLU ( inplace = True ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# AvgPool --> BlurPool
baseline = nn . AvgPool2d ( kernel_size = 2 , stride = 2 )
antialiased = antialiased_cnns . BlurPool ( C , stride = 2 )
入力テンソルにはC
チャネルがあると仮定します。ストライド 2 ではなくストライド 1 でレイヤーを計算すると、メモリと実行時間が追加されます。そのため、大幅な増加を防ぐために、通常は最高解像度 (ネットワークの初期段階) でのアンチエイリアスをスキップします。
アンチエイリアスを追加してトレーニングを続行するすでにモデルをトレーニングしてからアンチエイリアスを追加する場合は、その古いモデルから微調整できます。
antialiased_cnns . copy_params_buffers ( old_model , antialiased_model )
これが機能しない場合は、パラメータ (バッファではなく) をコピーするだけです。アンチエイリアスを追加してもパラメータは追加されないため、パラメータ リストは同じです。 (バッファーを追加するため、バッファーを照合するために何らかのヒューリスティックが使用され、エラーがスローされる可能性があります。)
antialiased_cnns . copy_params ( old_model , antialiased_model )
精度(画像が正しく分類される頻度) と一貫性(同じ画像の 2 つのシフトが同じに分類される頻度) の両方が向上していることが観察されています。
正確さ | ベースライン | アンチエイリアス済み | デルタ |
---|---|---|---|
アレックスネット | 56.55 | 56.94 | +0.39 |
vgg11 | 69.02 | 70.51 | +1.49 |
vgg13 | 69.93 | 71.52 | +1.59 |
vgg16 | 71.59 | 72.96 | +1.37 |
vgg19 | 72.38 | 73.54 | +1.16 |
vgg11_bn | 70.38 | 72.63 | +2.25 |
vgg13_bn | 71.55 | 73.61 | +2.06 |
vgg16_bn | 73.36 | 75.13 | +1.77 |
vgg19_bn | 74.24 | 75.68 | +1.44 |
レスネット18 | 69.74 | 71.67 | +1.93 |
レスネット34 | 73.30 | 74.60 | +1.30 |
レスネット50 | 76.16 | 77.41 | +1.25 |
レスネット101 | 77.37 | 78.38 | +1.01 |
レスネット152 | 78.31 | 79.07 | +0.76 |
resnext50_32x4d | 77.62 | 77.93 | +0.31 |
resnext101_32x8d | 79.31 | 79.33 | +0.02 |
ワイドレスネット50_2 | 78.47 | 78.70 | +0.23 |
ワイドレスネット101_2 | 78.85 | 78.99 | +0.14 |
デンスネット121 | 74.43 | 75.79 | +1.36 |
デンスネット169 | 75.60 | 76.73 | +1.13 |
デンスネット201 | 76.90 | 77.31 | +0.41 |
デンスネット161 | 77.14 | 77.88 | +0.74 |
モバイルネット_v2 | 71.88 | 72.72 | +0.84 |
一貫性 | ベースライン | アンチエイリアス済み | デルタ |
---|---|---|---|
アレックスネット | 78.18 | 83.31 | +5.13 |
vgg11 | 86.58 | 90.09 | +3.51 |
vgg13 | 86.92 | 90.31 | +3.39 |
vgg16 | 88.52 | 90.91 | +2.39 |
vgg19 | 89.17 | 91.08 | +1.91 |
vgg11_bn | 87.16 | 90.67 | +3.51 |
vgg13_bn | 88.03 | 91.09 | +3.06 |
vgg16_bn | 89.24 | 91.58 | +2.34 |
vgg19_bn | 89.59 | 91.60 | +2.01 |
レスネット18 | 85.11 | 88.36 | +3.25 |
レスネット34 | 87.56 | 89.77 | +2.21 |
レスネット50 | 89.20 | 91.32 | +2.12 |
レスネット101 | 89.81 | 91.97 | +2.16 |
レスネット152 | 90.92 | 92.42 | +1.50 |
resnext50_32x4d | 90.17 | 91.48 | +1.31 |
resnext101_32x8d | 91.33 | 92.67 | +1.34 |
ワイドレスネット50_2 | 90.77 | 92.46 | +1.69 |
ワイドレスネット101_2 | 90.93 | 92.10 | +1.17 |
デンスネット121 | 88.81 | 90.35 | +1.54 |
デンスネット169 | 89.68 | 90.61 | +0.93 |
デンスネット201 | 90.36 | 91.32 | +0.96 |
デンスネット161 | 90.82 | 91.66 | +0.84 |
モバイルネット_v2 | 86.50 | 87.73 | +1.23 |
煩雑さを軽減するために、拡張結果 (さまざまなフィルター サイズ) がここにあります。結果の向上にご協力ください。
この作品は、クリエイティブ コモンズ 表示 - 非営利 - 継承 4.0 国際ライセンスに基づいてライセンスされています。
すべての素材は、Adobe Inc. のクリエイティブ コモンズ BY-NC-SA 4.0 ライセンスに基づいて利用可能です。論文を引用し、変更を示すことで適切なクレジットを付与する限り、素材を非営利目的で使用、再配布、改変することができます。あなたが作ったもの。
このリポジトリは、PyTorch サンプル リポジトリと torchvision モデル リポジトリから構築されています。これらは BSD スタイルのライセンスを取得しています。
これがあなたの研究に役立つと思われる場合は、この bibtex を引用することを検討してください。コメントやフィードバックがございましたら、Richard Zhang <rizhang at adobe dot com> までご連絡ください。