Implémentation d'Alphafold 3 dans Pytorch
Vous pouvez discuter avec d'autres chercheurs de ce travail ici
Revue de l'article par Sergey
Guide illustré par Elana P. Simon
Conférence de Max Jaderberg
Un fork avec prise en charge complète de Lightning + Hydra est maintenu par Alex dans ce référentiel
Une visualisation des molécules de vie utilisées dans le référentiel peut être vue et interagie ici
Joseph pour sa contribution au codage positionnel relatif et à la perte LDDT fluide !
Felipe pour sa contribution aux modules Weighted Rigid Align, Express Coordonnées dans le cadre, Compute Alignment Error et Center Random Augmentation !
Alex pour avoir résolu divers problèmes dans les algorithmes transcrits
Heng pour avoir signalé des incohérences avec le document et demandé des solutions
Heng pour avoir trouvé un problème avec les indices des atomes moléculaires pour la perte du distogramme
Wei Lu pour avoir détecté quelques hyperparamètres erronés
Alex pour le script de préparation du jeu de données PDB !
Milot pour l'optimisation du script de clustering des jeux de données PDB !
Alex pour avoir essentiellement écrit l'intégralité du flux gargantuesque, depuis l'analyse du PDB jusqu'aux entrées moléculaires et atomiques pour la formation
Andrei pour avoir travaillé sur l'échantillonnage pondéré de l'ensemble de données PDB !
Jimin pour avoir soumis un petit correctif à un problème avec les coordonnées transmises à WeightedRigidAlign
@xluo233 pour avoir contribué aux mesures de confiance, au classement des pénalités de conflit et à un exemple de logique de classement !
sj900 pour intégrer et tester le WeightedPDBSampler
dans le PDBDataset
et pour ajouter la prise en charge initiale de MSA et de l'analyse de modèles !
@xluo233 encore une fois pour avoir contribué à la logique de calcul du score de sélection du modèle ainsi que du rasa non résolu !
Fandi pour avoir découvert quelques incohérences dans le module de diffusion atomique élucidé avec le module supplémentaire
Paolo pour avoir proposé l’hypothèse de la PDB neutral stable molecule
!
Dhuvi pour avoir corrigé un bug lié à l'attribution de l'ID des molécules d'ions métalliques pour Alphafold3Inputs
!
Dhuvi pour avoir repris la logique de traduction d' Alphafold3Input
en BioMolecule
pour l'enregistrer au format mmCIF !
Tom (du canal Discord) pour avoir identifié une divergence entre les calculs de distogramme et de vecteur unitaire de modèle de cette base de code et ceux d'OpenFold (et Andrei pour avoir aidé à résoudre le problème du distogramme) !
Kaihui pour avoir identifié un bug dans la façon dont les atomes non standard étaient traités dans les résidus de polymère !
Andrei pour avoir adopté l'interface frontale de Gradio !
Patrick pour jaxtyping, Florian pour einx, et bien sûr Alex pour einops
Soumith et l'organisation Pytorch pour m'avoir donné l'opportunité d'ouvrir ce travail en source libre
$ pip install alphafold3-pytorch
import torch
from alphafold3_pytorch import Alphafold3
from alphafold3_pytorch . utils . model_utils import exclusive_cumsum
alphafold3 = Alphafold3 (
dim_atom_inputs = 77 ,
dim_template_feats = 108
)
# mock inputs
seq_len = 16
molecule_atom_indices = torch . randint ( 0 , 2 , ( 2 , seq_len )). long ()
molecule_atom_lens = torch . full (( 2 , seq_len ), 2 ). long ()
atom_seq_len = molecule_atom_lens . sum ( dim = - 1 ). amax ()
atom_offsets = exclusive_cumsum ( molecule_atom_lens )
atom_inputs = torch . randn ( 2 , atom_seq_len , 77 )
atompair_inputs = torch . randn ( 2 , atom_seq_len , atom_seq_len , 5 )
additional_molecule_feats = torch . randint ( 0 , 2 , ( 2 , seq_len , 5 ))
additional_token_feats = torch . randn ( 2 , seq_len , 33 )
is_molecule_types = torch . randint ( 0 , 2 , ( 2 , seq_len , 5 )). bool ()
is_molecule_mod = torch . randint ( 0 , 2 , ( 2 , seq_len , 4 )). bool ()
molecule_ids = torch . randint ( 0 , 32 , ( 2 , seq_len ))
template_feats = torch . randn ( 2 , 2 , seq_len , seq_len , 108 )
template_mask = torch . ones (( 2 , 2 )). bool ()
msa = torch . randn ( 2 , 7 , seq_len , 32 )
msa_mask = torch . ones (( 2 , 7 )). bool ()
additional_msa_feats = torch . randn ( 2 , 7 , seq_len , 2 )
# required for training, but omitted on inference
atom_pos = torch . randn ( 2 , atom_seq_len , 3 )
distogram_atom_indices = molecule_atom_lens - 1
distance_labels = torch . randint ( 0 , 37 , ( 2 , seq_len , seq_len ))
resolved_labels = torch . randint ( 0 , 2 , ( 2 , atom_seq_len ))
# offset indices correctly
distogram_atom_indices += atom_offsets
molecule_atom_indices += atom_offsets
# train
loss = alphafold3 (
num_recycling_steps = 2 ,
atom_inputs = atom_inputs ,
atompair_inputs = atompair_inputs ,
molecule_ids = molecule_ids ,
molecule_atom_lens = molecule_atom_lens ,
additional_molecule_feats = additional_molecule_feats ,
additional_msa_feats = additional_msa_feats ,
additional_token_feats = additional_token_feats ,
is_molecule_types = is_molecule_types ,
is_molecule_mod = is_molecule_mod ,
msa = msa ,
msa_mask = msa_mask ,
templates = template_feats ,
template_mask = template_mask ,
atom_pos = atom_pos ,
distogram_atom_indices = distogram_atom_indices ,
molecule_atom_indices = molecule_atom_indices ,
distance_labels = distance_labels ,
resolved_labels = resolved_labels
)
loss . backward ()
# after much training ...
sampled_atom_pos = alphafold3 (
num_recycling_steps = 4 ,
num_sample_steps = 16 ,
atom_inputs = atom_inputs ,
atompair_inputs = atompair_inputs ,
molecule_ids = molecule_ids ,
molecule_atom_lens = molecule_atom_lens ,
additional_molecule_feats = additional_molecule_feats ,
additional_msa_feats = additional_msa_feats ,
additional_token_feats = additional_token_feats ,
is_molecule_types = is_molecule_types ,
is_molecule_mod = is_molecule_mod ,
msa = msa ,
msa_mask = msa_mask ,
templates = template_feats ,
template_mask = template_mask
)
sampled_atom_pos . shape # (2, <atom_seqlen>, 3)
Un exemple avec la gestion des entrées au niveau des molécules
import torch
from alphafold3_pytorch import Alphafold3 , Alphafold3Input
contrived_protein = 'AG'
mock_atompos = [
torch . randn ( 5 , 3 ), # alanine has 5 non-hydrogen atoms
torch . randn ( 4 , 3 ) # glycine has 4 non-hydrogen atoms
]
train_alphafold3_input = Alphafold3Input (
proteins = [ contrived_protein ],
atom_pos = mock_atompos
)
eval_alphafold3_input = Alphafold3Input (
proteins = [ contrived_protein ]
)
# training
alphafold3 = Alphafold3 (
dim_atom_inputs = 3 ,
dim_atompair_inputs = 5 ,
atoms_per_window = 27 ,
dim_template_feats = 108 ,
num_molecule_mods = 0 ,
confidence_head_kwargs = dict (
pairformer_depth = 1
),
template_embedder_kwargs = dict (
pairformer_stack_depth = 1
),
msa_module_kwargs = dict (
depth = 1
),
pairformer_stack = dict (
depth = 2
),
diffusion_module_kwargs = dict (
atom_encoder_depth = 1 ,
token_transformer_depth = 1 ,
atom_decoder_depth = 1 ,
)
)
loss = alphafold3 . forward_with_alphafold3_inputs ([ train_alphafold3_input ])
loss . backward ()
# sampling
alphafold3 . eval ()
sampled_atom_pos = alphafold3 . forward_with_alphafold3_inputs ( eval_alphafold3_input )
assert sampled_atom_pos . shape == ( 1 , ( 5 + 4 ), 3 )
Pour acquérir l'ensemble de données AlphaFold 3 PDB, téléchargez d'abord tous les complexes de premier assemblage (et d'unité asymétrique) dans la banque de données sur les protéines (PDB), puis prétraitez-les avec le script référencé ci-dessous. Le PDB peut être téléchargé depuis le RCSB : https://www.wwpdb.org/ftp/pdb-ftp-sites#rcsbpdb. Les deux scripts Python ci-dessous (c'est-à-dire filter_pdb_{train,val,test}_mmcifs.py
et cluster_pdb_{train,val,test}_mmcifs.py
) supposent que vous avez téléchargé le PDB au format de fichier mmCIF , en plaçant son premier assemblage et fichiers mmCIF d'unités asymétriques dans data/pdb_data/unfiltered_assembly_mmcifs/
et data/pdb_data/unfiltered_asym_mmcifs/
, respectivement.
Pour des raisons de reproductibilité, nous vous recommandons de télécharger le PDB à l'aide d'instantanés AWS (par exemple, 20240101
). Pour ce faire, reportez-vous à la documentation d'AWS pour configurer l'AWS CLI localement. Alternativement, sur le site Web du RCSB, accédez à « Protocoles de téléchargement » et suivez les instructions de téléchargement en fonction de votre emplacement.
Par exemple, on peut utiliser les commandes suivantes pour télécharger le PDB sous forme de deux collections de fichiers mmCIF :
# For `assembly1` complexes, use the PDB's `20240101` AWS snapshot:
aws s3 sync s3://pdbsnapshots/20240101/pub/pdb/data/assemblies/mmCIF/divided/ ./data/pdb_data/unfiltered_assembly_mmcifs
# Or as a fallback, use rsync:
rsync -rlpt -v -z --delete --port=33444
rsync.rcsb.org::ftp_data/assemblies/mmCIF/divided/ ./data/pdb_data/unfiltered_assembly_mmcifs/
# For asymmetric unit complexes, also use the PDB's `20240101` AWS snapshot:
aws s3 sync s3://pdbsnapshots/20240101/pub/pdb/data/structures/divided/mmCIF/ ./data/pdb_data/unfiltered_asym_mmcifs
# Or as a fallback, use rsync:
rsync -rlpt -v -z --delete --port=33444
rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ ./data/pdb_data/unfiltered_asym_mmcifs/
AVERTISSEMENT : le téléchargement du PDB peut prendre jusqu'à 700 Go d'espace.
REMARQUE : La PDB héberge tous les instantanés AWS disponibles ici : https://pdbsnapshots.s3.us-west-2.amazonaws.com/index.html.
Après le téléchargement, vous devriez avoir deux répertoires formatés comme ceci : https://files.rcsb.org/pub/pdb/data/assemblies/mmCIF/divided/ & https://files.rcsb.org/pub/pdb/data /structures/divisé/mmCIF/
00/
01/
02/
..
zz/
Pour ces répertoires, décompressez tous les fichiers :
find ./data/pdb_data/unfiltered_assembly_mmcifs/ -type f -name " *.gz " -exec gzip -d {} ;
find ./data/pdb_data/unfiltered_asym_mmcifs/ -type f -name " *.gz " -exec gzip -d {} ;
Ensuite, exécutez les commandes
wget -P ./data/ccd_data/ https://files.wwpdb.org/pub/pdb/data/monomers/components.cif.gz
wget -P ./data/ccd_data/ https://files.wwpdb.org/pub/pdb/data/component-models/complete/chem_comp_model.cif.gz
à partir du répertoire racine du projet pour télécharger la dernière version du dictionnaire des composants chimiques (CCD) du PDB et de ses modèles structurels. Extrayez chacun de ces fichiers à l'aide de la commande suivante :
find data/ccd_data/ -type f -name " *.gz " -exec gzip -d {} ;
Ensuite, exécutez ce qui suit en remplaçant pdb_assembly_dir
, pdb_asym_dir
, ccd_dir
et mmcif_output_dir
par les emplacements de vos copies locales du PDB du premier assemblage, du PDB de l'unité asymétrique, du CCD et du répertoire de sortie de l'ensemble de données souhaité (c'est-à-dire ./data/pdb_data/unfiltered_assembly_mmcifs/
, ./data/pdb_data/unfiltered_asym_mmcifs/
, ./data/ccd_data/
et ./data/pdb_data/{train,val,test}_mmcifs/
).
python scripts/filter_pdb_train_mmcifs.py --mmcif_assembly_dir < pdb_assembly_dir > --mmcif_asym_dir < pdb_asym_dir > --ccd_dir < ccd_dir > --output_dir < mmcif_output_dir >
python scripts/filter_pdb_val_mmcifs.py --mmcif_assembly_dir < pdb_assembly_dir > --mmcif_asym_dir < pdb_asym_dir > --output_dir < mmcif_output_dir >
python scripts/filter_pdb_test_mmcifs.py --mmcif_assembly_dir < pdb_assembly_dir > --mmcif_asym_dir < pdb_asym_dir > --output_dir < mmcif_output_dir >
Voir les scripts pour plus d'options. Chaque mmCIF de premier assemblage qui réussit toutes les étapes de traitement sera écrit dans mmcif_output_dir
dans un sous-répertoire nommé en fonction des deuxième et troisième caractères d'identification PDB du mmCIF (par exemple 5c
).
Ensuite, exécutez ce qui suit en remplaçant respectivement mmcif_dir
et {train,val,test}_clustering_output_dir
par votre répertoire de sortie local créé à l'aide du script de filtrage d'ensemble de données ci-dessus et par les répertoires de sortie de clustering souhaités (c'est-à-dire ./data/pdb_data/{train,val,test}_mmcifs/
et ./data/pdb_data/data_caches/{train,val,test}_clusterings/
) :
python scripts/cluster_pdb_train_mmcifs.py --mmcif_dir < mmcif_dir > --output_dir < train_clustering_output_dir > --clustering_filtered_pdb_dataset
python scripts/cluster_pdb_val_mmcifs.py --mmcif_dir < mmcif_dir > --reference_clustering_dir < train_clustering_output_dir > --output_dir < val_clustering_output_dir > --clustering_filtered_pdb_dataset
python scripts/cluster_pdb_test_mmcifs.py --mmcif_dir < mmcif_dir > --reference_1_clustering_dir < train_clustering_output_dir > --reference_2_clustering_dir < val_clustering_output_dir > --output_dir < test_clustering_output_dir > --clustering_filtered_pdb_dataset
Remarque : L'indicateur --clustering_filtered_pdb_dataset
est recommandé lors du regroupement de l'ensemble de données PDB filtré tel que organisé à l'aide des scripts ci-dessus, car cet indicateur permettra des exécutions plus rapides dans ce contexte (puisque le filtrage laisse les ID de résidus de chaque chaîne basés sur 1). Cependant, cet indicateur ne doit pas être fourni lors du regroupement d'autres ensembles de données (c'est-à-dire non PDB) de fichiers mmCIF. Sinon, le regroupement d'interfaces risque d'être effectué de manière incorrecte, car les fichiers mmCIF de ces ensembles de données risquent de ne pas utiliser une indexation stricte des résidus de base 1 pour chaque chaîne.
Remarque : il est possible de télécharger à la place des fichiers mmCIF ( train
/ val
/ test
) prétraités (c'est-à-dire filtrés) (~ 25 Go, comprenant 148 000 complexes) et des fichiers de clustering de chaîne/interface ( train
/ val
/ test
) (~ 3 Go) pour le 20240101
du PDB. Instantané AWS via un dossier OneDrive partagé. Chacune de ces archives tar.gz
doit être décompressée dans le répertoire data/pdb_data/
, par exemple via tar -xzf data_caches.tar.gz -C data/pdb_data/
. On peut également télécharger et préparer les données de distillation PDB en utilisant comme référence le script scripts/distillation_data_download.sh
. Une fois téléchargé, on peut exécuter scripts/reduce_uniprot_predictions_to_pdb.py
pour filtrer cet ensemble de données uniquement sur les exemples associés à au moins une entrée PDB. De plus, pour plus de commodité, un mappage des identifiants d'accession UniProt aux identifiants PDB pour la formation sur les données de distillation PDB a déjà été téléchargé et extrait sous data/afdb_data/data_caches/uniprot_to_pdb_id_mapping.dat
.
À la racine du projet, exécutez
$ sh ./contribute.sh
Ensuite, ajoutez votre module à alphafold3_pytorch/alphafold3.py
, ajoutez vos tests à tests/test_af3.py
et soumettez une pull request. Vous pouvez exécuter les tests localement avec
$ pytest tests/
Le Dockerfile
inclus contient les dépendances requises pour exécuter le package et pour entraîner/inférence à l'aide de PyTorch avec des GPU.
L'image de base par défaut est pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime
et installe la dernière version de ce package à partir de la branche main
de GitHub.
# # Build Docker Container
docker build -t af3 .
Vous pouvez également utiliser les arguments de build pour reconstruire l'image avec différentes versions du logiciel :
PYTORCH_TAG
: modifie l'image de base et construit ainsi avec différentes versions de PyTorch, CUDA et/ou cuDNN.GIT_TAG
: Modifie la balise de ce dépôt pour cloner et installer le package.Par exemple:
# # Use build argument to change versions
docker build --build-arg " PYTORCH_TAG=2.2.1-cuda12.1-cudnn8-devel " --build-arg " GIT_TAG=0.1.15 " -t af3 .
Ensuite, exécutez le conteneur avec des GPU et montez un volume local (pour la formation) à l'aide de la commande suivante :
# # Run Container
docker run -v .:/data --gpus all -it af3
@article { Abramson2024-fj ,
title = " Accurate structure prediction of biomolecular interactions with
{AlphaFold} 3 " ,
author = " Abramson, Josh and Adler, Jonas and Dunger, Jack and Evans,
Richard and Green, Tim and Pritzel, Alexander and Ronneberger,
Olaf and Willmore, Lindsay and Ballard, Andrew J and Bambrick,
Joshua and Bodenstein, Sebastian W and Evans, David A and Hung,
Chia-Chun and O'Neill, Michael and Reiman, David and
Tunyasuvunakool, Kathryn and Wu, Zachary and {v Z}emgulyt{.e},
Akvil{.e} and Arvaniti, Eirini and Beattie, Charles and
Bertolli, Ottavia and Bridgland, Alex and Cherepanov, Alexey and
Congreve, Miles and Cowen-Rivers, Alexander I and Cowie, Andrew
and Figurnov, Michael and Fuchs, Fabian B and Gladman, Hannah and
Jain, Rishub and Khan, Yousuf A and Low, Caroline M R and Perlin,
Kuba and Potapenko, Anna and Savy, Pascal and Singh, Sukhdeep and
Stecula, Adrian and Thillaisundaram, Ashok and Tong, Catherine
and Yakneen, Sergei and Zhong, Ellen D and Zielinski, Michal and
{v Z}{'i}dek, Augustin and Bapst, Victor and Kohli, Pushmeet
and Jaderberg, Max and Hassabis, Demis and Jumper, John M " ,
journal = " Nature " ,
month = " May " ,
year = 2024
}
@inproceedings { Darcet2023VisionTN ,
title = { Vision Transformers Need Registers } ,
author = { Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski } ,
year = { 2023 } ,
url = { https://api.semanticscholar.org/CorpusID:263134283 }
}
@article { Arora2024SimpleLA ,
title = { Simple linear attention language models balance the recall-throughput tradeoff } ,
author = { Simran Arora and Sabri Eyuboglu and Michael Zhang and Aman Timalsina and Silas Alberti and Dylan Zinsley and James Zou and Atri Rudra and Christopher R'e } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2402.18668 } ,
url = { https://api.semanticscholar.org/CorpusID:268063190 }
}
@article { Puny2021FrameAF ,
title = { Frame Averaging for Invariant and Equivariant Network Design } ,
author = { Omri Puny and Matan Atzmon and Heli Ben-Hamu and Edward James Smith and Ishan Misra and Aditya Grover and Yaron Lipman } ,
journal = { ArXiv } ,
year = { 2021 } ,
volume = { abs/2110.03336 } ,
url = { https://api.semanticscholar.org/CorpusID:238419638 }
}
@article { Duval2023FAENetFA ,
title = { FAENet: Frame Averaging Equivariant GNN for Materials Modeling } ,
author = { Alexandre Duval and Victor Schmidt and Alex Hernandez Garcia and Santiago Miret and Fragkiskos D. Malliaros and Yoshua Bengio and David Rolnick } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2305.05577 } ,
url = { https://api.semanticscholar.org/CorpusID:258564608 }
}
@article { Wang2022DeepNetST ,
title = { DeepNet: Scaling Transformers to 1, 000 Layers } ,
author = { Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2203.00555 } ,
url = { https://api.semanticscholar.org/CorpusID:247187905 }
}
@inproceedings { Ainslie2023CoLT5FL ,
title = { CoLT5: Faster Long-Range Transformers with Conditional Computation } ,
author = { Joshua Ainslie and Tao Lei and Michiel de Jong and Santiago Ontan'on and Siddhartha Brahma and Yury Zemlyanskiy and David Uthus and Mandy Guo and James Lee-Thorp and Yi Tay and Yun-Hsuan Sung and Sumit Sanghai } ,
year = { 2023 }
}
@article { Ash2019OnTD ,
title = { On the Difficulty of Warm-Starting Neural Network Training } ,
author = { Jordan T. Ash and Ryan P. Adams } ,
journal = { ArXiv } ,
year = { 2019 } ,
volume = { abs/1910.08475 } ,
url = { https://api.semanticscholar.org/CorpusID:204788802 }
}
@ARTICLE { Heinzinger2023.07.23.550085 ,
author = { Michael Heinzinger and Konstantin Weissenow and Joaquin Gomez Sanchez and Adrian Henkel and Martin Steinegger and Burkhard Rost } ,
title = { ProstT5: Bilingual Language Model for Protein Sequence and Structure } ,
year = { 2023 } ,
doi = { 10.1101/2023.07.23.550085 } ,
journal = { bioRxiv }
}
@article { Lin2022.07.20.500902 ,
author = { Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Santos Costa, Allan dos and Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Sal and Rives, Alexander } ,
title = { Language models of protein sequences at the scale of evolution enable accurate structure prediction } ,
elocation-id = { 2022.07.20.500902 } ,
year = { 2022 } ,
doi = { 10.1101/2022.07.20.500902 } ,
publisher = { Cold Spring Harbor Laboratory } ,
URL = { https://www.biorxiv.org/content/early/2022/07/21/2022.07.20.500902 } ,
eprint = { https://www.biorxiv.org/content/early/2022/07/21/2022.07.20.500902.full.pdf } ,
journal = { bioRxiv }
}
@article { Li2024SwitchEA ,
title = { Switch EMA: A Free Lunch for Better Flatness and Sharpness } ,
author = { Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2402.09240 } ,
url = { https://api.semanticscholar.org/CorpusID:267657558 }
}
@article { Nguyen2023MitigatingOI ,
title = { Mitigating Over-smoothing in Transformers via Regularized Nonlocal Functionals } ,
author = { Tam Nguyen and Tan M. Nguyen and Richard G. Baraniuk } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2312.00751 } ,
url = { https://api.semanticscholar.org/CorpusID:264300597 }
}
@inproceedings { Zhou2024ValueRL ,
title = { Value Residual Learning For Alleviating Attention Concentration In Transformers } ,
author = { Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan } ,
year = { 2024 } ,
url = { https://api.semanticscholar.org/CorpusID:273532030 }
}