egnn pytorch
** 마스킹이 있는 경우 이웃 선택에 버그가 발견되었습니다. 마스킹이 포함된 0.1.12 이전 실험을 실행한 경우 다시 실행하세요. **
Pytorch에서 E(n)-등변 그래프 신경망 구현. 결국 Alphafold2 복제에 사용될 수 있습니다. 이 기술은 단순 불변 기능에 사용되었으며 정확성과 성능 모두에서 이전의 모든 방법(SE3 Transformer 및 Lie Conv 포함)을 능가했습니다. 동적 시스템 모델, 분자 활동 예측 작업 등의 SOTA
$ pip install egnn-pytorch
import torch
from egnn_pytorch import EGNN
layer1 = EGNN ( dim = 512 )
layer2 = EGNN ( dim = 512 )
feats = torch . randn ( 1 , 16 , 512 )
coors = torch . randn ( 1 , 16 , 3 )
feats , coors = layer1 ( feats , coors )
feats , coors = layer2 ( feats , coors ) # (1, 16, 512), (1, 16, 3)
가장자리 있음
import torch
from egnn_pytorch import EGNN
layer1 = EGNN ( dim = 512 , edge_dim = 4 )
layer2 = EGNN ( dim = 512 , edge_dim = 4 )
feats = torch . randn ( 1 , 16 , 512 )
coors = torch . randn ( 1 , 16 , 3 )
edges = torch . randn ( 1 , 16 , 16 , 4 )
feats , coors = layer1 ( feats , coors , edges )
feats , coors = layer2 ( feats , coors , edges ) # (1, 16, 512), (1, 16, 3)
완전한 EGNN 네트워크
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
num_positions = 1024 , # unless what you are passing in is an unordered set, set this to the maximum sequence length
dim = 32 ,
depth = 3 ,
num_nearest_neighbors = 8 ,
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
feats = torch . randint ( 0 , 21 , ( 1 , 1024 )) # (1, 1024)
coors = torch . randn ( 1 , 1024 , 3 ) # (1, 1024, 3)
mask = torch . ones_like ( feats ). bool () # (1, 1024)
feats_out , coors_out = net ( feats , coors , mask = mask ) # (1, 1024, 32), (1, 1024, 3)
인접 행렬로 네트워크에 제공되는 희소한 이웃에만 주의를 기울이십시오.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
only_sparse_neighbors = True
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
또한 네트워크가 N차 이웃을 자동으로 결정하고 두 개의 추가 키워드 인수와 함께 가장자리로 사용할 인접 임베딩(순서에 따라)을 전달하도록 할 수도 있습니다.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
num_adj_degrees = 3 , # fetch up to 3rd degree neighbors
adj_dim = 8 , # pass an adjacency degree embedding to the EGNN layer, to be used in the edge MLP
only_sparse_neighbors = True
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
연속적인 가장자리를 전달해야 하는 경우
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
edge_dim = 4 ,
num_nearest_neighbors = 3
feats = torch . randint ( 0 , 21 , ( 1 , 1024 ))
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones_like ( feats ). bool ()
continuous_edges = torch . randn ( 1 , 1024 , 1024 , 4 )
# naive adjacency matrix
# assuming the sequence is connected as a chain, with at most 2 neighbors - (1024, 1024)
i = torch . arange ( 1024 )
adj_mat = ( i [:, None ] >= ( i [ None , :] - 1 )) & ( i [:, None ] <= ( i [ None , :] + 1 ))
feats_out , coors_out = net ( feats , coors , edges = continuous_edges , mask = mask , adj_mat = adj_mat ) # (1, 1024, 32), (1, 1024, 3)
EGNN의 초기 아키텍처는 이웃 수가 많을 때 불안정해졌습니다. 다행히도 이 문제를 크게 완화할 수 있는 두 가지 솔루션이 있는 것 같습니다.
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network (
num_tokens = 21 ,
dim = 32 ,
depth = 3 ,
num_nearest_neighbors = 32 ,
norm_coors = True , # normalize the relative coordinates
coor_weights_clamp_value = 2. # absolute clamped value for the coordinate weights, needed if you increase the num neareest neighbors
feats = torch . randint ( 0 , 21 , ( 1 , 1024 )) # (1, 1024)
coors = torch . randn ( 1 , 1024 , 3 ) # (1, 1024, 3)
mask = torch . ones_like ( feats ). bool () # (1, 1024)
feats_out , coors_out = net ( feats , coors , mask = mask ) # (1, 1024, 32), (1, 1024, 3)
import torch
from egnn_pytorch import EGNN
model = EGNN (
dim = dim , # input dimension
edge_dim = 0 , # dimension of the edges, if exists, should be > 0
m_dim = 16 , # hidden model dimension
fourier_features = 0 , # number of fourier features for encoding of relative distance - defaults to none as in paper
num_nearest_neighbors = 0 , # cap the number of neighbors doing message passing by relative distance
dropout = 0.0 , # dropout
norm_feats = False , # whether to layernorm the features
norm_coors = False , # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper
update_feats = True , # whether to update features - you can build a layer that only updates one or the other
update_coors = True , # whether ot update coordinates
only_sparse_neighbors = False , # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in
valid_radius = float ( 'inf' ), # the valid radius each node considers for message passing
m_pool_method = 'sum' , # whether to mean or sum pool for output node representation
soft_edges = False , # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper
coor_weights_clamp_value = None # clamping of the coordinate updates, again, for stabilization purposes
단백질 백본 노이즈 제거 예제를 실행하려면 먼저 sidechainnet
$ pip install sidechainnet
그 다음에
$ python
로컬에 pytorch 기하학이 설치되어 있는지 확인하십시오.
$ python test
@misc { satorras2021en ,
title = { E(n) Equivariant Graph Neural Networks } ,
author = { Victor Garcia Satorras and Emiel Hoogeboom and Max Welling } ,
year = { 2021 } ,
eprint = { 2102.09844 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }