Implémentation de l'attention axiale dans Pytorch. Une technique simple mais puissante pour traiter efficacement les données multidimensionnelles. Cela a fait des merveilles pour moi et pour de nombreux autres chercheurs.
Ajoutez simplement un codage de position à vos données et transmettez-le dans cette classe pratique, en spécifiant quelle dimension est considérée comme l'intégration et combien de dimensions axiales doivent être parcourues. Toutes les permutations, remodelages, seront pris en charge pour vous.
Ce document a en fait été rejeté parce qu’il était trop simple. Et pourtant, il a depuis été utilisé avec succès dans de nombreuses applications, parmi lesquelles la prévision météorologique ou la segmentation d’images. C'est juste pour montrer.
$ pip install axial_attention
Image
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 3 , 256 , 256 )
attn = AxialAttention (
dim = 3 , # embedding dimension
dim_index = 1 , # where is the embedding dimension
dim_heads = 32 , # dimension of each head. defaults to dim // heads if not supplied
heads = 1 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn ( img ) # (1, 3, 256, 256)
Latentes de la dernière image du canal
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 20 , 20 , 512 )
attn = AxialAttention (
dim = 512 , # embedding dimension
dim_index = - 1 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( img ) # (1, 20, 20 ,512)
Vidéo
import torch
from axial_attention import AxialAttention
video = torch . randn ( 1 , 5 , 128 , 256 , 256 )
attn = AxialAttention (
dim = 128 , # embedding dimension
dim_index = 2 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 3 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( video ) # (1, 5, 128, 256, 256)
Transformateur d'image, avec réseau réversible
import torch
from torch import nn
from axial_attention import AxialImageTransformer
conv1x1 = nn . Conv2d ( 3 , 128 , 1 )
transformer = AxialImageTransformer (
dim = 128 ,
depth = 12 ,
reversible = True
)
img = torch . randn ( 1 , 3 , 512 , 512 )
transformer ( conv1x1 ( img )) # (1, 3, 512, 512)
Avec intégration positionnelle axiale
import torch
from axial_attention import AxialAttention , AxialPositionalEmbedding
img = torch . randn ( 1 , 512 , 20 , 20 )
attn = AxialAttention (
dim = 512 ,
heads = 8 ,
dim_index = 1
)
pos_emb = AxialPositionalEmbedding (
dim = 512 ,
shape = ( 20 , 20 )
)
img = pos_emb ( img ) # (1, 512, 20, 20) - now positionally embedded
img = attn ( img ) # (1, 512, 20, 20)
@misc { ho2019axial ,
title = { Axial Attention in Multidimensional Transformers } ,
author = { Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans } ,
year = { 2019 } ,
archivePrefix = { arXiv }
}
@misc { wang2020axialdeeplab ,
title = { Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation } ,
author = { Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen } ,
year = { 2020 } ,
eprint = { 2003.07853 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@inproceedings { huang2019ccnet ,
title = { Ccnet: Criss-cross attention for semantic segmentation } ,
author = { Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu } ,
booktitle = { Proceedings of the IEEE/CVF International Conference on Computer Vision } ,
pages = { 603--612 } ,
year = { 2019 }
}