point transformer pytorch
0.1.5
在 Pytorch 中实现 Point Transformer 自注意力层。上面的简单电路似乎使他们的团队在点云分类和分割方面优于以前的所有方法。
$ pip install point-transformer-pytorch
import torch
from point_transformer_pytorch import PointTransformerLayer
attn = PointTransformerLayer (
dim = 128 ,
pos_mlp_hidden_dim = 64 ,
attn_mlp_hidden_mult = 4
)
feats = torch . randn ( 1 , 16 , 128 )
pos = torch . randn ( 1 , 16 , 3 )
mask = torch . ones ( 1 , 16 ). bool ()
attn ( feats , pos , mask = mask ) # (1, 16, 128)
这种类型的向量注意力比传统的向量注意力要昂贵得多。在论文中,他们使用点上的 k-近邻来排除对远处点的注意力。您可以通过一个额外的设置来执行相同的操作。
import torch
from point_transformer_pytorch import PointTransformerLayer
attn = PointTransformerLayer (
dim = 128 ,
pos_mlp_hidden_dim = 64 ,
attn_mlp_hidden_mult = 4 ,
num_neighbors = 16 # only the 16 nearest neighbors would be attended to for each point
)
feats = torch . randn ( 1 , 2048 , 128 )
pos = torch . randn ( 1 , 2048 , 3 )
mask = torch . ones ( 1 , 2048 ). bool ()
attn ( feats , pos , mask = mask ) # (1, 16, 128)
@misc { zhao2020point ,
title = { Point Transformer } ,
author = { Hengshuang Zhao and Li Jiang and Jiaya Jia and Philip Torr and Vladlen Koltun } ,
year = { 2020 } ,
eprint = { 2012.09164 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}