axial attention
0.6.1
การดำเนินการตามความสนใจของแกนใน Pytorch เทคนิคง่ายๆ แต่ทรงพลังในการดูแลข้อมูลหลายมิติอย่างมีประสิทธิภาพ มันได้ผลอย่างมหัศจรรย์สำหรับฉันและนักวิจัยคนอื่นๆ อีกหลายคน
เพียงเพิ่มการเข้ารหัสตำแหน่งลงในข้อมูลของคุณและส่งผ่านไปยังคลาสที่มีประโยชน์นี้ โดยระบุว่ามิติใดที่ถือว่าเป็นการฝัง และจำนวนมิติตามแนวแกนที่จะหมุนเวียนผ่าน การสับเปลี่ยน การปรับรูปร่างใหม่ทั้งหมดจะได้รับการดูแลสำหรับคุณ
จริงๆ แล้วบทความนี้ถูกปฏิเสธเพราะเรียบง่ายเกินไป และตั้งแต่นั้นเป็นต้นมา ก็มีการใช้งานอย่างประสบความสำเร็จในแอปพลิเคชั่นจำนวนหนึ่ง รวมถึงการพยากรณ์อากาศ การแบ่งส่วนภาพที่ดึงดูดความสนใจทุกด้าน แค่ไปโชว์..
$ pip install axial_attention
ภาพ
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 3 , 256 , 256 )
attn = AxialAttention (
dim = 3 , # embedding dimension
dim_index = 1 , # where is the embedding dimension
dim_heads = 32 , # dimension of each head. defaults to dim // heads if not supplied
heads = 1 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn ( img ) # (1, 3, 256, 256)
ระยะแฝงของภาพสุดท้ายของช่อง
import torch
from axial_attention import AxialAttention
img = torch . randn ( 1 , 20 , 20 , 512 )
attn = AxialAttention (
dim = 512 , # embedding dimension
dim_index = - 1 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 2 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( img ) # (1, 20, 20 ,512)
วีดีโอ
import torch
from axial_attention import AxialAttention
video = torch . randn ( 1 , 5 , 128 , 256 , 256 )
attn = AxialAttention (
dim = 128 , # embedding dimension
dim_index = 2 , # where is the embedding dimension
heads = 8 , # number of heads for multi-head attention
num_dimensions = 3 , # number of axial dimensions (images is 2, video is 3, or more)
)
attn ( video ) # (1, 5, 128, 256, 256)
Image Transformer พร้อมเครือข่ายแบบพลิกกลับได้
import torch
from torch import nn
from axial_attention import AxialImageTransformer
conv1x1 = nn . Conv2d ( 3 , 128 , 1 )
transformer = AxialImageTransformer (
dim = 128 ,
depth = 12 ,
reversible = True
)
img = torch . randn ( 1 , 3 , 512 , 512 )
transformer ( conv1x1 ( img )) # (1, 3, 512, 512)
ด้วยการฝังตำแหน่งตามแนวแกน
import torch
from axial_attention import AxialAttention , AxialPositionalEmbedding
img = torch . randn ( 1 , 512 , 20 , 20 )
attn = AxialAttention (
dim = 512 ,
heads = 8 ,
dim_index = 1
)
pos_emb = AxialPositionalEmbedding (
dim = 512 ,
shape = ( 20 , 20 )
)
img = pos_emb ( img ) # (1, 512, 20, 20) - now positionally embedded
img = attn ( img ) # (1, 512, 20, 20)
@misc { ho2019axial ,
title = { Axial Attention in Multidimensional Transformers } ,
author = { Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans } ,
year = { 2019 } ,
archivePrefix = { arXiv }
}
@misc { wang2020axialdeeplab ,
title = { Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation } ,
author = { Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen } ,
year = { 2020 } ,
eprint = { 2003.07853 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@inproceedings { huang2019ccnet ,
title = { Ccnet: Criss-cross attention for semantic segmentation } ,
author = { Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu } ,
booktitle = { Proceedings of the IEEE/CVF International Conference on Computer Vision } ,
pages = { 603--612 } ,
year = { 2019 }
}