Implementasi perhatian aksial di Pytorch. Sebuah teknik sederhana namun ampuh untuk menangani data multidimensi secara efisien. Ini telah menghasilkan keajaiban bagi saya dan banyak peneliti lainnya.
Cukup tambahkan beberapa pengkodean posisi ke data Anda dan teruskan ke kelas praktis ini, tentukan dimensi mana yang dianggap sebagai penyematan, dan berapa banyak dimensi aksial yang akan diputar. Semua permutasi, pembentukan kembali, akan diurus untuk Anda.
Tulisan ini sebenarnya ditolak dengan alasan terlalu sederhana. Namun, ini telah berhasil digunakan dalam sejumlah aplikasi, di antaranya prediksi cuaca, segmentasi gambar yang menarik perhatian. Hanya pergi untuk menunjukkan.
$ pip install axial_attention
Gambar
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 3 , 256 , 256 )
attn = AxialAttention (
dim = 3 , # embedding dimension
dim_index = 1 , # where is the embedding dimension
dim_heads = 32 , # dimension of each head. defaults to dim // heads if not supplied
heads = 1 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn ( img ) # (1, 3, 256, 256)
Gambar saluran-terakhir bersifat laten
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 20 , 20 , 512 )
attn = AxialAttention (
dim = 512 , # embedding dimension
dim_index = - 1 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( img ) # (1, 20, 20 ,512)
Video
import torch
from axial_attention import AxialAttention
video = torch . randn ( 1 , 5 , 128 , 256 , 256 )
attn = AxialAttention (
dim = 128 , # embedding dimension
dim_index = 2 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 3 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( video ) # (1, 5, 128, 256, 256)
Image Transformer, dengan jaringan yang dapat dibalik
import torch
from torch import nn
from axial_attention import AxialImageTransformer
conv1x1 = nn . Conv2d ( 3 , 128 , 1 )
transformer = AxialImageTransformer (
dim = 128 ,
depth = 12 ,
reversible = True
)
img = torch . randn ( 1 , 3 , 512 , 512 )
transformer ( conv1x1 ( img )) # (1, 3, 512, 512)
Dengan penyematan posisi aksial
import torch
from axial_attention import AxialAttention , AxialPositionalEmbedding
img = torch . randn ( 1 , 512 , 20 , 20 )
attn = AxialAttention (
dim = 512 ,
heads = 8 ,
dim_index = 1
)
pos_emb = AxialPositionalEmbedding (
dim = 512 ,
shape = ( 20 , 20 )
)
img = pos_emb ( img ) # (1, 512, 20, 20) - now positionally embedded
img = attn ( img ) # (1, 512, 20, 20)
@misc { ho2019axial ,
title = { Axial Attention in Multidimensional Transformers } ,
author = { Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans } ,
year = { 2019 } ,
archivePrefix = { arXiv }
}
@misc { wang2020axialdeeplab ,
title = { Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation } ,
author = { Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen } ,
year = { 2020 } ,
eprint = { 2003.07853 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@inproceedings { huang2019ccnet ,
title = { Ccnet: Criss-cross attention for semantic segmentation } ,
author = { Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu } ,
booktitle = { Proceedings of the IEEE/CVF International Conference on Computer Vision } ,
pages = { 603--612 } ,
year = { 2019 }
}