** Foi descoberto um bug com a seleção de vizinhos na presença de mascaramento. Se você executou algum experimento anterior à versão 0.1.12 que tinha mascaramento, execute-o novamente. **
Implementação de Redes Neurais de Grafos E(n)-Equivariantes, em Pytorch. Eventualmente pode ser usado para replicação Alphafold2. Essa técnica optou por recursos invariantes simples e acabou superando todos os métodos anteriores (incluindo SE3 Transformer e Lie Conv) tanto em precisão quanto em desempenho. SOTA em modelos de sistemas dinâmicos, tarefas de previsão de atividade molecular, etc.
$ pip install egnn-pytorch
import torch
from egnn_pytorch import EGNN
layer1 = EGNN ( dim = 512 )
layer2 = EGNN ( dim = 512 )
feats = torch . randn ( 1 , 16 , 512 )
coors = torch . randn ( 1 , 16 , 3 )
feats , coors = layer1 ( feats , coors )
feats , coors = layer2 ( feats , coors ) # (1, 16, 512), (1, 16, 3)
Com bordas
import torch
from egnn_pytorch import EGNN
layer1 = EGNN ( dim = 512 , edge_dim = 4 )
layer2 = EGNN ( dim = 512 , edge_dim = 4 )
feats = torch . randn ( 1 , 16 , 512 )
coors = torch . randn ( 1 , 16 , 3 )
edges = torch . randn ( 1 , 16 , 16 , 4 )
feats , coors = layer1 ( feats , coors , edges )
feats , coors = layer2 ( feats , coors , edges ) # (1, 16, 512), (1, 16, 3)
Uma rede EGNN completa
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
num_positions = 1024 , # unless what you are passing in is an unordered set, set this to the maximum sequence length
dim = 32 ,
depth = 3 ,
num_nearest_neighbors = 8 ,
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 )) # (1, 1024)
coors = torch . randn ( 1 , 1024 , 3 ) # (1, 1024, 3)
mask = torch . ones_like ( feats ). bool () # (1, 1024)
feats_out , coors_out = net ( feats , coors , mask = mask ) # (1, 1024, 32), (1, 1024, 3)
Atenda apenas aos vizinhos esparsos, dados à rede como matriz de adjacência.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
only_sparse_neighbors = True
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
Você também pode fazer com que a rede determine automaticamente os vizinhos de enésima ordem e passe uma incorporação de adjacência (dependendo da ordem) para ser usada como uma borda, com dois argumentos extras de palavras-chave
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
num_adj_degrees = 3 , # fetch up to 3rd degree neighbors
adj_dim = 8 , # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
only_sparse_neighbors = True
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
Se precisar passar em arestas contínuas
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
edge_dim = 4 ,
num_nearest_neighbors = 3
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
continuous_edges = torch . randn ( 1 , 1024 , 1024 , 4 )
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , edges = continuous_edges , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
A arquitetura inicial do EGNN sofria de instabilidade quando havia grande número de vizinhos. Felizmente, parece haver duas soluções que atenuam isso em grande parte.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
num_nearest_neighbors = 32 ,
norm_coors = True , # normalize the relative coordinates
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
)
feats = torch . randint ( 0 , 21 , ( 1 , 1024 )) # (1, 1024)
coors = torch . randn ( 1 , 1024 , 3 ) # (1, 1024, 3)
mask = torch . ones_like ( feats ). bool () # (1, 1024)
feats_out , coors_out = net ( feats , coors , mask = mask ) # (1, 1024, 32), (1, 1024, 3)
import torch
from egnn_pytorch import EGNN
model = EGNN (
dim = dim , # input dimension
edge_dim = 0 , # dimension of the edges, if exists, should be > 0
m_dim = 16 , # hidden model dimension
fourier_features = 0 , # number of fourier features for encoding of relative distance - defaults to none as in paper
num_nearest_neighbors = 0 , # cap the number of neighbors doing message passing by relative distance
dropout = 0.0 , # dropout
norm_feats = False , # whether to layernorm the features
norm_coors = False , # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper
update_feats = True , # whether to update features - you can build a layer that only updates one or the other
update_coors = True , # whether ot update coordinates
only_sparse_neighbors = False , # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in
valid_radius = float ( 'inf' ), # the valid radius each node considers for message passing
m_pool_method = 'sum' , # whether to mean or sum pool for output node representation
soft_edges = False , # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
coor_weights_clamp_value = None # clamping of the coordinate updates, again, for stabilization purposes
)
Para executar o exemplo de eliminação de ruído do backbone de proteína, primeiro instale sidechainnet
$ pip install sidechainnet
Então
$ python denoise_sparse.py
Certifique-se de ter o pytorch geométrico instalado localmente
$ python setup.py test
@misc { satorras2021en ,
title = { E(n) Equivariant Graph Neural Networks } ,
author = { Victor Garcia Satorras and Emiel Hoogeboom and Max Welling } ,
year = { 2021 } ,
eprint = { 2102.09844 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}