axial attention
0.6.1
Pytorch での Axial 注意の実装。多次元データを効率的に処理するためのシンプルだが強力なテクニック。それは私や他の多くの研究者にとって驚異的な効果をもたらしました。
データに位置エンコーディングを追加し、それをこの便利なクラスに渡し、埋め込みとみなされる次元と回転する軸方向の次元の数を指定するだけです。並べ替えや再形成はすべてあなたのために行われます。
この論文は実際には単純すぎるという理由で却下されました。それでも、それ以来、天気予報や全注意画像セグメンテーションなど、多くのアプリケーションで使用されて成功しています。ただショーに行くだけです。
$ pip install axial_attention
画像
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 3 , 256 , 256 )
attn = AxialAttention (
dim = 3 , # embedding dimension
dim_index = 1 , # where is the embedding dimension
dim_heads = 32 , # dimension of each head. defaults to dim // heads if not supplied
heads = 1 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn ( img ) # (1, 3, 256, 256)
チャンネル最後の画像潜在
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 20 , 20 , 512 )
attn = AxialAttention (
dim = 512 , # embedding dimension
dim_index = - 1 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( img ) # (1, 20, 20 ,512)
ビデオ
import torch
from axial_attention import AxialAttention
video = torch . randn ( 1 , 5 , 128 , 256 , 256 )
attn = AxialAttention (
dim = 128 , # embedding dimension
dim_index = 2 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 3 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( video ) # (1, 5, 128, 256, 256)
Image Transformer、リバーシブルネットワーク付き
import torch
from torch import nn
from axial_attention import AxialImageTransformer
conv1x1 = nn . Conv2d ( 3 , 128 , 1 )
transformer = AxialImageTransformer (
dim = 128 ,
depth = 12 ,
reversible = True
)
img = torch . randn ( 1 , 3 , 512 , 512 )
transformer ( conv1x1 ( img )) # (1, 3, 512, 512)
軸方向位置埋め込みあり
import torch
from axial_attention import AxialAttention , AxialPositionalEmbedding
img = torch . randn ( 1 , 512 , 20 , 20 )
attn = AxialAttention (
dim = 512 ,
heads = 8 ,
dim_index = 1
)
pos_emb = AxialPositionalEmbedding (
dim = 512 ,
shape = ( 20 , 20 )
)
img = pos_emb ( img ) # (1, 512, 20, 20) - now positionally embedded
img = attn ( img ) # (1, 512, 20, 20)
@misc { ho2019axial ,
title = { Axial Attention in Multidimensional Transformers } ,
author = { Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans } ,
year = { 2019 } ,
archivePrefix = { arXiv }
}
@misc { wang2020axialdeeplab ,
title = { Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation } ,
author = { Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen } ,
year = { 2020 } ,
eprint = { 2003.07853 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@inproceedings { huang2019ccnet ,
title = { Ccnet: Criss-cross attention for semantic segmentation } ,
author = { Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu } ,
booktitle = { Proceedings of the IEEE/CVF International Conference on Computer Vision } ,
pages = { 603--612 } ,
year = { 2019 }
}