Implementação de Alphafold 3 em Pytorch
Você pode conversar com outros pesquisadores sobre este trabalho aqui
Revisão do artigo por Sergey
Guia ilustrado por Elana P. Simon
Palestra de Max Jaderberg
Um fork com suporte completo para Lightning + Hydra está sendo mantido por Alex neste repositório
Uma visualização das moléculas de vida usadas no repositório pode ser vista e interagida aqui
Joseph por contribuir com a codificação posicional relativa e a perda suave de LDDT!
Felipe por contribuir com os módulos Weighted Rigid Align, Express Coordinates In Frame, Compute Alignment Error e Center Random Augmentation!
Alex por corrigir vários problemas nos algoritmos transcritos
Heng por apontar inconsistências com o papel e puxar solicitando as soluções
Heng por encontrar um problema com os índices de átomos moleculares para a perda do distograma
Wei Lu por detectar alguns hiperparâmetros errados
Alex pelo script de preparação do conjunto de dados do PDB!
Milot por otimizar o script de clustering do conjunto de dados PDB!
Alex, por basicamente escrever todo o fluxo gigantesco, desde a análise do PDB até a molécula e entradas atômicas para treinamento
Andrei por trabalhar na amostragem ponderada do conjunto de dados do PDB!
Jimin por enviar uma pequena correção para um problema com as coordenadas sendo passadas para WeightedRigidAlign
@xluo233 por contribuir com as medidas de confiança, classificação de penalidades de confronto e amostra de lógica de classificação!
sj900 para integrar e testar o WeightedPDBSampler
dentro do PDBDataset
e para adicionar suporte inicial para MSA e análise de modelo!
@xluo233 novamente por contribuir com a lógica para calcular a pontuação de seleção do modelo, bem como o rasa não resolvido!
Fandi por descobrir algumas inconsistências no módulo de difusão de átomos elucidado com o suplemento
Paolo por propor a hipótese PDB neutral stable molecule
!
Dhuvi por corrigir um bug relacionado à atribuição de ID de molécula de íon metálico para Alphafold3Inputs
!
Dhuvi por assumir a lógica de tradução de Alphafold3Input
para BioMolecule
para salvar em mmCIF!
Tom (do canal Discord) por identificar uma discrepância entre os cálculos do distograma e do vetor unitário do modelo desta base de código e os do OpenFold (e Andrei por ajudar a resolver o problema do distograma)!
Kaihui por identificar um bug na forma como átomos não padronizados eram tratados em resíduos de polímeros!
Andrei por assumir a interface frontend do gradio!
Patrick para jaxtyping, Florian para einx e, claro, Alex para einops
Soumit e à organização Pytorch por me darem a oportunidade de abrir o código-fonte deste trabalho
$ pip install alphafold3-pytorch
import torch
from alphafold3_pytorch import Alphafold3
from alphafold3_pytorch . utils . model_utils import exclusive_cumsum
alphafold3 = Alphafold3 (
dim_atom_inputs = 77 ,
dim_template_feats = 108
)
# mock inputs
seq_len = 16
molecule_atom_indices = torch . randint ( 0 , 2 , ( 2 , seq_len )). long ()
molecule_atom_lens = torch . full (( 2 , seq_len ), 2 ). long ()
atom_seq_len = molecule_atom_lens . sum ( dim = - 1 ). amax ()
atom_offsets = exclusive_cumsum ( molecule_atom_lens )
atom_inputs = torch . randn ( 2 , atom_seq_len , 77 )
atompair_inputs = torch . randn ( 2 , atom_seq_len , atom_seq_len , 5 )
additional_molecule_feats = torch . randint ( 0 , 2 , ( 2 , seq_len , 5 ))
additional_token_feats = torch . randn ( 2 , seq_len , 33 )
is_molecule_types = torch . randint ( 0 , 2 , ( 2 , seq_len , 5 )). bool ()
is_molecule_mod = torch . randint ( 0 , 2 , ( 2 , seq_len , 4 )). bool ()
molecule_ids = torch . randint ( 0 , 32 , ( 2 , seq_len ))
template_feats = torch . randn ( 2 , 2 , seq_len , seq_len , 108 )
template_mask = torch . ones (( 2 , 2 )). bool ()
msa = torch . randn ( 2 , 7 , seq_len , 32 )
msa_mask = torch . ones (( 2 , 7 )). bool ()
additional_msa_feats = torch . randn ( 2 , 7 , seq_len , 2 )
# required for training, but omitted on inference
atom_pos = torch . randn ( 2 , atom_seq_len , 3 )
distogram_atom_indices = molecule_atom_lens - 1
distance_labels = torch . randint ( 0 , 37 , ( 2 , seq_len , seq_len ))
resolved_labels = torch . randint ( 0 , 2 , ( 2 , atom_seq_len ))
# offset indices correctly
distogram_atom_indices += atom_offsets
molecule_atom_indices += atom_offsets
# train
loss = alphafold3 (
num_recycling_steps = 2 ,
atom_inputs = atom_inputs ,
atompair_inputs = atompair_inputs ,
molecule_ids = molecule_ids ,
molecule_atom_lens = molecule_atom_lens ,
additional_molecule_feats = additional_molecule_feats ,
additional_msa_feats = additional_msa_feats ,
additional_token_feats = additional_token_feats ,
is_molecule_types = is_molecule_types ,
is_molecule_mod = is_molecule_mod ,
msa = msa ,
msa_mask = msa_mask ,
templates = template_feats ,
template_mask = template_mask ,
atom_pos = atom_pos ,
distogram_atom_indices = distogram_atom_indices ,
molecule_atom_indices = molecule_atom_indices ,
distance_labels = distance_labels ,
resolved_labels = resolved_labels
)
loss . backward ()
# after much training ...
sampled_atom_pos = alphafold3 (
num_recycling_steps = 4 ,
num_sample_steps = 16 ,
atom_inputs = atom_inputs ,
atompair_inputs = atompair_inputs ,
molecule_ids = molecule_ids ,
molecule_atom_lens = molecule_atom_lens ,
additional_molecule_feats = additional_molecule_feats ,
additional_msa_feats = additional_msa_feats ,
additional_token_feats = additional_token_feats ,
is_molecule_types = is_molecule_types ,
is_molecule_mod = is_molecule_mod ,
msa = msa ,
msa_mask = msa_mask ,
templates = template_feats ,
template_mask = template_mask
)
sampled_atom_pos . shape # (2, , 3)
Um exemplo com manipulação de entrada em nível de molécula
import torch
from alphafold3_pytorch import Alphafold3 , Alphafold3Input
contrived_protein = 'AG'
mock_atompos = [
torch . randn ( 5 , 3 ), # alanine has 5 non-hydrogen atoms
torch . randn ( 4 , 3 ) # glycine has 4 non-hydrogen atoms
]
train_alphafold3_input = Alphafold3Input (
proteins = [ contrived_protein ],
atom_pos = mock_atompos
)
eval_alphafold3_input = Alphafold3Input (
proteins = [ contrived_protein ]
)
# training
alphafold3 = Alphafold3 (
dim_atom_inputs = 3 ,
dim_atompair_inputs = 5 ,
atoms_per_window = 27 ,
dim_template_feats = 108 ,
num_molecule_mods = 0 ,
confidence_head_kwargs = dict (
pairformer_depth = 1
),
template_embedder_kwargs = dict (
pairformer_stack_depth = 1
),
msa_module_kwargs = dict (
depth = 1
),
pairformer_stack = dict (
depth = 2
),
diffusion_module_kwargs = dict (
atom_encoder_depth = 1 ,
token_transformer_depth = 1 ,
atom_decoder_depth = 1 ,
)
)
loss = alphafold3 . forward_with_alphafold3_inputs ([ train_alphafold3_input ])
loss . backward ()
# sampling
alphafold3 . eval ()
sampled_atom_pos = alphafold3 . forward_with_alphafold3_inputs ( eval_alphafold3_input )
assert sampled_atom_pos . shape == ( 1 , ( 5 + 4 ), 3 )
Para adquirir o conjunto de dados AlphaFold 3 PDB, primeiro baixe todos os complexos de primeira montagem (e unidades assimétricas) no Protein Data Bank (PDB) e, em seguida, pré-processe-os com o script mencionado abaixo. O PDB pode ser baixado do RCSB: https://www.wwpdb.org/ftp/pdb-ftp-sites#rcsbpdb. Os dois scripts Python abaixo (ou seja, filter_pdb_{train,val,test}_mmcifs.py
e cluster_pdb_{train,val,test}_mmcifs.py
) assumem que você baixou o PDB no formato de arquivo mmCIF , colocando seu primeiro assembly e arquivos mmCIF de unidade assimétrica em data/pdb_data/unfiltered_assembly_mmcifs/
e data/pdb_data/unfiltered_asym_mmcifs/
, respectivamente.
Para reprodutibilidade, recomendamos fazer download do PDB usando snapshots da AWS (por exemplo, 20240101
). Para fazer isso, consulte a documentação da AWS para configurar a AWS CLI localmente. Alternativamente, no site do RCSB, navegue até “Protocolos de download” e siga as instruções de download dependendo da sua localização.
Por exemplo, pode-se usar os seguintes comandos para baixar o PDB como duas coleções de arquivos mmCIF:
# For `assembly1` complexes, use the PDB's `20240101` AWS snapshot:
aws s3 sync s3://pdbsnapshots/20240101/pub/pdb/data/assemblies/mmCIF/divided/ ./data/pdb_data/unfiltered_assembly_mmcifs
# Or as a fallback, use rsync:
rsync -rlpt -v -z --delete --port=33444
rsync.rcsb.org::ftp_data/assemblies/mmCIF/divided/ ./data/pdb_data/unfiltered_assembly_mmcifs/
# For asymmetric unit complexes, also use the PDB's `20240101` AWS snapshot:
aws s3 sync s3://pdbsnapshots/20240101/pub/pdb/data/structures/divided/mmCIF/ ./data/pdb_data/unfiltered_asym_mmcifs
# Or as a fallback, use rsync:
rsync -rlpt -v -z --delete --port=33444
rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ ./data/pdb_data/unfiltered_asym_mmcifs/
AVISO: O download do PDB pode ocupar até 700 GB de espaço.
NOTA: O PDB hospeda todos os snapshots da AWS disponíveis aqui: https://pdbsnapshots.s3.us-west-2.amazonaws.com/index.html.
Após o download, você deverá ter dois diretórios formatados assim: https://files.rcsb.org/pub/pdb/data/assemblies/mmCIF/divided/ & https://files.rcsb.org/pub/pdb/data /estruturas/divididas/mmCIF/
00/
01/
02/
..
zz/
Para esses diretórios, descompacte todos os arquivos:
find ./data/pdb_data/unfiltered_assembly_mmcifs/ -type f -name " *.gz " -exec gzip -d {} ;
find ./data/pdb_data/unfiltered_asym_mmcifs/ -type f -name " *.gz " -exec gzip -d {} ;
Em seguida, execute os comandos
wget -P ./data/ccd_data/ https://files.wwpdb.org/pub/pdb/data/monomers/components.cif.gz
wget -P ./data/ccd_data/ https://files.wwpdb.org/pub/pdb/data/component-models/complete/chem_comp_model.cif.gz
do diretório raiz do projeto para baixar a versão mais recente do Dicionário de Componentes Químicos (CCD) do PDB e seus modelos estruturais. Extraia cada um desses arquivos usando o seguinte comando:
find data/ccd_data/ -type f -name " *.gz " -exec gzip -d {} ;
Em seguida, execute o seguinte com pdb_assembly_dir
, pdb_asym_dir
, ccd_dir
e mmcif_output_dir
substituídos pelos locais de suas cópias locais do PDB de primeira montagem, PDB de unidade assimétrica, CCD e seu diretório de saída do conjunto de dados desejado (ou seja, ./data/pdb_data/unfiltered_assembly_mmcifs/
, ./data/pdb_data/unfiltered_asym_mmcifs/
, ./data/ccd_data/
e ./data/pdb_data/{train,val,test}_mmcifs/
).
python scripts/filter_pdb_train_mmcifs.py --mmcif_assembly_dir < pdb_assembly_dir > --mmcif_asym_dir < pdb_asym_dir > --ccd_dir < ccd_dir > --output_dir < mmcif_output_dir >
python scripts/filter_pdb_val_mmcifs.py --mmcif_assembly_dir < pdb_assembly_dir > --mmcif_asym_dir < pdb_asym_dir > --output_dir < mmcif_output_dir >
python scripts/filter_pdb_test_mmcifs.py --mmcif_assembly_dir < pdb_assembly_dir > --mmcif_asym_dir < pdb_asym_dir > --output_dir < mmcif_output_dir >
Veja os scripts para mais opções. Cada mmCIF da primeira montagem que passar com êxito em todas as etapas de processamento será gravado em mmcif_output_dir
dentro de um subdiretório nomeado de acordo com o segundo e terceiro caracteres de ID do PDB do mmCIF (por exemplo, 5c
).
Em seguida, execute o seguinte com mmcif_dir
e {train,val,test}_clustering_output_dir
substituídos, respectivamente, pelo diretório de saída local criado usando o script de filtragem do conjunto de dados acima e pelos diretórios de saída de cluster desejados (ou seja, ./data/pdb_data/{train,val,test}_mmcifs/
e ./data/pdb_data/data_caches/{train,val,test}_clusterings/
):
python scripts/cluster_pdb_train_mmcifs.py --mmcif_dir < mmcif_dir > --output_dir < train_clustering_output_dir > --clustering_filtered_pdb_dataset
python scripts/cluster_pdb_val_mmcifs.py --mmcif_dir < mmcif_dir > --reference_clustering_dir < train_clustering_output_dir > --output_dir < val_clustering_output_dir > --clustering_filtered_pdb_dataset
python scripts/cluster_pdb_test_mmcifs.py --mmcif_dir < mmcif_dir > --reference_1_clustering_dir < train_clustering_output_dir > --reference_2_clustering_dir < val_clustering_output_dir > --output_dir < test_clustering_output_dir > --clustering_filtered_pdb_dataset
Nota : O sinalizador --clustering_filtered_pdb_dataset
é recomendado ao agrupar o conjunto de dados PDB filtrado conforme curado usando os scripts acima, pois esse sinalizador permitirá tempos de execução mais rápidos neste contexto (já que a filtragem deixa os IDs de resíduo de cada cadeia baseados em 1). No entanto, esse sinalizador não deve ser fornecido ao agrupar outros conjuntos de dados (ou seja, não PDB) de arquivos mmCIF. Caso contrário, o clustering de interface pode ser executado incorretamente, pois os arquivos mmCIF desses conjuntos de dados podem não usar indexação de resíduo estrita baseada em 1 para cada cadeia.
Nota : Em vez disso, é possível baixar arquivos mmCIF ( train
/ val
/ test
) pré-processados (ou seja, filtrados) (~ 25 GB, compreendendo 148 mil complexos) e arquivos de clustering de cadeia/interface ( train
/ val
/ test
) (~ 3 GB) para o PDB 20240101
Instantâneo da AWS por meio de uma pasta compartilhada do OneDrive. Cada um desses arquivos tar.gz
deve ser descompactado no diretório data/pdb_data/
, por exemplo, via tar -xzf data_caches.tar.gz -C data/pdb_data/
. Também é possível baixar e preparar dados de destilação do PDB usando como referência o script scripts/distillation_data_download.sh
. Depois de baixado, pode-se executar scripts/reduce_uniprot_predictions_to_pdb.py
para filtrar este conjunto de dados apenas para exemplos associados a pelo menos uma entrada do PDB. Além disso, por conveniência, um mapeamento de IDs de acesso UniProt para IDs de PDB para treinamento em dados de destilação de PDB já foi baixado e extraído como data/afdb_data/data_caches/uniprot_to_pdb_id_mapping.dat
.
Na raiz do projeto, execute
$ sh ./contribute.sh
Em seguida, adicione seu módulo a alphafold3_pytorch/alphafold3.py
, adicione seus testes a tests/test_af3.py
e envie uma solicitação pull. Você pode executar os testes localmente com
$ pytest tests/
O Dockerfile
incluído contém as dependências necessárias para executar o pacote e treinar/inferência usando PyTorch com GPUs.
A imagem base padrão é pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime
e instala a versão mais recente deste pacote da ramificação main
do GitHub.
# # Build Docker Container
docker build -t af3 .
Alternativamente, use argumentos de construção para reconstruir a imagem com diferentes versões de software:
PYTORCH_TAG
: altera a imagem base e, portanto, constrói com diferentes versões de PyTorch, CUDA e/ou cuDNN.GIT_TAG
: Altera a tag deste repositório para clonar e instalar o pacote.Por exemplo:
# # Use build argument to change versions
docker build --build-arg " PYTORCH_TAG=2.2.1-cuda12.1-cudnn8-devel " --build-arg " GIT_TAG=0.1.15 " -t af3 .
Em seguida, execute o contêiner com GPUs e monte um volume local (para treinamento) usando o seguinte comando:
# # Run Container
docker run -v .:/data --gpus all -it af3
@article { Abramson2024-fj ,
title = " Accurate structure prediction of biomolecular interactions with
{AlphaFold} 3 " ,
author = " Abramson, Josh and Adler, Jonas and Dunger, Jack and Evans,
Richard and Green, Tim and Pritzel, Alexander and Ronneberger,
Olaf and Willmore, Lindsay and Ballard, Andrew J and Bambrick,
Joshua and Bodenstein, Sebastian W and Evans, David A and Hung,
Chia-Chun and O'Neill, Michael and Reiman, David and
Tunyasuvunakool, Kathryn and Wu, Zachary and {v Z}emgulyt{.e},
Akvil{.e} and Arvaniti, Eirini and Beattie, Charles and
Bertolli, Ottavia and Bridgland, Alex and Cherepanov, Alexey and
Congreve, Miles and Cowen-Rivers, Alexander I and Cowie, Andrew
and Figurnov, Michael and Fuchs, Fabian B and Gladman, Hannah and
Jain, Rishub and Khan, Yousuf A and Low, Caroline M R and Perlin,
Kuba and Potapenko, Anna and Savy, Pascal and Singh, Sukhdeep and
Stecula, Adrian and Thillaisundaram, Ashok and Tong, Catherine
and Yakneen, Sergei and Zhong, Ellen D and Zielinski, Michal and
{v Z}{'i}dek, Augustin and Bapst, Victor and Kohli, Pushmeet
and Jaderberg, Max and Hassabis, Demis and Jumper, John M " ,
journal = " Nature " ,
month = " May " ,
year = 2024
}
@inproceedings { Darcet2023VisionTN ,
title = { Vision Transformers Need Registers } ,
author = { Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski } ,
year = { 2023 } ,
url = { https://api.semanticscholar.org/CorpusID:263134283 }
}
@article { Arora2024SimpleLA ,
title = { Simple linear attention language models balance the recall-throughput tradeoff } ,
author = { Simran Arora and Sabri Eyuboglu and Michael Zhang and Aman Timalsina and Silas Alberti and Dylan Zinsley and James Zou and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2402.18668 } ,
url = { https://api.semanticscholar.org/CorpusID:268063190 }
}
@article { Puny2021FrameAF ,
title = { Frame Averaging for Invariant and Equivariant Network Design } ,
author = { Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman } ,
journal = { ArXiv } ,
year = { 2021 } ,
volume = { abs/2110.03336 } ,
url = { https://api.semanticscholar.org/CorpusID:238419638 }
}
@article { Duval2023FAENetFA ,
title = { FAENet: Frame Averaging Equivariant GNN for Materials Modeling } ,
author = { Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2305.05577 } ,
url = { https://api.semanticscholar.org/CorpusID:258564608 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 } ,
url = { https://api.semanticscholar.org/CorpusID:247187905 }
}
@inproceedings { Ainslie2023CoLT5FL ,
title = { CoLT5: Faster Long-Range Transformers with Conditional Computation } ,
author = { Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai } ,
year = { 2023 }
}
@article { Ash2019OnTD ,
title = { On the Difficulty of Warm-Starting Neural Network Training } ,
author = { Jordan T. Ash and Ryan P. Adams } ,
journal = { ArXiv } ,
year = { 2019 } ,
volume = { abs/1910.08475 } ,
url = { https://api.semanticscholar.org/CorpusID:204788802 }
}
@ARTICLE { Heinzinger2023.07.23.550085 ,
author = { Michael Heinzinger and Konstantin Weissenow and Joaquin Gomez Sanchez and Adrian Henkel and Martin Steinegger and Burkhard Rost } ,
title = { ProstT5: Bilingual Language Model for Protein Sequence and Structure } ,
year = { 2023 } ,
doi = { 10.1101/2023.07.23.550085 } ,
journal = { bioRxiv }
}
@article { Lin2022.07.20.500902 ,
author = { Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Santos Costa, Allan dos and Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Sal and Rives, Alexander } ,
title = { Language models of protein sequences at the scale of evolution enable accurate structure prediction } ,
elocation-id = { 2022.07.20.500902 } ,
year = { 2022 } ,
doi = { 10.1101/2022.07.20.500902 } ,
publisher = { Cold Spring Harbor Laboratory } ,
URL = { https://www.biorxiv.org/content/early/2022/07/21/2022.07.20.500902 } ,
eprint = { https://www.biorxiv.org/content/early/2022/07/21/2022.07.20.500902.full.pdf } ,
journal = { bioRxiv }
}
@article { Li2024SwitchEA ,
title = { Switch EMA: A Free Lunch for Better Flatness and Sharpness } ,
author = { Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2402.09240 } ,
url = { https://api.semanticscholar.org/CorpusID:267657558 }
}
@article { Nguyen2023MitigatingOI ,
title = { Mitigating Over-smoothing in Transformers via Regularized Nonlocal Functionals } ,
author = { Tam Nguyen and Tan M. Nguyen and Richard G. Baraniuk } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2312.00751 } ,
url = { https://api.semanticscholar.org/CorpusID:264300597 }
}
@inproceedings { Zhou2024ValueRL ,
title = { Value Residual Learning For Alleviating Attention Concentration In Transformers } ,
author = { Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan } ,
year = { 2024 } ,
url = { https://api.semanticscholar.org/CorpusID:273532030 }
}