การใช้งาน SE3-Transformers เพื่อการเอาใจใส่ตนเองอย่างเท่าเทียมกันใน Pytorch อาจจำเป็นสำหรับการจำลองผลลัพธ์ Alphafold2 และแอปพลิเคชันการค้นคว้ายาอื่นๆ
ตัวอย่างความเท่าเทียมกัน
หากคุณเคยใช้ SE3 Transformers เวอร์ชันก่อนหน้าเวอร์ชัน 0.6.0 โปรดอัปเดต @MattMcPartlon ได้ค้นพบข้อผิดพลาดใหญ่ หากคุณไม่ได้ใช้การตั้งค่าเพื่อนบ้านแบบกระจัดกระจาย adjacency และอาศัยฟังก์ชันการทำงานของเพื่อนบ้านที่ใกล้ที่สุด
อัปเดต: ขอแนะนำให้คุณใช้ Equiformer แทน
$ pip install se3-transformer-pytorch
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 512 ,
heads = 8 ,
depth = 6 ,
dim_head = 64 ,
num_degrees = 4 ,
valid_radius = 10
)
feats = torch . randn ( 1 , 1024 , 512 )
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones ( 1 , 1024 ). bool ()
out = model ( feats , coors , mask ) # (1, 1024, 512)
ตัวอย่างการใช้งานที่เป็นไปได้ใน Alphafold2 ดังที่อธิบายไว้ที่นี่
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 2 ,
input_degrees = 1 ,
num_degrees = 2 ,
output_degrees = 2 ,
reduce_dim_out = True ,
differentiable_coors = True
)
atom_feats = torch . randn ( 2 , 32 , 64 )
coors = torch . randn ( 2 , 32 , 3 )
mask = torch . ones ( 2 , 32 ). bool ()
refined_coors = coors + model ( atom_feats , coors , mask , return_type = 1 ) # (2, 32, 3)
คุณยังสามารถปล่อยให้คลาสหม้อแปลงฐานดูแลการฝังคุณสมบัติประเภท 0 ที่ถูกส่งเข้ามา สมมติว่าพวกมันคืออะตอม
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
num_tokens = 28 , # 28 unique atoms
dim = 64 ,
depth = 2 ,
input_degrees = 1 ,
num_degrees = 2 ,
output_degrees = 2 ,
reduce_dim_out = True
)
atoms = torch . randint ( 0 , 28 , ( 2 , 32 ))
coors = torch . randn ( 2 , 32 , 3 )
mask = torch . ones ( 2 , 32 ). bool ()
refined_coors = coors + model ( atoms , coors , mask , return_type = 1 ) # (2, 32, 3)
หากคุณคิดว่าเน็ตจะได้รับประโยชน์เพิ่มเติมจากการเข้ารหัสตำแหน่ง คุณสามารถระบุตำแหน่งของคุณในอวกาศและส่งต่อไปได้ดังนี้
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 2 ,
input_degrees = 2 ,
num_degrees = 2 ,
output_degrees = 2 ,
reduce_dim_out = True # reduce out the final dimension
)
atom_feats = torch . randn ( 2 , 32 , 64 , 1 ) # b x n x d x type0
coors_feats = torch . randn ( 2 , 32 , 64 , 3 ) # b x n x d x type1
# atom features are type 0, predicted coordinates are type 1
features = { '0' : atom_feats , '1' : coors_feats }
coors = torch . randn ( 2 , 32 , 3 )
mask = torch . ones ( 2 , 32 ). bool ()
refined_coors = coors + model ( features , coors , mask , return_type = 1 ) # (2, 32, 3) - equivariant to input type 1 features and coordinates
หากต้องการนำเสนอข้อมูล Edge ให้กับ SE3 Transformers (เช่น ประเภทพันธะระหว่างอะตอม) คุณเพียงแค่ต้องส่งผ่านอาร์กิวเมนต์คำหลักอีกสองรายการในการเริ่มต้น
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
num_tokens = 28 ,
dim = 64 ,
num_edge_tokens = 4 , # number of edge type, say 4 bond types
edge_dim = 16 , # dimension of edge embedding
depth = 2 ,
input_degrees = 1 ,
num_degrees = 3 ,
output_degrees = 1 ,
reduce_dim_out = True
)
atoms = torch . randint ( 0 , 28 , ( 2 , 32 ))
bonds = torch . randint ( 0 , 4 , ( 2 , 32 , 32 ))
coors = torch . randn ( 2 , 32 , 3 )
mask = torch . ones ( 2 , 32 ). bool ()
pred = model ( atoms , coors , mask , edges = bonds , return_type = 0 ) # (2, 32, 1)
หากคุณต้องการส่งผ่านค่าต่อเนื่องสำหรับ Edge ของคุณ คุณสามารถเลือกที่จะไม่ตั้งค่า num_edge_tokens
เข้ารหัสประเภทบอนด์แยกของคุณ จากนั้นต่อเข้ากับคุณสมบัติโฟร์เรียร์ของค่าต่อเนื่องเหล่านี้
import torch
from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch . utils import fourier_encode
model = SE3Transformer (
dim = 64 ,
depth = 1 ,
attend_self = True ,
num_degrees = 2 ,
output_degrees = 2 ,
edge_dim = 34 # edge dimension must match the final dimension of the edges being passed in
)
feats = torch . randn ( 1 , 32 , 64 )
coors = torch . randn ( 1 , 32 , 3 )
mask = torch . ones ( 1 , 32 ). bool ()
pairwise_continuous_values = torch . randint ( 0 , 4 , ( 1 , 32 , 32 , 2 )) # say there are 2
edges = fourier_encode (
pairwise_continuous_values ,
num_encodings = 8 ,
include_self = True
) # (1, 32, 32, 34) - {2 * (2 * 8 + 1)}
out = model ( feats , coors , mask , edges = edges , return_type = 1 )
หากคุณทราบการเชื่อมต่อของจุดต่างๆ ของคุณ (เช่น คุณกำลังทำงานกับโมเลกุล) คุณสามารถส่งผ่านเมทริกซ์ adjacency ในรูปแบบของบูลีนมาสก์ (โดยที่ True
แสดงถึงการเชื่อมต่อ)
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 32 ,
heads = 8 ,
depth = 1 ,
dim_head = 64 ,
num_degrees = 2 ,
valid_radius = 10 ,
attend_sparse_neighbors = True , # this must be set to true, in which case it will assert that you pass in the adjacency matrix
num_neighbors = 0 , # if you set this to 0, it will only consider the connected neighbors as defined by the adjacency matrix. but if you set a value greater than 0, it will continue to fetch the closest points up to this many, excluding the ones already specified by the adjacency matrix
max_sparse_neighbors = 8 # you can cap the number of neighbors, sampled from within your sparse set of neighbors as defined by the adjacency matrix, if specified
)
feats = torch . randn ( 1 , 128 , 32 )
coors = torch . randn ( 1 , 128 , 3 )
mask = torch . ones ( 1 , 128 ). bool ()
# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)
i = torch . arange ( 128 )
adj_mat = ( i [:, None ] <= ( i [ None , :] + 1 )) & ( i [:, None ] >= ( i [ None , :] - 1 ))
out = model ( feats , coors , mask , adj_mat = adj_mat ) # (1, 128, 512)
คุณยังสามารถให้เครือข่ายรับเพื่อนบ้านระดับ N โดยอัตโนมัติด้วยคำหลักพิเศษหนึ่งคำ num_adj_degrees
หากคุณต้องการให้ระบบแยกความแตกต่างระหว่างระดับของเพื่อนบ้านเป็นข้อมูลขอบ ให้ส่งค่า adj_dim
ที่ไม่เป็นศูนย์เพิ่มเติม
import torch
from se3_transformer_pytorch . se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 1 ,
attend_self = True ,
num_degrees = 2 ,
output_degrees = 2 ,
num_neighbors = 0 ,
attend_sparse_neighbors = True ,
num_adj_degrees = 2 , # automatically derive 2nd degree neighbors
adj_dim = 4 # embed 1st and 2nd degree neighbors (as well as null neighbors) with edge embeddings of this dimension
)
feats = torch . randn ( 1 , 32 , 64 )
coors = torch . randn ( 1 , 32 , 3 )
mask = torch . ones ( 1 , 32 ). bool ()
# placeholder adjacency matrix
# naively assuming the sequence is one long chain (128, 128)
i = torch . arange ( 128 )
adj_mat = ( i [:, None ] <= ( i [ None , :] + 1 )) & ( i [:, None ] >= ( i [ None , :] - 1 ))
out = model ( feats , coors , mask , adj_mat = adj_mat , return_type = 1 )
เพื่อให้สามารถควบคุมมิติข้อมูลของแต่ละประเภทได้อย่างละเอียด คุณสามารถใช้คีย์เวิร์ด hidden_fiber_dict
และ out_fiber_dict
เพื่อส่งผ่านพจนานุกรมโดยมีค่าระดับเป็นค่ามิติเป็นคีย์ / ค่า
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
num_tokens = 28 ,
dim = 64 ,
num_edge_tokens = 4 ,
edge_dim = 16 ,
depth = 2 ,
input_degrees = 1 ,
num_degrees = 3 ,
output_degrees = 1 ,
hidden_fiber_dict = { 0 : 16 , 1 : 8 , 2 : 4 },
out_fiber_dict = { 0 : 16 , 1 : 1 },
reduce_dim_out = False
)
atoms = torch . randint ( 0 , 28 , ( 2 , 32 ))
bonds = torch . randint ( 0 , 4 , ( 2 , 32 , 32 ))
coors = torch . randn ( 2 , 32 , 3 )
mask = torch . ones ( 2 , 32 ). bool ()
pred = model ( atoms , coors , mask , edges = bonds )
pred [ '0' ] # (2, 32, 16)
pred [ '1' ] # (2, 32, 1, 3)
คุณสามารถควบคุมเพิ่มเติมได้ว่าโหนดใดบ้างที่สามารถพิจารณาได้โดยการส่งผ่านมาสก์เพื่อนบ้าน ค่า False
ทั้งหมดจะถูกปกปิดโดยไม่พิจารณา
import torch
from se3_transformer_pytorch . se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 16 ,
dim_head = 16 ,
attend_self = True ,
num_degrees = 4 ,
output_degrees = 2 ,
num_edge_tokens = 4 ,
num_neighbors = 8 , # make sure you set this value as the maximum number of neighbors set by your neighbor_mask, or it will throw a warning
edge_dim = 2 ,
depth = 3
)
feats = torch . randn ( 1 , 32 , 16 )
coors = torch . randn ( 1 , 32 , 3 )
mask = torch . ones ( 1 , 32 ). bool ()
bonds = torch . randint ( 0 , 4 , ( 1 , 32 , 32 ))
neighbor_mask = torch . ones ( 1 , 32 , 32 ). bool () # set the nodes you wish to be masked out as False
out = model (
feats ,
coors ,
mask ,
edges = bonds ,
neighbor_mask = neighbor_mask ,
return_type = 1
)
คุณลักษณะนี้ช่วยให้คุณสามารถส่งผ่านเวกเตอร์ที่สามารถดูได้เป็นโหนดส่วนกลางที่โหนดอื่นทั้งหมดมองเห็นได้ แนวคิดนี้คือการรวมกราฟของคุณไว้ในเวกเตอร์ฟีเจอร์ไม่กี่ตัว ซึ่งจะถูกฉายไปที่คีย์ / ค่าทั่วเลเยอร์ความสนใจทั้งหมดในเครือข่าย โหนดทั้งหมดจะสามารถเข้าถึงข้อมูลโหนดทั่วโลกได้อย่างเต็มที่ โดยไม่คำนึงถึงเพื่อนบ้านที่ใกล้ที่สุดหรือการคำนวณที่อยู่ติดกัน
import torch
from torch import nn
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 1 ,
num_degrees = 2 ,
num_neighbors = 4 ,
valid_radius = 10 ,
global_feats_dim = 32 # this must be set to the dimension of the global features, in this example, 32
)
feats = torch . randn ( 1 , 32 , 64 )
coors = torch . randn ( 1 , 32 , 3 )
mask = torch . ones ( 1 , 32 ). bool ()
# naively derive global features
# by pooling features and projecting
global_feats = nn . Linear ( 64 , 32 )( feats . mean ( dim = 1 , keepdim = True )) # (1, 1, 32)
out = model ( feats , coors , mask , return_type = 0 , global_feats = global_feats )
สิ่งที่ต้องทำ:
คุณสามารถใช้ SE3 Transformers แบบถดถอยอัตโนมัติได้โดยใช้แฟล็กพิเศษเพียงแฟล็กเดียว
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 512 ,
heads = 8 ,
depth = 6 ,
dim_head = 64 ,
num_degrees = 4 ,
valid_radius = 10 ,
causal = True # set this to True
)
feats = torch . randn ( 1 , 1024 , 512 )
coors = torch . randn ( 1 , 1024 , 3 )
mask = torch . ones ( 1 , 1024 ). bool ()
out = model ( feats , coors , mask ) # (1, 1024, 512)
ฉันได้ค้นพบว่าการใช้คีย์ที่ฉายเชิงเส้น (แทนที่จะใช้การบิดแบบคู่) ดูเหมือนว่าจะได้ผลดีในงานลดเสียงรบกวนของของเล่น สิ่งนี้นำไปสู่การประหยัดหน่วยความจำ 25% คุณสามารถลองใช้คุณสมบัตินี้ได้โดยการตั้งค่า linear_proj_keys = True
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 1 ,
num_degrees = 4 ,
num_neighbors = 8 ,
valid_radius = 10 ,
splits = 4 ,
linear_proj_keys = True # set this to True
). cuda ()
feats = torch . randn ( 1 , 32 , 64 ). cuda ()
coors = torch . randn ( 1 , 32 , 3 ). cuda ()
mask = torch . ones ( 1 , 32 ). bool (). cuda ()
out = model ( feats , coors , mask , return_type = 0 )
มีเทคนิคที่ค่อนข้างไม่เป็นที่รู้จักสำหรับ Transformers โดยเราสามารถแชร์ส่วนหัวของคีย์/ค่าหนึ่งรายการกับส่วนหัวทั้งหมดของข้อความค้นหาได้ จากประสบการณ์ของฉันใน NLP สิ่งนี้มักจะนำไปสู่ประสิทธิภาพที่แย่ลง แต่ถ้าคุณต้องการแลกหน่วยความจำกับความลึกที่มากขึ้นหรือจำนวนองศาที่มากขึ้น นี่อาจเป็นตัวเลือกที่ดี
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 8 ,
num_degrees = 4 ,
num_neighbors = 8 ,
valid_radius = 10 ,
splits = 4 ,
one_headed_key_values = True # one head of key / values shared across all heads of the queries
). cuda ()
feats = torch . randn ( 1 , 32 , 64 ). cuda ()
coors = torch . randn ( 1 , 32 , 3 ). cuda ()
mask = torch . ones ( 1 , 32 ). bool (). cuda ()
out = model ( feats , coors , mask , return_type = 0 )
คุณยังสามารถผูกคีย์ / ค่า (ให้เหมือนกัน) เพื่อประหยัดหน่วยความจำครึ่งหนึ่ง
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 64 ,
depth = 8 ,
num_degrees = 4 ,
num_neighbors = 8 ,
valid_radius = 10 ,
splits = 4 ,
tie_key_values = True # set this to True
). cuda ()
feats = torch . randn ( 1 , 32 , 64 ). cuda ()
coors = torch . randn ( 1 , 32 , 3 ). cuda ()
mask = torch . ones ( 1 , 32 ). bool (). cuda ()
out = model ( feats , coors , mask , return_type = 0 )
นี่คือ EGNN เวอร์ชันทดลองที่ใช้ได้กับประเภทที่สูงกว่าและมีมิติมากกว่า 1 (สำหรับพิกัด) ชื่อคลาสยังคงเป็น SE3Transformer
เนื่องจากมีการใช้ตรรกะที่มีอยู่แล้วซ้ำ ดังนั้นอย่าเพิ่งสนใจสิ่งนั้นในตอนนี้จนกว่าฉันจะล้างมันในภายหลัง
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 32 ,
num_neighbors = 8 ,
num_edge_tokens = 4 ,
edge_dim = 4 ,
num_degrees = 4 , # number of higher order types - will use basis on a TCN to project to these dimensions
use_egnn = True , # set this to true to use EGNN instead of equivariant attention layers
egnn_hidden_dim = 64 , # egnn hidden dimension
depth = 4 , # depth of EGNN
reduce_dim_out = True # will project the dimension of the higher types to 1
). cuda ()
feats = torch . randn ( 2 , 32 , 32 ). cuda ()
coors = torch . randn ( 2 , 32 , 3 ). cuda ()
bonds = torch . randint ( 0 , 4 , ( 2 , 32 , 32 )). cuda ()
mask = torch . ones ( 2 , 32 ). bool (). cuda ()
refinement = model ( feats , coors , mask , edges = bonds , return_type = 1 ) # (2, 32, 3)
coors = coors + refinement # update coors with refinement
หากคุณต้องการระบุมิติข้อมูลแต่ละรายการสำหรับประเภทที่สูงกว่าแต่ละประเภท เพียงส่งผ่าน hidden_fiber_dict
โดยที่พจนานุกรมอยู่ในรูปแบบ {<degree>:<dim>} แทนที่จะเป็น num_degrees
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
dim = 32 ,
num_neighbors = 8 ,
hidden_fiber_dict = { 0 : 32 , 1 : 16 , 2 : 8 , 3 : 4 },
use_egnn = True ,
depth = 4 ,
egnn_hidden_dim = 64 ,
egnn_weights_clamp_value = 2 ,
reduce_dim_out = True
). cuda ()
feats = torch . randn ( 2 , 32 , 32 ). cuda ()
coors = torch . randn ( 2 , 32 , 3 ). cuda ()
mask = torch . ones ( 2 , 32 ). bool (). cuda ()
refinement = model ( feats , coors , mask , return_type = 1 ) # (2, 32, 3)
coors = coors + refinement # update coors with refinement
ส่วนนี้จะแสดงความพยายามอย่างต่อเนื่องในการทำให้ SE3 Transformer ดีขึ้นเล็กน้อย
ประการแรก ฉันได้เพิ่มเครือข่ายแบบพลิกกลับได้ สิ่งนี้ทำให้ฉันเพิ่มความลึกได้อีกเล็กน้อยก่อนที่จะเข้าสู่อุปสรรคของหน่วยความจำตามปกติ การอนุรักษ์ความเท่าเทียมแสดงให้เห็นในการทดสอบ
import torch
from se3_transformer_pytorch import SE3Transformer
model = SE3Transformer (
num_tokens = 20 ,
dim = 32 ,
dim_head = 32 ,
heads = 4 ,
depth = 12 , # 12 layers
input_degrees = 1 ,
num_degrees = 3 ,
output_degrees = 1 ,
reduce_dim_out = True ,
reversible = True # set reversible to True
). cuda ()
atoms = torch . randint ( 0 , 4 , ( 2 , 32 )). cuda ()
coors = torch . randn ( 2 , 32 , 3 ). cuda ()
mask = torch . ones ( 2 , 32 ). bool (). cuda ()
pred = model ( atoms , coors , mask = mask , return_type = 0 )
loss = pred . sum ()
loss . backward ()
ขั้นแรกให้ติดตั้ง sidechainnet
$ pip install sidechainnet
จากนั้นรันงาน denoising โปรตีนแบ็คโบน
$ python denoise.py
ตามค่าเริ่มต้น เวกเตอร์พื้นฐานจะถูกแคชไว้ อย่างไรก็ตาม หากจำเป็นต้องล้างแคช คุณเพียงแค่ต้องตั้งค่าสถานะสิ่งแวดล้อม CLEAR_CACHE
เป็นค่าบางค่าในการเริ่มต้นสคริปต์
$ CLEAR_CACHE=1 python train.py
หรือคุณสามารถลองลบไดเร็กทอรีแคชซึ่งควรมีอยู่ที่
$ rm -rf ~ /.cache.equivariant_attention
คุณยังสามารถกำหนดไดเร็กทอรีของคุณเองที่คุณต้องการให้แคชเก็บไว้ได้ ในกรณีที่ไดเร็กทอรีเริ่มต้นอาจมีปัญหาเรื่องการอนุญาต
CACHE_PATH=./path/to/my/cache python train.py
$ python setup.py pytest
ห้องสมุดนี้ส่วนใหญ่เป็นท่าเรือของพื้นที่เก็บข้อมูลอย่างเป็นทางการของ Fabian แต่ไม่มีห้องสมุด DGL
@misc { fuchs2020se3transformers ,
title = { SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks } ,
author = { Fabian B. Fuchs and Daniel E. Worrall and Volker Fischer and Max Welling } ,
year = { 2020 } ,
eprint = { 2006.10503 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { satorras2021en ,
title = { E(n) Equivariant Graph Neural Networks } ,
author = { Victor Garcia Satorras and Emiel Hoogeboom and Max Welling } ,
year = { 2021 } ,
eprint = { 2102.09844 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { gomez2017reversible ,
title = { The Reversible Residual Network: Backpropagation Without Storing Activations } ,
author = { Aidan N. Gomez and Mengye Ren and Raquel Urtasun and Roger B. Grosse } ,
year = { 2017 } ,
eprint = { 1707.04585 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@misc { shazeer2019fast ,
title = { Fast Transformer Decoding: One Write-Head is All You Need } ,
author = { Noam Shazeer } ,
year = { 2019 } ,
eprint = { 1911.02150 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.NE }
}