axial attention
0.6.1
Pytorch에서 Axial attention을 구현합니다. 다차원 데이터를 효율적으로 관리할 수 있는 간단하지만 강력한 기술입니다. 그것은 나와 다른 많은 연구자들에게 놀라운 일이었습니다.
데이터에 일부 위치 인코딩을 추가하고 이를 이 편리한 클래스에 전달하여 임베딩으로 간주되는 차원과 회전할 축 차원 수를 지정하기만 하면 됩니다. 모든 순열, 재형성이 귀하를 위해 처리됩니다.
이 논문은 실제로 너무 단순하다는 이유로 거부되었습니다. 그러나 이후 날씨 예측, 집중 이미지 분할 등 다양한 응용 분야에서 성공적으로 사용되었습니다. 그냥 보여주러 갑니다.
$ pip install axial_attention
영상
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 3 , 256 , 256 )
attn = AxialAttention (
dim = 3 , # embedding dimension
dim_index = 1 , # where is the embedding dimension
dim_heads = 32 , # dimension of each head. defaults to dim // heads if not supplied
heads = 1 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn ( img ) # (1, 3, 256, 256)
채널 마지막 이미지 잠재성
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 20 , 20 , 512 )
attn = AxialAttention (
dim = 512 , # embedding dimension
dim_index = - 1 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( img ) # (1, 20, 20 ,512)
동영상
import torch
from axial_attention import AxialAttention
video = torch . randn ( 1 , 5 , 128 , 256 , 256 )
attn = AxialAttention (
dim = 128 , # embedding dimension
dim_index = 2 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 3 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( video ) # (1, 5, 128, 256, 256)
가역 네트워크를 갖춘 이미지 변환기
import torch
from torch import nn
from axial_attention import AxialImageTransformer
conv1x1 = nn . Conv2d ( 3 , 128 , 1 )
transformer = AxialImageTransformer (
dim = 128 ,
depth = 12 ,
reversible = True
)
img = torch . randn ( 1 , 3 , 512 , 512 )
transformer ( conv1x1 ( img )) # (1, 3, 512, 512)
축 위치 임베딩 사용
import torch
from axial_attention import AxialAttention , AxialPositionalEmbedding
img = torch . randn ( 1 , 512 , 20 , 20 )
attn = AxialAttention (
dim = 512 ,
heads = 8 ,
dim_index = 1
)
pos_emb = AxialPositionalEmbedding (
dim = 512 ,
shape = ( 20 , 20 )
)
img = pos_emb ( img ) # (1, 512, 20, 20) - now positionally embedded
img = attn ( img ) # (1, 512, 20, 20)
@misc { ho2019axial ,
title = { Axial Attention in Multidimensional Transformers } ,
author = { Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans } ,
year = { 2019 } ,
archivePrefix = { arXiv }
}
@misc { wang2020axialdeeplab ,
title = { Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation } ,
author = { Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen } ,
year = { 2020 } ,
eprint = { 2003.07853 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@inproceedings { huang2019ccnet ,
title = { Ccnet: Criss-cross attention for semantic segmentation } ,
author = { Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu } ,
booktitle = { Proceedings of the IEEE/CVF International Conference on Computer Vision } ,
pages = { 603--612 } ,
year = { 2019 }
}