Rendre les réseaux convolutifs à nouveau invariants
Richard Zhang. Dans ICML, 2019.
Exécutez pip install antialiased-cnns
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True )
Si vous avez déjà un modèle et que vous souhaitez effectuer un antialias et continuer l'entraînement, copiez vos anciennes pondérations :
import torchvision . models as models
old_model = models . resnet50 ( pretrained = True ) # old (aliased) model
antialiased_cnns . copy_params_buffers ( old_model , model ) # copy the weights over
Si vous souhaitez modifier votre propre modèle, utilisez le calque BlurPool. Plus d’informations sur nos modèles fournis et sur la façon d’utiliser BlurPool sont ci-dessous.
C = 10 # example feature channel size
blurpool = antialiased_cnns . BlurPool ( C , stride = 2 ) # BlurPool layer; use to downsample a feature map
ex_tens = torch . Tensor ( 1 , C , 128 , 128 )
print ( blurpool ( ex_tens ). shape ) # 1xCx64x64 tensor
Mises à jour
pip install antialiased-cnns
et charger des modèles avec l'indicateur pretrained=True
.BlurPool
Pip installe ce paquet
pip install antialiased-cnns
Ou clonez ce référentiel et installez les exigences (notamment PyTorch)
https://github.com/adobe/antialiased-cnns.git
cd antialiased-cnns
pip install -r requirements.txt
Ce qui suit charge un modèle anticrénelé pré-entraîné, peut-être comme épine dorsale de votre application.
import antialiased_cnns
model = antialiased_cnns . resnet50 ( pretrained = True , filter_size = 4 )
Nous fournissons également des pondérations pour AlexNet
, VGG16(bn)
, Resnet18,34,50,101
, Densenet121
et MobileNetv2
anticrénelés (voir example_usage.py).
Le module antialiased_cnns
contient la classe BlurPool
, qui effectue le flou + le sous-échantillonnage. Exécutez pip install antialiased-cnns
ou copiez le sous-répertoire antialiased_cnns
.
Méthodologie La méthodologie est simple : évaluez d'abord avec la foulée 1, puis utilisez notre couche BlurPool
pour effectuer un sous-échantillonnage avec anticrénelage. Apportez les modifications architecturales suivantes.
import antialiased_cnns
# MaxPool --> MaxBlurPool
baseline = nn . MaxPool2d ( kernel_size = 2 , stride = 2 )
antialiased = [ nn . MaxPool2d ( kernel_size = 2 , stride = 1 ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# Conv --> ConvBlurPool
baseline = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 2 , padding = 1 ),
nn . ReLU ( inplace = True )]
antialiased = [ nn . Conv2d ( Cin , C , kernel_size = 3 , stride = 1 , padding = 1 ),
nn . ReLU ( inplace = True ),
antialiased_cnns . BlurPool ( C , stride = 2 )]
# AvgPool --> BlurPool
baseline = nn . AvgPool2d ( kernel_size = 2 , stride = 2 )
antialiased = antialiased_cnns . BlurPool ( C , stride = 2 )
Nous supposons que le tenseur entrant a des canaux C
Le calcul d'une couche à la foulée 1 au lieu de la foulée 2 ajoute de la mémoire et du temps d'exécution. En tant que tel, nous ignorons généralement l’anticrénelage à la résolution la plus élevée (au début du réseau), pour éviter des augmentations importantes.
Ajoutez un anticrénelage, puis continuez l'entraînement. Si vous avez déjà entraîné un modèle, puis ajoutez un anticrénelage, vous pouvez affiner l'ajustement à partir de cet ancien modèle :
antialiased_cnns . copy_params_buffers ( old_model , antialiased_model )
Si cela ne fonctionne pas, vous pouvez simplement copier les paramètres (et non les tampons). L'ajout d'un anticrénelage n'ajoute aucun paramètre, les listes de paramètres sont donc identiques. (Il ajoute des tampons, donc une heuristique est utilisée pour faire correspondre les tampons, ce qui peut générer une erreur.)
antialiased_cnns . copy_params ( old_model , antialiased_model )
Nous observons des améliorations à la fois en termes de précision (à quelle fréquence l'image est classée correctement) et de cohérence (à quelle fréquence deux décalages de la même image sont classés de la même manière).
PRÉCISION | Référence | Anticrénelé | Delta |
---|---|---|---|
Alexnet | 56.55 | 56,94 | +0,39 |
vgg11 | 69.02 | 70.51 | +1,49 |
vgg13 | 69,93 | 71.52 | +1,59 |
vgg16 | 71.59 | 72,96 | +1,37 |
vgg19 | 72.38 | 73.54 | +1,16 |
vgg11_bn | 70.38 | 72,63 | +2,25 |
vgg13_bn | 71.55 | 73.61 | +2,06 |
vgg16_bn | 73.36 | 75.13 | +1,77 |
vgg19_bn | 74.24 | 75,68 | +1,44 |
resnet18 | 69,74 | 71,67 | +1,93 |
resnet34 | 73h30 | 74.60 | +1.30 |
resnet50 | 76.16 | 77.41 | +1,25 |
resnet101 | 77.37 | 78.38 | +1,01 |
resnet152 | 78.31 | 79.07 | +0,76 |
resnext50_32x4d | 77.62 | 77,93 | +0,31 |
resnext101_32x8d | 79.31 | 79.33 | +0,02 |
large_resnet50_2 | 78.47 | 78.70 | +0,23 |
large_resnet101_2 | 78,85 | 78,99 | +0,14 |
densenet121 | 74.43 | 75,79 | +1,36 |
densenet169 | 75.60 | 76,73 | +1,13 |
densenet201 | 76.90 | 77.31 | +0,41 |
densenet161 | 77.14 | 77,88 | +0,74 |
mobilenet_v2 | 71,88 | 72,72 | +0,84 |
COHÉRENCE | Référence | Anticrénelé | Delta |
---|---|---|---|
Alexnet | 78.18 | 83.31 | +5.13 |
vgg11 | 86,58 | 90.09 | +3,51 |
vgg13 | 86,92 | 90.31 | +3,39 |
vgg16 | 88.52 | 90.91 | +2,39 |
vgg19 | 89.17 | 91.08 | +1,91 |
vgg11_bn | 87.16 | 90,67 | +3,51 |
vgg13_bn | 88.03 | 91.09 | +3.06 |
vgg16_bn | 89.24 | 91.58 | +2,34 |
vgg19_bn | 89.59 | 91.60 | +2.01 |
resnet18 | 85.11 | 88.36 | +3,25 |
resnet34 | 87,56 | 89,77 | +2,21 |
resnet50 | 89.20 | 91.32 | +2,12 |
resnet101 | 89,81 | 91,97 | +2,16 |
resnet152 | 90.92 | 92.42 | +1,50 |
resnext50_32x4d | 90.17 | 91.48 | +1,31 |
resnext101_32x8d | 91.33 | 92,67 | +1,34 |
large_resnet50_2 | 90,77 | 92.46 | +1,69 |
large_resnet101_2 | 90.93 | 92.10 | +1,17 |
densenet121 | 88,81 | 90.35 | +1,54 |
densenet169 | 89,68 | 90.61 | +0,93 |
densenet201 | 90.36 | 91.32 | +0,96 |
densenet161 | 90,82 | 91,66 | +0,84 |
mobilenet_v2 | 86,50 | 87,73 | +1,23 |
Pour réduire l'encombrement, des résultats étendus (différentes tailles de filtre) sont ici. Aidez-nous à améliorer les résultats !
Ce travail est sous licence internationale Creative Commons Attribution-Pas d’Utilisation Commerciale-Partage dans les mêmes conditions 4.0.
Tout le matériel est mis à disposition sous la licence Creative Commons BY-NC-SA 4.0 par Adobe Inc. Vous pouvez utiliser, redistribuer et adapter le matériel à des fins non commerciales , à condition de donner le crédit approprié en citant notre article et en indiquant toute modification. que tu as fait.
Le référentiel s'appuie sur le référentiel d'exemples PyTorch et le référentiel de modèles Torchvision. Ceux-ci sont sous licence de style BSD.
Si vous trouvez cela utile pour votre recherche, pensez à citer ce bibtex. Veuillez contacter Richard Zhang <rizhang at adobe dot com> pour tout commentaire ou réaction.