egnn pytorch
0.2.8
** 在存在屏蔽的情况下,发现邻居选择存在错误。如果您在 0.1.12 之前运行过任何具有屏蔽的实验,请重新运行它们。 **
在 Pytorch 中实现 E(n)-等变图神经网络。最终可能用于 Alphafold2 复制。该技术追求简单的不变特征,最终在准确性和性能方面击败了所有以前的方法(包括 SE3 Transformer 和 Lie Conv)。动力系统模型、分子活性预测任务等中的SOTA
$ pip install egnn-pytorch
import torch
from egnn_pytorch import EGNN
layer1 = EGNN ( dim = 512 )
layer2 = EGNN ( dim = 512 )
feats = torch . randn ( 1 , 16 , 512 )
coors = torch . randn ( 1 , 16 , 3 )
feats , coors = layer1 ( feats , coors )
feats , coors = layer2 ( feats , coors ) # (1, 16, 512), (1, 16, 3)
有边
import torch
from egnn_pytorch import EGNN
layer1 = EGNN ( dim = 512 , edge_dim = 4 )
layer2 = EGNN ( dim = 512 , edge_dim = 4 )
feats = torch . randn ( 1 , 16 , 512 )
coors = torch . randn ( 1 , 16 , 3 )
edges = torch . randn ( 1 , 16 , 16 , 4 )
feats , coors = layer1 ( feats , coors , edges )
feats , coors = layer2 ( feats , coors , edges ) # (1, 16, 512), (1, 16, 3)
完整的 EGNN 网络
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
num_positions = 1024 , # unless what you are passing in is an unordered set, set this to the maximum sequence length
dim = 32 ,
depth = 3 ,
num_nearest_neighbors = 8 ,
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 )) # (1, 1024)
coors = torch . randn ( 1 , 1024 , 3 ) # (1, 1024, 3)
mask = torch . ones_like ( feats ). bool () # (1, 1024)
feats_out , coors_out = net ( feats , coors , mask = mask ) # (1, 1024, 32), (1, 1024, 3)
只关注稀疏邻居,以邻接矩阵的形式提供给网络。
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
only_sparse_neighbors = True
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
您还可以让网络自动确定 N 阶邻居,并传入邻接嵌入(取决于顺序)以用作边,并带有两个额外的关键字参数
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
num_adj_degrees = 3 , # fetch up to 3rd degree neighbors
adj_dim = 8 , # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
only_sparse_neighbors = True
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
如果需要传入连续的边
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
edge_dim = 4 ,
num_nearest_neighbors = 3
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
continuous_edges = torch . randn ( 1 , 1024 , 1024 , 4 )
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , edges = continuous_edges , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
当邻居数量较多时,EGNN 的初始架构会出现不稳定的情况。值得庆幸的是,似乎有两种解决方案可以在很大程度上缓解这种情况。
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
num_nearest_neighbors = 32 ,
norm_coors = True , # normalize the relative coordinates
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 )) # (1, 1024)
coors = torch . randn ( 1 , 1024 , 3 ) # (1, 1024, 3)
mask = torch . ones_like ( feats ). bool () # (1, 1024)
feats_out , coors_out = net ( feats , coors , mask = mask ) # (1, 1024, 32), (1, 1024, 3)
import torch
from egnn_pytorch import EGNN
model = EGNN (
dim = dim , # input dimension
edge_dim = 0 , # dimension of the edges, if exists, should be > 0
m_dim = 16 , # hidden model dimension
fourier_features = 0 , # number of fourier features for encoding of relative distance - defaults to none as in paper
num_nearest_neighbors = 0 , # cap the number of neighbors doing message passing by relative distance
dropout = 0.0 , # dropout
norm_feats = False , # whether to layernorm the features
norm_coors = False , # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper
update_feats = True , # whether to update features - you can build a layer that only updates one or the other
update_coors = True , # whether ot update coordinates
only_sparse_neighbors = False , # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in
valid_radius = float ( 'inf' ), # the valid radius each node considers for message passing
m_pool_method = 'sum' , # whether to mean or sum pool for output node representation
soft_edges = False , # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
coor_weights_clamp_value = None # clamping of the coordinate updates, again, for stabilization purposes
)
要运行蛋白质主干去噪示例,首先安装sidechainnet
$ pip install sidechainnet
然后
$ python denoise_sparse.py
确保本地安装了 pytorch 几何
$ python setup.py test
@misc { satorras2021en ,
title = { E(n) Equivariant Graph Neural Networks } ,
author = { Victor Garcia Satorras and Emiel Hoogeboom and Max Welling } ,
year = { 2021 } ,
eprint = { 2102.09844 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}