diff --git a/README.md b/README.md index 7e66e8cd3..38ab7fe8a 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,17 @@ v2.0. This is a completely new model that was entered in CASP14 and published in Nature. For simplicity, we refer to this model as AlphaFold throughout the rest of this document. -Any publication that discloses findings arising from using this source code or -the model parameters should [cite](#citing-this-work) the -[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2). Please also refer -to the +We also provide an implementation of AlphaFold-Multimer. This represents a work +in progress and AlphaFold-Multimer isn't expected to be as stable as our monomer +AlphaFold system. +[Read the guide](#updating-existing-alphafold-installation-to-include-alphafold-multimers) +for how to upgrade and update code. + +Any publication that discloses findings arising from using this source code or the model parameters should [cite](#citing-this-work) the +[AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and, if +applicable, the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). + +Please also refer to the [Supplementary Information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-021-03819-2/MediaObjects/41586_2021_3819_MOESM1_ESM.pdf) for a detailed description of the method. @@ -58,18 +65,25 @@ The following steps are required in order to run AlphaFold: or take a look at the following [NVIDIA Docker issue](https://github.com/NVIDIA/nvidia-docker/issues/1447#issuecomment-801479573). +If you wish to run AlphaFold using Singularity (a common containerization platform on HPC systems) we recommend using some of the +third party Singularity setups as linked in +https://github.com/deepmind/alphafold/issues/10 or +https://github.com/deepmind/alphafold/issues/24. + ### Genetic databases This step requires `aria2c` to be installed on your machine. AlphaFold needs multiple genetic (sequence) databases to run: -* [UniRef90](https://www.uniprot.org/help/uniref), -* [MGnify](https://www.ebi.ac.uk/metagenomics/), * [BFD](https://bfd.mmseqs.com/), -* [Uniclust30](https://uniclust.mmseqs.com/), +* [MGnify](https://www.ebi.ac.uk/metagenomics/), * [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/), -* [PDB](https://www.rcsb.org/) (structures in the mmCIF format). +* [PDB](https://www.rcsb.org/) (structures in the mmCIF format), +* [PDB seqres](https://www.rcsb.org/) – only for AlphaFold-Multimer, +* [Uniclust30](https://uniclust.mmseqs.com/), +* [UniProt](https://www.uniprot.org/uniprot/) – only for AlphaFold-Multimer, +* [UniRef90](https://www.uniprot.org/help/uniref). We provide a script `scripts/download_all_data.sh` that can be used to download and set up all of these databases: @@ -89,9 +103,13 @@ and set up all of these databases: ``` will download a reduced version of the databases to be used with the - `reduced_dbs` preset. + `reduced_dbs` database preset. -We don't provide exactly the versions used in CASP14 -- see the [note on +:ledger: **Note: The download directory `` should _not_ be a +subdirectory in the AlphaFold repository directory.** If it is, the Docker build +will be slow as the large databases will be copied during the image creation. + +We don't provide exactly the database versions used in CASP14 – see the [note on reproducibility](#note-on-reproducibility). Some of the databases are mirrored for speed, see [mirrored databases](#mirrored-databases). @@ -100,8 +118,8 @@ and the total size when unzipped is 2.2 TB. Please make sure you have a large enough hard drive space, bandwidth and time to download. We recommend using an SSD for better genetic search performance.** -This script will also download the model parameter files. Once the script has -finished, you should have the following directory structure: +The `download_all_data.sh` script will also download the model parameter files. +Once the script has finished, you should have the following directory structure: ``` $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) @@ -112,24 +130,29 @@ $DOWNLOAD_DIR/ # Total: ~ 2.2 TB (download: 438 GB) params/ # ~ 3.5 GB (download: 3.5 GB) # 5 CASP14 models, # 5 pTM models, + # 5 AlphaFold-Multimer models, # LICENSE, - # = 11 files. + # = 16 files. pdb70/ # ~ 56 GB (download: 19.5 GB) # 9 files. pdb_mmcif/ # ~ 206 GB (download: 46 GB) mmcif_files/ # About 180,000 .cif files. obsolete.dat + pdb_seqres/ # ~ 0.2 GB (download: 0.2 GB) + pdb_seqres.txt small_bfd/ # ~ 17 GB (download: 9.6 GB) bfd-first_non_consensus_sequences.fasta uniclust30/ # ~ 86 GB (download: 24.9 GB) uniclust30_2018_08/ # 13 files. + uniprot/ # ~ 98.3 GB (download: 49 GB) + uniprot.fasta uniref90/ # ~ 58 GB (download: 29.7 GB) uniref90.fasta ``` -`bfd/` is only downloaded if you download the full databasees, and `small_bfd/` +`bfd/` is only downloaded if you download the full databases, and `small_bfd/` is only downloaded if you download the reduced databases. ### Model parameters @@ -140,7 +163,7 @@ CC BY-NC 4.0 license. Please see the [Disclaimer](#license-and-disclaimer) below for more detail. The AlphaFold parameters are available from -https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar, and +https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar, and are downloaded as part of the `scripts/download_all_data.sh` script. This script will download parameters for: @@ -148,8 +171,46 @@ will download parameters for: structure prediction quality (see Jumper et al. 2021, Suppl. Methods 1.12 for details). * 5 pTM models, which were fine-tuned to produce pTM (predicted TM-score) and - predicted aligned error values alongside their structure predictions (see - Jumper et al. 2021, Suppl. Methods 1.9.7 for details). + (PAE) predicted aligned error values alongside their structure predictions + (see Jumper et al. 2021, Suppl. Methods 1.9.7 for details). +* 5 AlphaFold-Multimer models that produce pTM and PAE values alongside their + structure predictions. + +### Updating existing AlphaFold installation to include AlphaFold-Multimers + +If you have AlphaFold v2.0.0 or v2.0.1 you can either reinstall AlphaFold fully +from scratch (remove everything and run the setup from scratch) or you can do an +incremental update that will be significantly faster but will require a bit more +work. Make sure you follow these steps in the exact order they are listed below: + +1. **Update the code.** + * Go to the directory with the cloned AlphaFold repository and run + `git fetch origin main` to get all code updates. +1. **Download the UniProt and PDB seqres databases.** + * Run `scripts/download_uniprot.sh `. + * Remove `/pdb_mmcif`. It is needed to have PDB SeqRes and + PDB from exactly the same date. Failure to do this step will result in + potential errors when searching for templates when running + AlphaFold-Multimer. + * Run `scripts/download_pdb_mmcif.sh `. + * Run `scripts/download_pdb_seqres.sh `. +1. **Update the model parameters.** + * Remove the old model parameters in `/params`. + * Download new model parameters using + `scripts/download_alphafold_params.sh `. +1. **Follow [Running AlphaFold](#running-alphafold).** + +#### API changes between v2.0.0 and v2.1.0 + +We tried to keep the API as much backwards compatible as possible, but we had to +change the following: + +* The `RunModel.predict()` now needs a `random_seed` argument as MSA sampling + happens inside the Multimer model. +* The `preset` flag in `run_alphafold.py` and `run_docker.py` was split into + `db_preset` and `model_preset`. +* Setting the `data_dir` flag is now needed when using `run_docker.py`. + ## Running AlphaFold @@ -164,8 +225,6 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional git clone https://github.com/deepmind/alphafold.git ``` -1. Modify `DOWNLOAD_DIR` in `docker/run_docker.py` to be the path to the - directory containing the downloaded databases. 1. Build the Docker image: ```bash @@ -181,14 +240,19 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional pip3 install -r docker/requirements.txt ``` -1. Run `run_docker.py` pointing to a FASTA file containing the protein sequence - for which you wish to predict the structure. If you are predicting the - structure of a protein that is already in PDB and you wish to avoid using it - as a template, then `max_template_date` must be set to be before the release - date of the structure. For example, for the T1050 CASP14 target: +1. Run `run_docker.py` pointing to a FASTA file containing the protein + sequence(s) for which you wish to predict the structure. If you are + predicting the structure of a protein that is already in PDB and you wish to + avoid using it as a template, then `max_template_date` must be set to be + before the release date of the structure. You must also provide the path to + the directory containing the downloaded databases. For example, for the + T1050 CASP14 target: ```bash - python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14 + python3 docker/run_docker.py \ + --fasta_paths=T1050.fasta \ + --max_template_date=2020-05-14 \ + --data_dir=$DOWNLOAD_DIR ``` By default, Alphafold will attempt to use all visible GPU devices. To use a @@ -197,33 +261,76 @@ with 12 vCPUs, 85 GB of RAM, a 100 GB boot disk, the databases on an additional [GPU enumeration](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#gpu-enumeration) for more details. -1. You can control AlphaFold speed / quality tradeoff by adding - `--preset=reduced_dbs`, `--preset=full_dbs` or `--preset=casp14` to the run - command. We provide the following presets: +1. You can control which AlphaFold model to run by adding the + `--model_preset=` flag. We provide the following models: + + * **monomer**: This is the original model used at CASP14 with no ensembling. + + * **monomer\_casp14**: This is the original model used at CASP14 with + `num_ensemble=8`, matching our CASP14 configuration. This is largely + provided for reproducibility as it is 8x more computationally + expensive for limited accuracy gain (+0.1 average GDT gain on CASP14 + domains). + + * **monomer\_ptm**: This is the original CASP14 model fine tuned with the + pTM head, providing a pairwise confidence measure. It is slightly less + accurate than the normal monomer model. + + * **multimer**: This is the [AlphaFold-Multimer](#citing-this-work) model. + To use this model, provide a multi-sequence FASTA file. In addition, the + UniProt database should have been downloaded. - * **reduced_dbs**: This preset is optimized for speed and lower hardware - requirements. It runs with a reduced version of the BFD database and - with no ensembling. It requires 8 CPU cores (vCPUs), 8 GB of RAM, and - 600 GB of disk space. - * **full_dbs**: The model in this preset is 8 times faster than the - `casp14` preset with a very minor quality drop (-0.1 average GDT drop on - CASP14 domains). It runs with all genetic databases and with no - ensembling. - * **casp14**: This preset uses the same settings as were used in CASP14. - It runs with all genetic databases and with 8 ensemblings. +1. You can control MSA speed/quality tradeoff by adding + `--db_preset=reduced_dbs` or `--db_preset=full_dbs` to the run command. We + provide the following presets: - Running the command above with the `casp14` preset would look like this: + * **reduced\_dbs**: This preset is optimized for speed and lower hardware + requirements. It runs with a reduced version of the BFD database. + It requires 8 CPU cores (vCPUs), 8 GB of RAM, and 600 GB of disk space. + + * **full\_dbs**: This runs with all genetic databases used at CASP14. + + Running the command above with the `monomer` model preset and the + `reduced_dbs` data preset would look like this: ```bash - python3 docker/run_docker.py --fasta_paths=T1050.fasta --max_template_date=2020-05-14 --preset=casp14 + python3 docker/run_docker.py \ + --fasta_paths=T1050.fasta \ + --max_template_date=2020-05-14 \ + --model_preset=monomer \ + --db_preset=reduced_dbs \ + --data_dir=$DOWNLOAD_DIR ``` +### Running AlphaFold-Multimer + +All steps are the same as when running the monomer system, but you will have to + +* provide an input fasta with multiple sequences, +* set `--model_preset=multimer`, +* optionally set the `--is_prokaryote_list` flag with booleans that determine + whether all input sequences in the given fasta file are prokaryotic. If that + is not the case or the origin is unknown, set to `false` for that fasta. + +An example that folds two protein complexes `multimer1` and `multimer2` where +the first is prokaryotic and the second isn't: + +```bash +python3 docker/run_docker.py \ + --fasta_paths=multimer1.fasta,multimer2.fasta \ + --is_prokaryote_list=true,false \ + --max_template_date=2020-05-14 \ + --model_preset=multimer \ + --data_dir=$DOWNLOAD_DIR +``` + ### AlphaFold output -The outputs will be in a subfolder of `output_dir` in `run_docker.py`. They -include the computed MSAs, unrelaxed structures, relaxed structures, ranked -structures, raw model outputs, prediction metadata, and section timings. The -`output_dir` directory will have the following structure: +The outputs will be saved in a subdirectory of the directory provided via the +`--output_dir` flag of `run_docker.py` (defaults to `/tmp/alphafold/`). The +outputs include the computed MSAs, unrelaxed structures, relaxed structures, +ranked structures, raw model outputs, prediction metadata, and section timings. +The `--output_dir` directory will have the following structure: ``` / @@ -312,7 +419,7 @@ develop on top of the `RunModel.predict` method with a parallel system for precomputing multi-sequence alignments. Alternatively, this script can be run repeatedly with only moderate overhead. -## Note on reproducibility +## Note on CASP14 reproducibility AlphaFold's output for a small number of proteins has high inter-run variance, and may be affected by changes in the input data. The CASP14 target T1064 is a @@ -359,6 +466,21 @@ If you use the code or data in this package, please cite: } ``` +In addition, if you use the AlphaFold-Multimer mode, please cite: + +```bibtex +@article {AlphaFold-Multimer2021, + author = {Evans, Richard and O{\textquoteright}Neill, Michael and Pritzel, Alexander and Antropova, Natasha and Senior, Andrew and Green, Tim and {\v{Z}}{\'\i}dek, Augustin and Bates, Russ and Blackwell, Sam and Yim, Jason and Ronneberger, Olaf and Bodenstein, Sebastian and Zielinski, Michal and Bridgland, Alex and Potapenko, Anna and Cowie, Andrew and Tunyasuvunakool, Kathryn and Jain, Rishub and Clancy, Ellen and Kohli, Pushmeet and Jumper, John and Hassabis, Demis}, + journal = {bioRxiv} + title = {Protein complex prediction with AlphaFold-Multimer}, + year = {2021}, + elocation-id = {2021.10.04.463034}, + doi = {10.1101/2021.10.04.463034}, + URL = {https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034}, + eprint = {https://www.biorxiv.org/content/early/2021/10/04/2021.10.04.463034.full.pdf}, +} +``` + ## Community contributions Colab notebooks provided by the community (please note that these notebooks may @@ -391,6 +513,7 @@ and packages: * [NumPy](https://numpy.org) * [OpenMM](https://github.com/openmm/openmm) * [OpenStructure](https://openstructure.org) +* [pandas](https://pandas.pydata.org/) * [pymol3d](https://github.com/avirshup/py3dmol) * [SciPy](https://scipy.org) * [Sonnet](https://github.com/deepmind/sonnet) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 5f1085f75..787ec44dd 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -111,8 +111,10 @@ def compute_predicted_aligned_error( def predicted_tm_score( logits: np.ndarray, breaks: np.ndarray, - residue_weights: Optional[np.ndarray] = None) -> np.ndarray: - """Computes predicted TM alignment score. + residue_weights: Optional[np.ndarray] = None, + asym_id: Optional[np.ndarray] = None, + interface: bool = False) -> np.ndarray: + """Computes predicted TM alignment or predicted interface TM alignment score. Args: logits: [num_res, num_res, num_bins] the logits output from @@ -120,9 +122,12 @@ def predicted_tm_score( breaks: [num_bins] the error bins. residue_weights: [num_res] the per residue weights to use for the expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation, i.e. when interface=True. + interface: If True, interface predicted TM score is computed. Returns: - ptm_score: the predicted TM alignment score. + ptm_score: The predicted TM alignment or the predicted iTM score. """ # residue_weights has to be in [0, 1], but can be floating-point, i.e. the @@ -132,24 +137,32 @@ def predicted_tm_score( bin_centers = _calculate_bin_centers(breaks) - num_res = np.sum(residue_weights) + num_res = int(np.sum(residue_weights)) # Clip num_res to avoid negative/undefined d0. clipped_num_res = max(num_res, 19) - # Compute d_0(num_res) as defined by TM-score, eqn. (5) in - # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf - # Yang & Skolnick "Scoring function for automated - # assessment of protein structure template quality" 2004 + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick + # "Scoring function for automated assessment of protein structure template + # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 - # Convert logits to probs + # Convert logits to probs. probs = scipy.special.softmax(logits, axis=-1) - # TM-Score term for every bin + # TM-Score term for every bin. tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) - # E_distances tm(distance) + # E_distances tm(distance). predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) - normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum()) + pair_mask = np.ones(shape=(num_res, num_res), dtype=bool) + if interface: + pair_mask *= asym_id[:, None] != asym_id[None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / (1e-8 + np.sum( + pair_residue_weights, axis=-1, keepdims=True)) per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) diff --git a/alphafold/common/protein.py b/alphafold/common/protein.py index 2848f5bbc..8faa4c0aa 100644 --- a/alphafold/common/protein.py +++ b/alphafold/common/protein.py @@ -23,6 +23,10 @@ FeatureDict = Mapping[str, np.ndarray] ModelOutput = Mapping[str, Any] # Is a nested dict. +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + @dataclasses.dataclass(frozen=True) class Protein: @@ -43,11 +47,21 @@ class Protein: # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. residue_index: np.ndarray # [num_res] + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + # B-factors, or temperature factors, of each residue (in sq. angstroms units), # representing the displacement of the residue from its ground truth mean # value. b_factors: np.ndarray # [num_res, num_atom_type] + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: """Takes a PDB string and constructs a Protein object. @@ -57,9 +71,8 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: Args: pdb_str: The contents of the pdb file - chain_id: If None, then the pdb file must contain a single chain (which - will be parsed). If chain_id is specified (e.g. A), then only that chain - is parsed. + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. Returns: A new `Protein` parsed from the pdb contents. @@ -73,57 +86,63 @@ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: f'Only single model PDBs are supported. Found {len(models)} models.') model = models[0] - if chain_id is not None: - chain = model[chain_id] - else: - chains = list(model.get_chains()) - if len(chains) != 1: - raise ValueError( - 'Only single chain PDBs are supported when chain_id not specified. ' - f'Found {len(chains)} chains.') - else: - chain = chains[0] - atom_positions = [] aatype = [] atom_mask = [] residue_index = [] + chain_ids = [] b_factors = [] - for res in chain: - if res.id[2] != ' ': - raise ValueError( - f'PDB contains an insertion code at chain {chain.id} and residue ' - f'index {res.id[1]}. These are not supported.') - res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') - restype_idx = residue_constants.restype_order.get( - res_shortname, residue_constants.restype_num) - pos = np.zeros((residue_constants.atom_type_num, 3)) - mask = np.zeros((residue_constants.atom_type_num,)) - res_b_factors = np.zeros((residue_constants.atom_type_num,)) - for atom in res: - if atom.name not in residue_constants.atom_types: - continue - pos[residue_constants.atom_order[atom.name]] = atom.coord - mask[residue_constants.atom_order[atom.name]] = 1. - res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor - if np.sum(mask) < 0.5: - # If no known atom positions are reported for the residue then skip it. + for chain in model: + if chain_id is not None and chain.id != chain_id: continue - aatype.append(restype_idx) - atom_positions.append(pos) - atom_mask.append(mask) - residue_index.append(res.id[1]) - b_factors.append(res_b_factors) + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num,)) + res_b_factors = np.zeros((residue_constants.atom_type_num,)) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1. + res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) return Protein( atom_positions=np.array(atom_positions), atom_mask=np.array(atom_mask), aatype=np.array(aatype), residue_index=np.array(residue_index), + chain_index=chain_index, b_factors=np.array(b_factors)) +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + def to_pdb(prot: Protein) -> str: """Converts a `Protein` instance to a PDB string. @@ -143,16 +162,33 @@ def to_pdb(prot: Protein) -> str: aatype = prot.aatype atom_positions = prot.atom_positions residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) b_factors = prot.b_factors if np.any(aatype > residue_constants.restype_num): raise ValueError('Invalid aatypes.') + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + pdb_lines.append('MODEL 1') atom_index = 1 - chain_id = 'A' + last_chain_index = chain_index[0] # Add all atom sites. for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append(_chain_end( + atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]], + residue_index[i - 1])) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + res_name_3 = res_1to3(aatype[i]) for atom_name, pos, mask, b_factor in zip( atom_types, atom_positions[i], atom_mask[i], b_factors[i]): @@ -168,7 +204,7 @@ def to_pdb(prot: Protein) -> str: charge = '' # PDB is a columnar format, every space matters here! atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' - f'{res_name_3:>3} {chain_id:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' f'{residue_index[i]:>4}{insertion_code:>1} ' f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' f'{occupancy:>6.2f}{b_factor:>6.2f} ' @@ -176,17 +212,15 @@ def to_pdb(prot: Protein) -> str: pdb_lines.append(atom_line) atom_index += 1 - # Close the chain. - chain_end = 'TER' - chain_termination_line = ( - f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' - f'{chain_id:>1}{residue_index[-1]:>4}') - pdb_lines.append(chain_termination_line) + # Close the final chain. + pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], residue_index[-1])) pdb_lines.append('ENDMDL') - pdb_lines.append('END') - pdb_lines.append('') - return '\n'.join(pdb_lines) + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. def ideal_atom_mask(prot: Protein) -> np.ndarray: @@ -205,25 +239,40 @@ def ideal_atom_mask(prot: Protein) -> np.ndarray: return residue_constants.STANDARD_ATOM_MASK[prot.aatype] -def from_prediction(features: FeatureDict, result: ModelOutput, - b_factors: Optional[np.ndarray] = None) -> Protein: +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + remove_leading_feature_dimension: bool = True) -> Protein: """Assembles a protein from a prediction. Args: features: Dictionary holding model inputs. result: Dictionary holding model outputs. b_factors: (Optional) B-factors to use for the protein. + remove_leading_feature_dimension: Whether to remove the leading dimension + of the `features` values. Returns: A protein instance. """ fold_output = result['structure_module'] + + def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray: + return arr[0] if remove_leading_feature_dimension else arr + + if 'asym_id' in features: + chain_index = _maybe_remove_leading_dim(features['asym_id']) + else: + chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype'])) + if b_factors is None: b_factors = np.zeros_like(fold_output['final_atom_mask']) return Protein( - aatype=features['aatype'][0], + aatype=_maybe_remove_leading_dim(features['aatype']), atom_positions=fold_output['final_atom_positions'], atom_mask=fold_output['final_atom_mask'], - residue_index=features['residue_index'][0] + 1, + residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1, + chain_index=chain_index, b_factors=b_factors) diff --git a/alphafold/common/protein_test.py b/alphafold/common/protein_test.py index bfd306d04..281a471e1 100644 --- a/alphafold/common/protein_test.py +++ b/alphafold/common/protein_test.py @@ -38,11 +38,17 @@ def _check_shapes(self, prot, num_res): self.assertEqual((num_res,), prot.aatype.shape) self.assertEqual((num_res, num_atoms), prot.atom_mask.shape) self.assertEqual((num_res,), prot.residue_index.shape) + self.assertEqual((num_res,), prot.chain_index.shape) self.assertEqual((num_res, num_atoms), prot.b_factors.shape) - @parameterized.parameters(('2rbg.pdb', 'A', 282), - ('2rbg.pdb', 'B', 282)) - def test_from_pdb_str(self, pdb_file, chain_id, num_res): + @parameterized.named_parameters( + dict(testcase_name='chain_A', + pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1), + dict(testcase_name='chain_B', + pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1), + dict(testcase_name='multichain', + pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2)) + def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains): pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, pdb_file) with open(pdb_file) as f: @@ -52,14 +58,19 @@ def test_from_pdb_str(self, pdb_file, chain_id, num_res): self.assertGreaterEqual(prot.aatype.min(), 0) # Allow equal since unknown restypes have index equal to restype_num. self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num) + self.assertLen(np.unique(prot.chain_index), num_chains) def test_to_pdb(self): with open( os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, '2rbg.pdb')) as f: pdb_string = f.read() - prot = protein.from_pdb_string(pdb_string, chain_id='A') + prot = protein.from_pdb_string(pdb_string) pdb_string_reconstr = protein.to_pdb(prot) + + for line in pdb_string_reconstr.splitlines(): + self.assertLen(line, 80) + prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) @@ -69,6 +80,8 @@ def test_to_pdb(self): prot_reconstr.atom_mask, prot.atom_mask) np.testing.assert_array_equal( prot_reconstr.residue_index, prot.residue_index) + np.testing.assert_array_equal( + prot_reconstr.chain_index, prot.chain_index) np.testing.assert_array_almost_equal( prot_reconstr.b_factors, prot.b_factors) @@ -77,9 +90,9 @@ def test_ideal_atom_mask(self): os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, '2rbg.pdb')) as f: pdb_string = f.read() - prot = protein.from_pdb_string(pdb_string, chain_id='A') + prot = protein.from_pdb_string(pdb_string) ideal_mask = protein.ideal_atom_mask(prot) - non_ideal_residues = set([102] + list(range(127, 285))) + non_ideal_residues = set([102] + list(range(127, 286))) for i, (res, atom_mask) in enumerate( zip(prot.residue_index, prot.atom_mask)): if res in non_ideal_residues: @@ -87,6 +100,18 @@ def test_ideal_atom_mask(self): else: self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') + def test_too_many_chains(self): + num_res = protein.PDB_MAX_CHAINS + 1 + num_atom_type = residue_constants.atom_type_num + with self.assertRaises(ValueError): + _ = protein.Protein( + atom_positions=np.random.random([num_res, num_atom_type, 3]), + aatype=np.random.randint(0, 21, [num_res]), + atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32), + residue_index=np.arange(1, num_res+1), + chain_index=np.arange(num_res), + b_factors=np.random.uniform(1, 100, [num_res])) + if __name__ == '__main__': absltest.main() diff --git a/alphafold/common/residue_constants.py b/alphafold/common/residue_constants.py index 3bb370a79..4318875a9 100644 --- a/alphafold/common/residue_constants.py +++ b/alphafold/common/residue_constants.py @@ -399,13 +399,12 @@ def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], ("residue_virtual_bonds"). Returns: - residue_bonds: dict that maps resname --> list of Bond tuples - residue_virtual_bonds: dict that maps resname --> list of Bond tuples - residue_bond_angles: dict that maps resname --> list of BondAngle tuples + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. """ stereo_chemical_props_path = os.path.join( - os.environ.get('ALPHAFOLD_BASE_DIR', default=''), - 'alphafold/common/stereo_chemical_props.txt' + os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt' ) with open(stereo_chemical_props_path, 'rt') as f: stereo_chemical_props = f.read() diff --git a/alphafold/data/feature_processing.py b/alphafold/data/feature_processing.py new file mode 100644 index 000000000..aefd09a19 --- /dev/null +++ b/alphafold/data/feature_processing.py @@ -0,0 +1,231 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Feature processing logic for multimer data pipeline.""" + +from typing import Iterable, MutableMapping, List + +from alphafold.common import residue_constants +from alphafold.data import msa_pairing +from alphafold.data import pipeline +import numpy as np + +REQUIRED_FEATURES = frozenset({ + 'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids', + 'all_crops_all_chains_mask', 'all_crops_all_chains_positions', + 'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id', + 'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean', + 'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments', + 'num_templates', 'queue_size', 'residue_index', 'resolution', + 'seq_length', 'seq_mask', 'sym_id', 'template_aatype', + 'template_all_atom_mask', 'template_all_atom_positions' +}) + +MAX_TEMPLATES = 4 +MSA_CROP_SIZE = 2048 + + +def _is_homomer_or_monomer(chains: Iterable[pipeline.FeatureDict]) -> bool: + """Checks if a list of chains represents a homomer/monomer example.""" + # Note that an entity_id of 0 indicates padding. + num_unique_chains = len(np.unique(np.concatenate( + [np.unique(chain['entity_id'][chain['entity_id'] > 0]) for + chain in chains]))) + return num_unique_chains == 1 + + +def pair_and_merge( + all_chain_features: MutableMapping[str, pipeline.FeatureDict], + is_prokaryote: bool) -> pipeline.FeatureDict: + """Runs processing on features to augment, pair and merge. + + Args: + all_chain_features: A MutableMap of dictionaries of features for each chain. + is_prokaryote: Whether the target complex is from a prokaryotic or + eukaryotic organism. + + Returns: + A dictionary of features. + """ + + process_unmerged_features(all_chain_features) + + np_chains_list = list(all_chain_features.values()) + + pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) + + if pair_msa_sequences: + np_chains_list = msa_pairing.create_paired_features( + chains=np_chains_list, prokaryotic=is_prokaryote) + np_chains_list = msa_pairing.deduplicate_unpaired_sequences(np_chains_list) + np_chains_list = crop_chains( + np_chains_list, + msa_crop_size=MSA_CROP_SIZE, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES) + np_example = msa_pairing.merge_chain_features( + np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES) + np_example = process_final(np_example) + return np_example + + +def crop_chains( + chains_list: List[pipeline.FeatureDict], + msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int) -> List[pipeline.FeatureDict]: + """Crops the MSAs for a set of chains. + + Args: + chains_list: A list of chains to be cropped. + msa_crop_size: The total number of sequences to crop from the MSA. + pair_msa_sequences: Whether we are operating in sequence-pairing mode. + max_templates: The maximum templates to use per chain. + + Returns: + The chains cropped. + """ + + # Apply the cropping. + cropped_chains = [] + for chain in chains_list: + cropped_chain = _crop_single_chain( + chain, + msa_crop_size=msa_crop_size, + pair_msa_sequences=pair_msa_sequences, + max_templates=max_templates) + cropped_chains.append(cropped_chain) + + return cropped_chains + + +def _crop_single_chain(chain: pipeline.FeatureDict, + msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int) -> pipeline.FeatureDict: + """Crops msa sequences to `msa_crop_size`.""" + msa_size = chain['num_alignments'] + + if pair_msa_sequences: + msa_size_all_seq = chain['num_alignments_all_seq'] + msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chain's MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] + num_non_gapped_pairs = np.sum( + np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) + num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, + msa_crop_size_all_seq) + + # Restrict the unpaired crop size so that paired+unpaired sequences do not + # exceed msa_seqs_per_chain for each chain. + max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) + msa_crop_size = np.minimum(msa_size, max_msa_crop_size) + else: + msa_crop_size = np.minimum(msa_size, msa_crop_size) + + include_templates = 'template_aatype' in chain and max_templates + if include_templates: + num_templates = chain['template_aatype'].shape[0] + templates_crop_size = np.minimum(num_templates, max_templates) + + for k in chain: + k_split = k.split('_all_seq')[0] + if k_split in msa_pairing.TEMPLATE_FEATURES: + chain[k] = chain[k][:templates_crop_size, :] + elif k_split in msa_pairing.MSA_FEATURES: + if '_all_seq' in k and pair_msa_sequences: + chain[k] = chain[k][:msa_crop_size_all_seq, :] + else: + chain[k] = chain[k][:msa_crop_size, :] + + chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) + if include_templates: + chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32) + if pair_msa_sequences: + chain['num_alignments_all_seq'] = np.asarray( + msa_crop_size_all_seq, dtype=np.int32) + return chain + + +def process_final(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict: + """Final processing steps in data pipeline, after merging and pairing.""" + np_example = _correct_msa_restypes(np_example) + np_example = _make_seq_mask(np_example) + np_example = _make_msa_mask(np_example) + np_example = _filter_features(np_example) + return np_example + + +def _correct_msa_restypes(np_example): + """Correct MSA restype to have the same order as residue_constants.""" + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + np_example['msa'] = np.take(new_order_list, np_example['msa'], axis=0) + np_example['msa'] = np_example['msa'].astype(np.int32) + return np_example + + +def _make_seq_mask(np_example): + np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) + return np_example + + +def _make_msa_mask(np_example): + """Mask features are all ones, but will later be zero-padded.""" + + np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.float32) + + seq_mask = (np_example['entity_id'] > 0).astype(np.float32) + np_example['msa_mask'] *= seq_mask[None] + + return np_example + + +def _filter_features(np_example: pipeline.FeatureDict) -> pipeline.FeatureDict: + """Filters features of example to only those requested.""" + return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} + + +def process_unmerged_features( + all_chain_features: MutableMapping[str, pipeline.FeatureDict]): + """Postprocessing stage for per-chain features before merging.""" + num_chains = len(all_chain_features) + for chain_features in all_chain_features.values(): + # Convert deletion matrices to float. + chain_features['deletion_matrix'] = np.asarray( + chain_features.pop('deletion_matrix_int'), dtype=np.float32) + if 'deletion_matrix_int_all_seq' in chain_features: + chain_features['deletion_matrix_all_seq'] = np.asarray( + chain_features.pop('deletion_matrix_int_all_seq'), dtype=np.float32) + + chain_features['deletion_mean'] = np.mean( + chain_features['deletion_matrix'], axis=0) + + # Add all_atom_mask and dummy all_atom_positions based on aatype. + all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ + chain_features['aatype']] + chain_features['all_atom_mask'] = all_atom_mask + chain_features['all_atom_positions'] = np.zeros( + list(all_atom_mask.shape) + [3]) + + # Add assembly_num_chains. + chain_features['assembly_num_chains'] = np.asarray(num_chains) + + # Add entity_mask. + for chain_features in all_chain_features.values(): + chain_features['entity_mask'] = ( + chain_features['entity_id'] != 0).astype(np.int32) diff --git a/alphafold/data/mmcif_parsing.py b/alphafold/data/mmcif_parsing.py index 18375165a..acb9396b5 100644 --- a/alphafold/data/mmcif_parsing.py +++ b/alphafold/data/mmcif_parsing.py @@ -15,6 +15,7 @@ """Parses the mmCIF file format.""" import collections import dataclasses +import functools import io from typing import Any, Mapping, Optional, Sequence, Tuple @@ -160,6 +161,7 @@ def mmcif_loop_to_dict(prefix: str, return {entry[index]: entry for entry in entries} +@functools.lru_cache(16, typed=False) def parse(*, file_id: str, mmcif_string: str, @@ -314,7 +316,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: raw_resolution = parsed_info[res_key][0] header['resolution'] = float(raw_resolution) except ValueError: - logging.warning('Invalid resolution format: %s', parsed_info[res_key]) + logging.debug('Invalid resolution format: %s', parsed_info[res_key]) return header diff --git a/alphafold/data/msa_identifiers.py b/alphafold/data/msa_identifiers.py new file mode 100644 index 000000000..00893d126 --- /dev/null +++ b/alphafold/data/msa_identifiers.py @@ -0,0 +1,92 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re +from typing import Optional + + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + uniprot_accession_id: str = '' + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets accession id and species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with a uniprot_accession_id and species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers( + uniprot_accession_id=matches.group('AccessionIdentifier'), + species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> Optional[str]: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) diff --git a/alphafold/data/msa_pairing.py b/alphafold/data/msa_pairing.py new file mode 100644 index 000000000..ddd36ee1e --- /dev/null +++ b/alphafold/data/msa_pairing.py @@ -0,0 +1,638 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pairing logic for multimer data pipeline.""" + +import collections +import functools +import re +import string +from typing import Any, Dict, Iterable, List, Sequence + +from alphafold.common import residue_constants +from alphafold.data import pipeline +import numpy as np +import pandas as pd +import scipy.linalg + +ALPHA_ACCESSION_ID_MAP = {x: y for y, x in enumerate(string.ascii_uppercase)} +ALPHANUM_ACCESSION_ID_MAP = { + chr: num for num, chr in enumerate(string.ascii_uppercase + string.digits) +} # A-Z,0-9 +NUM_ACCESSION_ID_MAP = {str(x): x for x in range(10)} # 0-9 + +MSA_GAP_IDX = residue_constants.restypes_with_x_and_gap.index('-') +SEQUENCE_GAP_CUTOFF = 0.5 +SEQUENCE_SIMILARITY_CUTOFF = 0.9 + +MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX, + 'msa_mask_all_seq': 1, + 'deletion_matrix_all_seq': 0, + 'deletion_matrix_int_all_seq': 0, + 'msa': MSA_GAP_IDX, + 'msa_mask': 1, + 'deletion_matrix': 0, + 'deletion_matrix_int': 0} + +MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') +SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions', + 'all_atom_mask', 'seq_mask', 'between_segment_residues', + 'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id', + 'sym_id', 'entity_mask', 'deletion_mean', + 'prediction_atom_mask', + 'literature_positions', 'atom_indices_to_group_indices', + 'rigid_group_default_frame') +TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions', + 'template_all_atom_mask') +CHAIN_FEATURES = ('num_alignments', 'seq_length') + + +domain_name_pattern = re.compile( + r'''^(?P[a-z\d]{4}) + \{(?P[\d+(\+\d+)?])\} + (?P[a-zA-Z\d]+) + \{(?P\d+)\}$ + ''', re.VERBOSE) + + +def create_paired_features( + chains: Iterable[pipeline.FeatureDict], + prokaryotic: bool, + ) -> List[pipeline.FeatureDict]: + """Returns the original chains with paired NUM_SEQ features. + + Args: + chains: A list of feature dictionaries for each chain. + prokaryotic: Whether the target complex is from a prokaryotic organism. + Used to determine the distance metric for pairing. + + Returns: + A list of feature dictionaries with sequence features including only + rows to be paired. + """ + chains = list(chains) + chain_keys = chains[0].keys() + + if len(chains) < 2: + return chains + else: + updated_chains = [] + paired_chains_to_paired_row_indices = pair_sequences( + chains, prokaryotic) + paired_rows = reorder_paired_rows( + paired_chains_to_paired_row_indices) + + for chain_num, chain in enumerate(chains): + new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} + for feature_name in chain_keys: + if feature_name.endswith('_all_seq'): + feats_padded = pad_features(chain[feature_name], feature_name) + new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]] + new_chain['num_alignments_all_seq'] = np.asarray( + len(paired_rows[:, chain_num])) + updated_chains.append(new_chain) + return updated_chains + + +def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: + """Add a 'padding' row at the end of the features list. + + The padding row will be selected as a 'paired' row in the case of partial + alignment - for the chain that doesn't have paired alignment. + + Args: + feature: The feature to be padded. + feature_name: The name of the feature to be padded. + + Returns: + The feature with an additional padding row. + """ + assert feature.dtype != np.dtype(np.string_) + if feature_name in ('msa_all_seq', 'msa_mask_all_seq', + 'deletion_matrix_all_seq', 'deletion_matrix_int_all_seq'): + num_res = feature.shape[1] + padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], + feature.dtype) + elif feature_name in ('msa_uniprot_accession_identifiers_all_seq', + 'msa_species_identifiers_all_seq'): + padding = [b''] + else: + return feature + feats_padded = np.concatenate([feature, padding], axis=0) + return feats_padded + + +def _make_msa_df(chain_features: pipeline.FeatureDict) -> pd.DataFrame: + """Makes dataframe with msa features needed for msa pairing.""" + chain_msa = chain_features['msa_all_seq'] + query_seq = chain_msa[0] + per_seq_similarity = np.sum( + query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) + per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) + msa_df = pd.DataFrame({ + 'msa_species_identifiers': + chain_features['msa_species_identifiers_all_seq'], + 'msa_uniprot_accession_identifiers': + chain_features['msa_uniprot_accession_identifiers_all_seq'], + 'msa_row': + np.arange(len( + chain_features['msa_uniprot_accession_identifiers_all_seq'])), + 'msa_similarity': per_seq_similarity, + 'gap': per_seq_gap + }) + return msa_df + + +def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: + """Creates mapping from species to msa dataframe of that species.""" + species_lookup = {} + for species, species_df in msa_df.groupby('msa_species_identifiers'): + species_lookup[species] = species_df + return species_lookup + + +@functools.lru_cache(maxsize=65536) +def encode_accession(accession_id: str) -> int: + """Map accession codes to the serial order in which they were assigned.""" + alpha = ALPHA_ACCESSION_ID_MAP # A-Z + alphanum = ALPHANUM_ACCESSION_ID_MAP # A-Z,0-9 + num = NUM_ACCESSION_ID_MAP # 0-9 + + coding = 0 + + # This is based on the uniprot accession id format + # https://www.uniprot.org/help/accession_numbers + if accession_id[0] in {'O', 'P', 'Q'}: + bases = (alpha, num, alphanum, alphanum, alphanum, num) + elif len(accession_id) == 6: + bases = (alpha, num, alpha, alphanum, alphanum, num) + elif len(accession_id) == 10: + bases = (alpha, num, alpha, alphanum, alphanum, num, alpha, alphanum, + alphanum, num) + + product = 1 + for place, base in zip(reversed(accession_id), reversed(bases)): + coding += base[place] * product + product *= len(base) + + return coding + + +def _calc_id_diff(id_a: bytes, id_b: bytes) -> int: + return abs(encode_accession(id_a.decode()) - encode_accession(id_b.decode())) + + +def _find_all_accession_matches(accession_id_lists: List[List[bytes]], + diff_cutoff: int = 20 + ) -> List[List[Any]]: + """Finds accession id matches across the chains based on their difference.""" + all_accession_tuples = [] + current_tuple = [] + tokens_used_in_answer = set() + + def _matches_all_in_current_tuple(inp: bytes, diff_cutoff: int) -> bool: + return all((_calc_id_diff(s, inp) < diff_cutoff for s in current_tuple)) + + def _all_tokens_not_used_before() -> bool: + return all((s not in tokens_used_in_answer for s in current_tuple)) + + def dfs(level, accession_id, diff_cutoff=diff_cutoff) -> None: + if level == len(accession_id_lists) - 1: + if _all_tokens_not_used_before(): + all_accession_tuples.append(list(current_tuple)) + for s in current_tuple: + tokens_used_in_answer.add(s) + return + + if level == -1: + new_list = accession_id_lists[level+1] + else: + new_list = [(_calc_id_diff(accession_id, s), s) for + s in accession_id_lists[level+1]] + new_list = sorted(new_list) + new_list = [s for d, s in new_list] + + for s in new_list: + if (_matches_all_in_current_tuple(s, diff_cutoff) and + s not in tokens_used_in_answer): + current_tuple.append(s) + dfs(level + 1, s) + current_tuple.pop() + dfs(-1, '') + return all_accession_tuples + + +def _accession_row(msa_df: pd.DataFrame, accession_id: bytes) -> pd.Series: + matched_df = msa_df[msa_df.msa_uniprot_accession_identifiers == accession_id] + return matched_df.iloc[0] + + +def _match_rows_by_genetic_distance( + this_species_msa_dfs: List[pd.DataFrame], + cutoff: int = 20) -> List[List[int]]: + """Finds MSA sequence pairings across chains within a genetic distance cutoff. + + The genetic distance between two sequences is approximated by taking the + difference in their UniProt accession ids. + + Args: + this_species_msa_dfs: a list of dataframes containing MSA features for + sequences for a specific species. If species is missing for a chain, the + dataframe is set to None. + cutoff: the genetic distance cutoff. + + Returns: + A list of lists, each containing M indices corresponding to paired MSA rows, + where M is the number of chains. + """ + num_examples = len(this_species_msa_dfs) # N + + accession_id_lists = [] # M + match_index_to_chain_index = {} + for chain_index, species_df in enumerate(this_species_msa_dfs): + if species_df is not None: + accession_id_lists.append( + list(species_df.msa_uniprot_accession_identifiers.values)) + # Keep track of which of the this_species_msa_dfs are not None. + match_index_to_chain_index[len(accession_id_lists) - 1] = chain_index + + all_accession_id_matches = _find_all_accession_matches( + accession_id_lists, cutoff) # [k, M] + + all_paired_msa_rows = [] # [k, N] + for accession_id_match in all_accession_id_matches: + paired_msa_rows = [] + for match_index, accession_id in enumerate(accession_id_match): + # Map back to chain index. + chain_index = match_index_to_chain_index[match_index] + seq_series = _accession_row( + this_species_msa_dfs[chain_index], accession_id) + + if (seq_series.msa_similarity > SEQUENCE_SIMILARITY_CUTOFF or + seq_series.gap > SEQUENCE_GAP_CUTOFF): + continue + else: + paired_msa_rows.append(seq_series.msa_row) + # If a sequence is skipped based on sequence similarity to the respective + # target sequence or a gap cuttoff, the lengths of accession_id_match and + # paired_msa_rows will be different. Skip this match. + if len(paired_msa_rows) == len(accession_id_match): + paired_and_non_paired_msa_rows = np.array([-1] * num_examples) + matched_chain_indices = list(match_index_to_chain_index.values()) + paired_and_non_paired_msa_rows[matched_chain_indices] = paired_msa_rows + all_paired_msa_rows.append(list(paired_and_non_paired_msa_rows)) + return all_paired_msa_rows + + +def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame] + ) -> List[List[int]]: + """Finds MSA sequence pairings across chains based on sequence similarity. + + Each chain's MSA sequences are first sorted by their sequence similarity to + their respective target sequence. The sequences are then paired, starting + from the sequences most similar to their target sequence. + + Args: + this_species_msa_dfs: a list of dataframes containing MSA features for + sequences for a specific species. + + Returns: + A list of lists, each containing M indices corresponding to paired MSA rows, + where M is the number of chains. + """ + all_paired_msa_rows = [] + + num_seqs = [len(species_df) for species_df in this_species_msa_dfs + if species_df is not None] + take_num_seqs = np.min(num_seqs) + + sort_by_similarity = ( + lambda x: x.sort_values('msa_similarity', axis=0, ascending=False)) + + for species_df in this_species_msa_dfs: + if species_df is not None: + species_df_sorted = sort_by_similarity(species_df) + msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values + else: + msa_rows = [-1] * take_num_seqs # take the last 'padding' row + all_paired_msa_rows.append(msa_rows) + all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) + return all_paired_msa_rows + + +def pair_sequences(examples: List[pipeline.FeatureDict], + prokaryotic: bool) -> Dict[int, np.ndarray]: + """Returns indices for paired MSA sequences across chains.""" + + num_examples = len(examples) + + all_chain_species_dict = [] + common_species = set() + for chain_features in examples: + msa_df = _make_msa_df(chain_features) + species_dict = _create_species_dict(msa_df) + all_chain_species_dict.append(species_dict) + common_species.update(set(species_dict)) + + common_species = sorted(common_species) + common_species.remove(b'') # Remove target sequence species. + + all_paired_msa_rows = [np.zeros(len(examples), int)] + all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} + all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] + + for species in common_species: + if not species: + continue + this_species_msa_dfs = [] + species_dfs_present = 0 + for species_dict in all_chain_species_dict: + if species in species_dict: + this_species_msa_dfs.append(species_dict[species]) + species_dfs_present += 1 + else: + this_species_msa_dfs.append(None) + + # Skip species that are present in only one chain. + if species_dfs_present <= 1: + continue + + if np.any( + np.array([len(species_df) for species_df in + this_species_msa_dfs if + isinstance(species_df, pd.DataFrame)]) > 600): + continue + + # In prokaryotes (and some eukaryotes), interacting genes are often + # co-located on the chromosome into operons. Because of that we can assume + # that if two proteins' intergenic distance is less than a threshold, they + # two proteins will form an an interacting pair. + # In most eukaryotes, a single protein's MSA can contain many paralogs. + # Two genes may interact even if they are not close by genomic distance. + # In case of eukaryotes, some methods pair MSA sequences using sequence + # similarity method. + # See Jinbo Xu's work: + # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6030867/#B28. + if prokaryotic: + paired_msa_rows = _match_rows_by_genetic_distance(this_species_msa_dfs) + + if not paired_msa_rows: + continue + else: + paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs) + all_paired_msa_rows.extend(paired_msa_rows) + all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) + all_paired_msa_rows_dict = { + num_examples: np.array(paired_msa_rows) for + num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() + } + return all_paired_msa_rows_dict + + +def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray] + ) -> np.ndarray: + """Creates a list of indices of paired MSA rows across chains. + + Args: + all_paired_msa_rows_dict: a mapping from the number of paired chains to the + paired indices. + + Returns: + a list of lists, each containing indices of paired MSA rows across chains. + The paired-index lists are ordered by: + 1) the number of chains in the paired alignment, i.e, all-chain pairings + will come first. + 2) e-values + """ + all_paired_msa_rows = [] + + for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): + paired_rows = all_paired_msa_rows_dict[num_pairings] + paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows])) + paired_rows_sort_index = np.argsort(paired_rows_product) + all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) + + return np.array(all_paired_msa_rows) + + +def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: + """Like scipy.linalg.block_diag but with an optional padding value.""" + ones_arrs = [np.ones_like(x) for x in arrs] + off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs) + diag = scipy.linalg.block_diag(*arrs) + diag += (off_diag_mask * pad_value).astype(diag.dtype) + return diag + + +def _correct_post_merged_feats( + np_example: pipeline.FeatureDict, + np_chains_list: Sequence[pipeline.FeatureDict], + pair_msa_sequences: bool) -> pipeline.FeatureDict: + """Adds features that need to be computed/recomputed post merging.""" + + np_example['seq_length'] = np.asarray(np_example['aatype'].shape[0], + dtype=np.int32) + np_example['num_alignments'] = np.asarray(np_example['msa'].shape[0], + dtype=np.int32) + + if not pair_msa_sequences: + # Generate a bias that is 1 for the first row of every block in the + # block diagonal MSA - i.e. make sure the cluster stack always includes + # the query sequences for each chain (since the first row is the query + # sequence). + cluster_bias_masks = [] + for chain in np_chains_list: + mask = np.zeros(chain['msa'].shape[0]) + mask[0] = 1 + cluster_bias_masks.append(mask) + np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) + for x in np_chains_list] + + np_example['bert_mask'] = block_diag( + *msa_masks, pad_value=0) + else: + np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) + np_example['cluster_bias_mask'][0] = 1 + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [np.ones(x['msa'].shape, dtype=np.float32) for + x in np_chains_list] + msa_masks_all_seq = [np.ones(x['msa_all_seq'].shape, dtype=np.float32) for + x in np_chains_list] + + msa_mask_block_diag = block_diag( + *msa_masks, pad_value=0) + msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) + np_example['bert_mask'] = np.concatenate( + [msa_mask_all_seq, msa_mask_block_diag], axis=0) + return np_example + + +def _pad_templates(chains: Sequence[pipeline.FeatureDict], + max_templates: int) -> Sequence[pipeline.FeatureDict]: + """For each chain pad the number of templates to a fixed size. + + Args: + chains: A list of protein chains. + max_templates: Each chain will be padded to have this many templates. + + Returns: + The list of chains, updated to have template features padded to + max_templates. + """ + for chain in chains: + for k, v in chain.items(): + if k in TEMPLATE_FEATURES: + padding = np.zeros_like(v.shape) + padding[0] = max_templates - v.shape[0] + padding = [(0, p) for p in padding] + chain[k] = np.pad(v, padding, mode='constant') + return chains + + +def _merge_features_from_multiple_chains( + chains: Sequence[pipeline.FeatureDict], + pair_msa_sequences: bool) -> pipeline.FeatureDict: + """Merge features from multiple chains. + + Args: + chains: A list of feature dictionaries that we want to merge. + pair_msa_sequences: Whether to concatenate MSA features along the + num_res dimension (if True), or to block diagonalize them (if False). + + Returns: + A feature dictionary for the merged example. + """ + merged_example = {} + for feature_name in chains[0]: + feats = [x[feature_name] for x in chains] + feature_name_split = feature_name.split('_all_seq')[0] + if feature_name_split in MSA_FEATURES: + if pair_msa_sequences or '_all_seq' in feature_name: + merged_example[feature_name] = np.concatenate(feats, axis=1) + else: + merged_example[feature_name] = block_diag( + *feats, pad_value=MSA_PAD_VALUES[feature_name]) + elif feature_name_split in SEQ_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=0) + elif feature_name_split in TEMPLATE_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=1) + elif feature_name_split in CHAIN_FEATURES: + merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32) + else: + merged_example[feature_name] = feats[0] + return merged_example + + +def _merge_homomers_dense_msa( + chains: Iterable[pipeline.FeatureDict]) -> Sequence[pipeline.FeatureDict]: + """Merge all identical chains, making the resulting MSA dense. + + Args: + chains: An iterable of features for each chain. + + Returns: + A list of feature dictionaries. All features with the same entity_id + will be merged - MSA features will be concatenated along the num_res + dimension - making them dense. + """ + entity_chains = collections.defaultdict(list) + for chain in chains: + entity_id = chain['entity_id'][0] + entity_chains[entity_id].append(chain) + + grouped_chains = [] + for entity_id in sorted(entity_chains): + chains = entity_chains[entity_id] + grouped_chains.append(chains) + chains = [ + _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) + for chains in grouped_chains] + return chains + + +def _concatenate_paired_and_unpaired_features( + example: pipeline.FeatureDict) -> pipeline.FeatureDict: + """Merges paired and block-diagonalised features.""" + features = MSA_FEATURES + for feature_name in features: + if feature_name in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + example[feature_name] = merged_feat + example['num_alignments'] = np.array(example['msa'].shape[0], + dtype=np.int32) + return example + + +def merge_chain_features(np_chains_list: List[pipeline.FeatureDict], + pair_msa_sequences: bool, + max_templates: int) -> pipeline.FeatureDict: + """Merges features for multiple chains to single FeatureDict. + + Args: + np_chains_list: List of FeatureDicts for each chain. + pair_msa_sequences: Whether to merge paired MSAs. + max_templates: The maximum number of templates to include. + + Returns: + Single FeatureDict for entire complex. + """ + np_chains_list = _pad_templates( + np_chains_list, max_templates=max_templates) + np_chains_list = _merge_homomers_dense_msa(np_chains_list) + # Unpaired MSA features will be always block-diagonalised; paired MSA + # features will be concatenated. + np_example = _merge_features_from_multiple_chains( + np_chains_list, pair_msa_sequences=False) + if pair_msa_sequences: + np_example = _concatenate_paired_and_unpaired_features(np_example) + np_example = _correct_post_merged_feats( + np_example=np_example, + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences) + + return np_example + + +def deduplicate_unpaired_sequences( + np_chains: List[pipeline.FeatureDict]) -> List[pipeline.FeatureDict]: + """Removes unpaired sequences which duplicate a paired sequence.""" + + feature_names = np_chains[0].keys() + msa_features = MSA_FEATURES + + for chain in np_chains: + sequence_set = set(tuple(s) for s in chain['msa_all_seq']) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa']): + if tuple(seq) not in sequence_set: + keep_rows.append(row_num) + for feature_name in feature_names: + if feature_name in msa_features: + if keep_rows: + chain[feature_name] = chain[feature_name][keep_rows] + else: + new_shape = list(chain[feature_name].shape) + new_shape[0] = 0 + chain[feature_name] = np.zeros(new_shape, + dtype=chain[feature_name].dtype) + chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32) + return np_chains diff --git a/alphafold/data/parsers.py b/alphafold/data/parsers.py index edc21bbeb..cbb58f23a 100644 --- a/alphafold/data/parsers.py +++ b/alphafold/data/parsers.py @@ -15,20 +15,47 @@ """Functions for parsing various file formats.""" import collections import dataclasses +import itertools import re import string -from typing import Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set DeletionMatrix = Sequence[Sequence[int]] +@dataclasses.dataclass(frozen=True) +class Msa: + """Class representing a parsed MSA file.""" + sequences: Sequence[str] + deletion_matrix: DeletionMatrix + descriptions: Sequence[str] + + def __post_init__(self): + if not (len(self.sequences) == + len(self.deletion_matrix) == + len(self.descriptions)): + raise ValueError( + 'All fields for an MSA must have the same length. ' + f'Got {len(self.sequences)} sequences, ' + f'{len(self.deletion_matrix)} rows in the deletion matrix and ' + f'{len(self.descriptions)} descriptions.') + + def __len__(self): + return len(self.sequences) + + def truncate(self, max_seqs: int): + return Msa(sequences=self.sequences[:max_seqs], + deletion_matrix=self.deletion_matrix[:max_seqs], + descriptions=self.descriptions[:max_seqs]) + + @dataclasses.dataclass(frozen=True) class TemplateHit: """Class representing a template hit.""" index: int name: str aligned_cols: int - sum_probs: float + sum_probs: Optional[float] query: str hit_sequence: str indices_query: List[int] @@ -64,9 +91,7 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: return sequences, descriptions -def parse_stockholm( - stockholm_string: str -) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: +def parse_stockholm(stockholm_string: str) -> Msa: """Parses sequences and deletion matrix from stockholm format alignment. Args: @@ -121,10 +146,12 @@ def parse_stockholm( deletion_count = 0 deletion_matrix.append(deletion_vec) - return msa, deletion_matrix, list(name_to_sequence.keys()) + return Msa(sequences=msa, + deletion_matrix=deletion_matrix, + descriptions=list(name_to_sequence.keys())) -def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: +def parse_a3m(a3m_string: str) -> Msa: """Parses sequences and deletion matrix from a3m format alignment. Args: @@ -138,8 +165,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: * The deletion matrix for the alignment as a list of lists. The element at `deletion_matrix[i][j]` is the number of residues deleted from the aligned sequence i at residue position j. + * A list of descriptions, one per sequence, from the a3m file. """ - sequences, _ = parse_fasta(a3m_string) + sequences, descriptions = parse_fasta(a3m_string) deletion_matrix = [] for msa_sequence in sequences: deletion_vec = [] @@ -155,7 +183,9 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: # Make the MSA matrix out of aligned (deletion-free) sequences. deletion_table = str.maketrans('', '', string.ascii_lowercase) aligned_sequences = [s.translate(deletion_table) for s in sequences] - return aligned_sequences, deletion_matrix + return Msa(sequences=aligned_sequences, + deletion_matrix=deletion_matrix, + descriptions=descriptions) def _convert_sto_seq_to_a3m( @@ -168,7 +198,8 @@ def _convert_sto_seq_to_a3m( def convert_stockholm_to_a3m(stockholm_format: str, - max_sequences: Optional[int] = None) -> str: + max_sequences: Optional[int] = None, + remove_first_row_gaps: bool = True) -> str: """Converts MSA in Stockholm format to the A3M format.""" descriptions = {} sequences = {} @@ -203,18 +234,138 @@ def convert_stockholm_to_a3m(stockholm_format: str, # Convert sto format to a3m line by line a3m_sequences = {} - # query_sequence is assumed to be the first sequence - query_sequence = next(iter(sequences.values())) - query_non_gaps = [res != '-' for res in query_sequence] + if remove_first_row_gaps: + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != '-' for res in query_sequence] for seqname, sto_sequence in sequences.items(): - a3m_sequences[seqname] = ''.join( - _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) + # Dots are optional in a3m format and are commonly removed. + out_sequence = sto_sequence.replace('.', '') + if remove_first_row_gaps: + out_sequence = ''.join( + _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) + a3m_sequences[seqname] = out_sequence fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" for k in a3m_sequences) return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. +def _keep_line(line: str, seqnames: Set[str]) -> bool: + """Function to decide which lines to keep.""" + if not line.strip(): + return True + if line.strip() == '//': # End tag + return True + if line.startswith('# STOCKHOLM'): # Start tag + return True + if line.startswith('#=GC RF'): # Reference Annotation Line + return True + if line[:4] == '#=GS': # Description lines - keep if sequence in list. + _, seqname, _ = line.split(maxsplit=2) + return seqname in seqnames + elif line.startswith('#'): # Other markup - filter out + return False + else: # Alignment data - keep if sequence in list. + seqname = line.partition(' ')[0] + return seqname in seqnames + + +def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: + """Truncates a stockholm file to a maximum number of sequences.""" + seqnames = set() + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname = line.partition(' ')[0] + seqnames.add(seqname) + if len(seqnames) >= max_sequences: + break + + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: + """Removes empty columns (dashes-only) from a Stockholm MSA.""" + processed_lines = {} + unprocessed_lines = {} + for i, line in enumerate(stockholm_msa.splitlines()): + if line.startswith('#=GC RF'): + reference_annotation_i = i + reference_annotation_line = line + # Reached the end of this chunk of the alignment. Process chunk. + _, _, first_alignment = line.rpartition(' ') + mask = [] + for j in range(len(first_alignment)): + for _, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + if alignment[j] != '-': + mask.append(True) + break + else: # Every row contained a hyphen - empty column. + mask.append(False) + # Add reference annotation for processing with mask. + unprocessed_lines[reference_annotation_i] = reference_annotation_line + + if not any(mask): # All columns were empty. Output empty lines for chunk. + for line_index in unprocessed_lines: + processed_lines[line_index] = '' + else: + for line_index, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + masked_alignment = ''.join(itertools.compress(alignment, mask)) + processed_lines[line_index] = f'{prefix} {masked_alignment}' + + # Clear raw_alignments. + unprocessed_lines = {} + elif line.strip() and not line.startswith(('#', '//')): + unprocessed_lines[i] = line + else: + processed_lines[i] = line + return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) + + +def deduplicate_stockholm_msa(stockholm_msa: str) -> str: + """Remove duplicate sequences (ignoring insertions wrt query).""" + sequence_dict = collections.defaultdict(str) + + # First we must extract all sequences from the MSA. + for line in stockholm_msa.splitlines(): + # Only consider the alignments - ignore reference annotation, empty lines, + # descriptions or markup. + if line.strip() and not line.startswith(('#', '//')): + line = line.strip() + seqname, alignment = line.split() + sequence_dict[seqname] += alignment + + seen_sequences = set() + seqnames = set() + # First alignment is the query. + query_align = next(iter(sequence_dict.values())) + mask = [c != '-' for c in query_align] # Mask is False for insertions. + for seqname, alignment in sequence_dict.items(): + # Apply mask to remove all insertions from the string. + masked_alignment = ''.join(itertools.compress(alignment, mask)) + if masked_alignment in seen_sequences: + continue + else: + seen_sequences.add(masked_alignment) + seqnames.add(seqname) + + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + def _get_hhr_line_regex_groups( regex_pattern: str, line: str) -> Sequence[Optional[str]]: match = re.match(regex_pattern, line) @@ -264,8 +415,8 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: raise RuntimeError( 'Could not parse section: %s. Expected this: \n%s to contain summary.' % (detailed_lines, detailed_lines[2])) - (prob_true, e_value, _, aligned_cols, _, _, sum_probs, - neff) = [float(x) for x in match.groups()] + (_, _, _, aligned_cols, _, _, sum_probs, _) = [float(x) + for x in match.groups()] # The next section reads the detailed comparisons. These are in a 'human # readable' format which has a fixed length. The strategy employed is to @@ -362,3 +513,95 @@ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: target_name = fields[0] e_values[target_name] = float(e_value) return e_values + + +def _get_indices(sequence: str, start: int) -> List[int]: + """Returns indices for non-gap/insert residues starting at the given index.""" + indices = [] + counter = start + for symbol in sequence: + # Skip gaps but add a placeholder so that the alignment is preserved. + if symbol == '-': + indices.append(-1) + # Skip deleted residues, but increase the counter. + elif symbol.islower(): + counter += 1 + # Normal aligned residue. Increase the counter and append to indices. + else: + indices.append(counter) + counter += 1 + return indices + + +@dataclasses.dataclass(frozen=True) +class HitMetadata: + pdb_id: str + chain: str + start: int + end: int + length: int + text: str + + +def _parse_hmmsearch_description(description: str) -> HitMetadata: + """Parses the hmmsearch A3M sequence description line.""" + # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text + # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 + match = re.match( + r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', + description.strip()) + + if not match: + raise ValueError(f'Could not parse description: "{description}".') + + return HitMetadata( + pdb_id=match[1], + chain=match[2], + start=int(match[3]), + end=int(match[4]), + length=int(match[5]), + text=match[6]) + + +def parse_hmmsearch_a3m(query_sequence: str, + a3m_string: str, + skip_first: bool = True) -> Sequence[TemplateHit]: + """Parses an a3m string produced by hmmsearch. + + Args: + query_sequence: The query sequence. + a3m_string: The a3m string produced by hmmsearch. + skip_first: Whether to skip the first sequence in the a3m string. + + Returns: + A sequence of `TemplateHit` results. + """ + # Zip the descriptions and MSAs together, skip the first query sequence. + parsed_a3m = list(zip(*parse_fasta(a3m_string))) + if skip_first: + parsed_a3m = parsed_a3m[1:] + + indices_query = _get_indices(query_sequence, start=0) + + hits = [] + for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): + if 'mol:protein' not in hit_description: + continue # Skip non-protein chains. + metadata = _parse_hmmsearch_description(hit_description) + # Aligned columns are only the match states. + aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) + indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) + + hit = TemplateHit( + index=i, + name=f'{metadata.pdb_id}_{metadata.chain}', + aligned_cols=aligned_cols, + sum_probs=None, + query=query_sequence, + hit_sequence=hit_sequence.upper(), + indices_query=indices_query, + indices_hit=indices_hit, + ) + hits.append(hit) + + return hits diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py index 461bce875..1f643dad8 100644 --- a/alphafold/data/pipeline.py +++ b/alphafold/data/pipeline.py @@ -15,19 +15,22 @@ """Functions for building the input features for the AlphaFold model.""" import os -from typing import Mapping, Optional, Sequence +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union from absl import logging from alphafold.common import residue_constants +from alphafold.data import msa_identifiers from alphafold.data import parsers from alphafold.data import templates from alphafold.data.tools import hhblits from alphafold.data.tools import hhsearch +from alphafold.data.tools import hmmsearch from alphafold.data.tools import jackhmmer import numpy as np # Internal import (7716). -FeatureDict = Mapping[str, np.ndarray] +FeatureDict = MutableMapping[str, np.ndarray] +TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] def make_sequence_features( @@ -47,55 +50,78 @@ def make_sequence_features( return features -def make_msa_features( - msas: Sequence[Sequence[str]], - deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: +def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: """Constructs a feature dict of MSA features.""" if not msas: raise ValueError('At least one MSA must be provided.') int_msa = [] deletion_matrix = [] + uniprot_accession_ids = [] + species_ids = [] seen_sequences = set() for msa_index, msa in enumerate(msas): if not msa: raise ValueError(f'MSA {msa_index} must contain at least one sequence.') - for sequence_index, sequence in enumerate(msa): + for sequence_index, sequence in enumerate(msa.sequences): if sequence in seen_sequences: continue seen_sequences.add(sequence) int_msa.append( [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) - deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) - - num_res = len(msas[0][0]) + deletion_matrix.append(msa.deletion_matrix[sequence_index]) + identifiers = msa_identifiers.get_identifiers( + msa.descriptions[sequence_index]) + uniprot_accession_ids.append( + identifiers.uniprot_accession_id.encode('utf-8')) + species_ids.append(identifiers.species_id.encode('utf-8')) + + num_res = len(msas[0].sequences[0]) num_alignments = len(int_msa) features = {} features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) features['msa'] = np.array(int_msa, dtype=np.int32) features['num_alignments'] = np.array( [num_alignments] * num_res, dtype=np.int32) + features['msa_uniprot_accession_identifiers'] = np.array( + uniprot_accession_ids, dtype=np.object_) + features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_) return features +def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str, + msa_format: str, use_precomputed_msas: bool, + ) -> Mapping[str, Any]: + """Runs an MSA tool, checking if output already exists first.""" + if not use_precomputed_msas or not os.path.exists(msa_out_path): + result = msa_runner.query(input_fasta_path)[0] + with open(msa_out_path, 'w') as f: + f.write(result[msa_format]) + else: + logging.warning('Reading MSA from file %s', msa_out_path) + with open(msa_out_path, 'r') as f: + result = {msa_format: f.read()} + return result + + class DataPipeline: """Runs the alignment tools and assembles the input features.""" def __init__(self, jackhmmer_binary_path: str, hhblits_binary_path: str, - hhsearch_binary_path: str, uniref90_database_path: str, mgnify_database_path: str, bfd_database_path: Optional[str], uniclust30_database_path: Optional[str], small_bfd_database_path: Optional[str], - pdb70_database_path: str, + template_searcher: TemplateSearcher, template_featurizer: templates.TemplateHitFeaturizer, use_small_bfd: bool, mgnify_max_hits: int = 501, - uniref_max_hits: int = 10000): - """Constructs a feature dict for a given FASTA file.""" + uniref_max_hits: int = 10000, + use_precomputed_msas: bool = False): + """Initializes the data pipeline.""" self._use_small_bfd = use_small_bfd self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( binary_path=jackhmmer_binary_path, @@ -111,12 +137,11 @@ def __init__(self, self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( binary_path=jackhmmer_binary_path, database_path=mgnify_database_path) - self.hhsearch_pdb70_runner = hhsearch.HHSearch( - binary_path=hhsearch_binary_path, - databases=[pdb70_database_path]) + self.template_searcher = template_searcher self.template_featurizer = template_featurizer self.mgnify_max_hits = mgnify_max_hits self.uniref_max_hits = uniref_max_hits + self.use_precomputed_msas = use_precomputed_msas def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: """Runs alignment tools on the input sequence and creates features.""" @@ -130,72 +155,68 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: input_description = input_descs[0] num_res = len(input_sequence) - jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( - input_fasta_path)[0] - jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( - input_fasta_path)[0] - - uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( - jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits) - hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) - uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') - with open(uniref90_out_path, 'w') as f: - f.write(jackhmmer_uniref90_result['sto']) - + jackhmmer_uniref90_result = run_msa_tool( + self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, + 'sto', self.use_precomputed_msas) mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') - with open(mgnify_out_path, 'w') as f: - f.write(jackhmmer_mgnify_result['sto']) - - pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr') - with open(pdb70_out_path, 'w') as f: - f.write(hhsearch_result) + jackhmmer_mgnify_result = run_msa_tool( + self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', + self.use_precomputed_msas) + + msa_for_templates = jackhmmer_uniref90_result['sto'] + msa_for_templates = parsers.truncate_stockholm_msa( + msa_for_templates, max_sequences=self.uniref_max_hits) + msa_for_templates = parsers.deduplicate_stockholm_msa( + msa_for_templates) + msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( + msa_for_templates) + + if self.template_searcher.input_format == 'sto': + pdb_templates_result = self.template_searcher.query(msa_for_templates) + elif self.template_searcher.input_format == 'a3m': + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates) + pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m) + else: + raise ValueError('Unrecognized template input format: ' + f'{self.template_searcher.input_format}') - uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm( - jackhmmer_uniref90_result['sto']) - mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm( - jackhmmer_mgnify_result['sto']) - hhsearch_hits = parsers.parse_hhr(hhsearch_result) - mgnify_msa = mgnify_msa[:self.mgnify_max_hits] - mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] + pdb_hits_out_path = os.path.join( + msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') + with open(pdb_hits_out_path, 'w') as f: + f.write(pdb_templates_result) - if self._use_small_bfd: - jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( - input_fasta_path)[0] + uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto']) + uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) + mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) + mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) - bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') - with open(bfd_out_path, 'w') as f: - f.write(jackhmmer_small_bfd_result['sto']) + pdb_template_hits = self.template_searcher.get_template_hits( + output_string=pdb_templates_result, input_sequence=input_sequence) - bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( - jackhmmer_small_bfd_result['sto']) + if self._use_small_bfd: + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') + jackhmmer_small_bfd_result = run_msa_tool( + self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path, + 'sto', self.use_precomputed_msas) + bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) else: - hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query( - input_fasta_path) - bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') - with open(bfd_out_path, 'w') as f: - f.write(hhblits_bfd_uniclust_result['a3m']) - - bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( - hhblits_bfd_uniclust_result['a3m']) + hhblits_bfd_uniclust_result = run_msa_tool( + self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path, + 'a3m', self.use_precomputed_msas) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) templates_result = self.template_featurizer.get_templates( query_sequence=input_sequence, - query_pdb_code=None, - query_release_date=None, - hits=hhsearch_hits) + hits=pdb_template_hits) sequence_features = make_sequence_features( sequence=input_sequence, description=input_description, num_res=num_res) - msa_features = make_msa_features( - msas=(uniref90_msa, bfd_msa, mgnify_msa), - deletion_matrices=(uniref90_deletion_matrix, - bfd_deletion_matrix, - mgnify_deletion_matrix)) + msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa)) logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) diff --git a/alphafold/data/pipeline_multimer.py b/alphafold/data/pipeline_multimer.py new file mode 100644 index 000000000..75bc1a52a --- /dev/null +++ b/alphafold/data/pipeline_multimer.py @@ -0,0 +1,288 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functions for building the features for the AlphaFold multimer model.""" + +import collections +import contextlib +import copy +import dataclasses +import json +import os +import tempfile +from typing import Mapping, MutableMapping, Sequence + +from absl import logging +from alphafold.common import protein +from alphafold.common import residue_constants +from alphafold.data import feature_processing +from alphafold.data import msa_pairing +from alphafold.data import parsers +from alphafold.data import pipeline +from alphafold.data.tools import jackhmmer +import numpy as np + +# Internal import (7716). + + +@dataclasses.dataclass(frozen=True) +class _FastaChain: + sequence: str + description: str + + +def _make_chain_id_map(*, + sequences: Sequence[str], + descriptions: Sequence[str], + ) -> Mapping[str, _FastaChain]: + """Makes a mapping from PDB-format chain ID to sequence and description.""" + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError('Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + chain_id_map = {} + for chain_id, sequence, description in zip( + protein.PDB_CHAIN_IDS, sequences, descriptions): + chain_id_map[chain_id] = _FastaChain( + sequence=sequence, description=description) + return chain_id_map + + +@contextlib.contextmanager +def temp_fasta_file(fasta_str: str): + with tempfile.NamedTemporaryFile('w', suffix='.fasta') as fasta_file: + fasta_file.write(fasta_str) + fasta_file.seek(0) + yield fasta_file.name + + +def convert_monomer_features( + monomer_features: pipeline.FeatureDict, + chain_id: str) -> pipeline.FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + converted = {} + converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + unnecessary_leading_dim_feats = { + 'sequence', 'domain_name', 'num_alignments', 'seq_length'} + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + elif feature_name == 'template_all_atom_masks': + feature_name = 'template_all_atom_mask' + converted[feature_name] = feature + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features( + all_chain_features: MutableMapping[str, pipeline.FeatureDict], + ) -> MutableMapping[str, pipeline.FeatureDict]: + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_id, chain_features in all_chain_features.items(): + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = {} + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + for sym_id, chain_features in enumerate(group_chain_features, start=1): + new_all_chain_features[ + f'{int_id_to_str_id(entity_id)}_{sym_id}'] = chain_features + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_id += 1 + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask'): + np_example[feat] = np.pad( + np_example[feat], ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq),)) + return np_example + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, + monomer_data_pipeline: pipeline.DataPipeline, + jackhmmer_binary_path: str, + uniprot_database_path: str, + max_uniprot_hits: int = 50000, + use_precomputed_msas: bool = False): + """Initializes the data pipeline. + + Args: + monomer_data_pipeline: An instance of pipeline.DataPipeline - that runs + the data pipeline for the monomer AlphaFold system. + jackhmmer_binary_path: Location of the jackhmmer binary. + uniprot_database_path: Location of the unclustered uniprot sequences, that + will be searched with jackhmmer and used for MSA pairing. + max_uniprot_hits: The maximum number of hits to return from uniprot. + use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold. + """ + self._monomer_data_pipeline = monomer_data_pipeline + self._uniprot_msa_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self._max_uniprot_hits = max_uniprot_hits + self.use_precomputed_msas = use_precomputed_msas + + def _process_single_chain( + self, + chain_id: str, + sequence: str, + description: str, + msa_output_dir: str, + is_homomer_or_monomer: bool) -> pipeline.FeatureDict: + """Runs the monomer pipeline on a single chain.""" + chain_fasta_str = f'>{description}\n{sequence}\n' + chain_msa_output_dir = os.path.join(msa_output_dir, chain_id) + if not os.path.exists(chain_msa_output_dir): + os.makedirs(chain_msa_output_dir) + with temp_fasta_file(chain_fasta_str) as chain_fasta_path: + logging.info('Running monomer pipeline on chain %s: %s', + chain_id, description) + chain_features = self._monomer_data_pipeline.process( + input_fasta_path=chain_fasta_path, + msa_output_dir=chain_msa_output_dir) + + # We only construct the pairing features if there are 2 or more unique + # sequences. + if not is_homomer_or_monomer: + all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path, + chain_msa_output_dir) + chain_features.update(all_seq_msa_features) + return chain_features + + def _all_seq_msa_features(self, input_fasta_path, msa_output_dir): + """Get MSA features for unclustered uniprot, for pairing.""" + out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + result = pipeline.run_msa_tool( + self._uniprot_msa_runner, input_fasta_path, out_path, 'sto', + self.use_precomputed_msas) + msa = parsers.parse_stockholm(result['sto']) + msa = msa.truncate(max_seqs=self._max_uniprot_hits) + all_seq_features = pipeline.make_msa_features([msa]) + valid_feats = msa_pairing.MSA_FEATURES + ( + 'msa_uniprot_accession_identifiers', + 'msa_species_identifiers', + ) + feats = {f'{k}_all_seq': v for k, v in all_seq_features.items() + if k in valid_feats} + return feats + + def process(self, + input_fasta_path: str, + msa_output_dir: str, + is_prokaryote: bool = False) -> pipeline.FeatureDict: + """Runs alignment tools on the input sequences and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + + chain_id_map = _make_chain_id_map(sequences=input_seqs, + descriptions=input_descs) + chain_id_map_path = os.path.join(msa_output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + chain_id_map_dict = {chain_id: dataclasses.asdict(fasta_chain) + for chain_id, fasta_chain in chain_id_map.items()} + json.dump(chain_id_map_dict, f, indent=4, sort_keys=True) + + all_chain_features = {} + sequence_features = {} + is_homomer_or_monomer = len(set(input_seqs)) == 1 + for chain_id, fasta_chain in chain_id_map.items(): + if fasta_chain.sequence in sequence_features: + all_chain_features[chain_id] = copy.deepcopy( + sequence_features[fasta_chain.sequence]) + continue + chain_features = self._process_single_chain( + chain_id=chain_id, + sequence=fasta_chain.sequence, + description=fasta_chain.description, + msa_output_dir=msa_output_dir, + is_homomer_or_monomer=is_homomer_or_monomer) + + chain_features = convert_monomer_features(chain_features, + chain_id=chain_id) + all_chain_features[chain_id] = chain_features + sequence_features[fasta_chain.sequence] = chain_features + + all_chain_features = add_assembly_features(all_chain_features) + + np_example = feature_processing.pair_and_merge( + all_chain_features=all_chain_features, + is_prokaryote=is_prokaryote, + ) + + # Pad MSA to avoid zero-sized extra_msa. + np_example = pad_msa(np_example, 512) + + return np_example diff --git a/alphafold/data/templates.py b/alphafold/data/templates.py index 9c1f20773..d37598711 100644 --- a/alphafold/data/templates.py +++ b/alphafold/data/templates.py @@ -13,8 +13,10 @@ # limitations under the License. """Functions for getting templates and calculating template features.""" +import abc import dataclasses import datetime +import functools import glob import os import re @@ -71,10 +73,6 @@ class DateError(PrefilterError): """An error indicating that the hit date was after the max allowed date.""" -class PdbIdError(PrefilterError): - """An error indicating that the hit PDB ID was identical to the query.""" - - class AlignRatioError(PrefilterError): """An error indicating that the hit align ratio to the query was too small.""" @@ -128,7 +126,6 @@ def _is_after_cutoff( else: # Since this is just a quick prefilter to reduce the number of mmCIF files # we need to parse, we don't have to worry about returning True here. - logging.warning('Template structure not in release dates dict: %s', pdb_id) return False @@ -177,7 +174,6 @@ def _assess_hhsearch_hit( hit: parsers.TemplateHit, hit_pdb_code: str, query_sequence: str, - query_pdb_code: Optional[str], release_dates: Mapping[str, datetime.datetime], release_date_cutoff: datetime.datetime, max_subsequence_ratio: float = 0.95, @@ -190,7 +186,6 @@ def _assess_hhsearch_hit( different from the value in the actual hit since the original pdb might have become obsolete. query_sequence: Amino acid sequence of the query. - query_pdb_code: 4 letter pdb code of the query. release_dates: Dictionary mapping pdb codes to their structure release dates. release_date_cutoff: Max release date that is valid for this query. @@ -202,7 +197,6 @@ def _assess_hhsearch_hit( Raises: DateError: If the hit date was after the max allowed date. - PdbIdError: If the hit PDB ID was identical to the query. AlignRatioError: If the hit align ratio to the query was too small. DuplicateError: If the hit was an exact subsequence of the query. LengthError: If the hit was too short. @@ -222,10 +216,6 @@ def _assess_hhsearch_hit( raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date ' f'({release_date_cutoff}).') - if query_pdb_code is not None: - if query_pdb_code.lower() == hit_pdb_code.lower(): - raise PdbIdError('PDB code identical to Query PDB code.') - if align_ratio <= min_align_ratio: raise AlignRatioError('Proportion of residues aligned to query too small. ' f'Align ratio: {align_ratio}.') @@ -368,8 +358,9 @@ def _realign_pdb_template_to_query( 'protein chain.') try: - (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( + parsed_a3m = parsers.parse_a3m( aligner.align([old_template_sequence, new_template_sequence])) + old_aligned_template, new_aligned_template = parsed_a3m.sequences except Exception as e: raise QueryToTemplateAlignError( 'Could not align old template %s to template %s (%s_%s). Error: %s' % @@ -472,6 +463,18 @@ def _get_atom_positions( pos[residue_constants.atom_order['SD']] = [x, y, z] mask[residue_constants.atom_order['SD']] = 1.0 + # Fix naming errors in arginine residues where NH2 is incorrectly + # assigned to be closer to CD than NH1. + cd = residue_constants.atom_order['CD'] + nh1 = residue_constants.atom_order['NH1'] + nh2 = residue_constants.atom_order['NH2'] + if (res.get_resname() == 'ARG' and + all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and + (np.linalg.norm(pos[nh1] - pos[cd]) > + np.linalg.norm(pos[nh2] - pos[cd]))): + pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy() + mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy() + all_positions[res_index] = pos all_positions_mask[res_index] = mask _check_residue_distances( @@ -673,9 +676,15 @@ class SingleHitResult: warning: Optional[str] +@functools.lru_cache(16, typed=False) +def _read_file(path): + with open(path, 'r') as f: + file_data = f.read() + return file_data + + def _process_single_hit( query_sequence: str, - query_pdb_code: Optional[str], hit: parsers.TemplateHit, mmcif_dir: str, max_template_date: datetime.datetime, @@ -702,14 +711,12 @@ def _process_single_hit( hit=hit, hit_pdb_code=hit_pdb_code, query_sequence=query_sequence, - query_pdb_code=query_pdb_code, release_dates=release_dates, release_date_cutoff=max_template_date) except PrefilterError as e: msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' - logging.info('%s: %s', query_pdb_code, msg) - if strict_error_check and isinstance( - e, (DateError, PdbIdError, DuplicateError)): + logging.info(msg) + if strict_error_check and isinstance(e, (DateError, DuplicateError)): # In strict mode we treat some prefilter cases as errors. return SingleHitResult(features=None, error=msg, warning=None) @@ -724,11 +731,10 @@ def _process_single_hit( template_sequence = hit.hit_sequence.replace('-', '') cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') - logging.info('Reading PDB entry from %s. Query: %s, template: %s', - cif_path, query_sequence, template_sequence) + logging.debug('Reading PDB entry from %s. Query: %s, template: %s', cif_path, + query_sequence, template_sequence) # Fail if we can't find the mmCIF file. - with open(cif_path, 'r') as cif_file: - cif_string = cif_file.read() + cif_string = _read_file(cif_path) parsing_result = mmcif_parsing.parse( file_id=hit_pdb_code, mmcif_string=cif_string) @@ -742,7 +748,7 @@ def _process_single_hit( if strict_error_check: return SingleHitResult(features=None, error=error, warning=None) else: - logging.warning(error) + logging.debug(error) return SingleHitResult(features=None, error=None, warning=None) try: @@ -754,7 +760,10 @@ def _process_single_hit( query_sequence=query_sequence, template_chain_id=hit_chain_id, kalign_binary_path=kalign_binary_path) - features['template_sum_probs'] = [hit.sum_probs] + if hit.sum_probs is None: + features['template_sum_probs'] = [0] + else: + features['template_sum_probs'] = [hit.sum_probs] # It is possible there were some errors when parsing the other chains in the # mmCIF file, but the template features for the chain we want were still @@ -765,7 +774,7 @@ def _process_single_hit( TemplateAtomMaskAllZerosError) as e: # These 3 errors indicate missing mmCIF experimental data rather than a # problem with the template search, so turn them into warnings. - warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + warning = ('%s_%s (sum_probs: %s, rank: %s): feature extracting errors: ' '%s, mmCIF parsing errors: %s' % (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index, str(e), parsing_result.errors)) @@ -788,8 +797,8 @@ class TemplateSearchResult: warnings: Sequence[str] -class TemplateHitFeaturizer: - """A class for turning hhr hits to template features.""" +class TemplateHitFeaturizer(abc.ABC): + """An abstract base class for turning template hits to template features.""" def __init__( self, @@ -850,29 +859,28 @@ def __init__( else: self._obsolete_pdbs = {} + @abc.abstractmethod + def get_templates( + self, + query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence.""" + + +class HhsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hhsearch to template features.""" + def get_templates( self, query_sequence: str, - query_pdb_code: Optional[str], - query_release_date: Optional[datetime.datetime], hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: """Computes the templates for given query sequence (more details above).""" - logging.info('Searching for template for: %s', query_pdb_code) + logging.info('Searching for template for: %s', query_sequence) template_features = {} for template_feature_name in TEMPLATE_FEATURES: template_features[template_feature_name] = [] - # Always use a max_template_date. Set to query_release_date minus 60 days - # if that's earlier. - template_cutoff_date = self._max_template_date - if query_release_date: - delta = datetime.timedelta(days=60) - if query_release_date - delta < template_cutoff_date: - template_cutoff_date = query_release_date - delta - assert template_cutoff_date < query_release_date - assert template_cutoff_date <= self._max_template_date - num_hits = 0 errors = [] warnings = [] @@ -884,10 +892,9 @@ def get_templates( result = _process_single_hit( query_sequence=query_sequence, - query_pdb_code=query_pdb_code, hit=hit, mmcif_dir=self._mmcif_dir, - max_template_date=template_cutoff_date, + max_template_date=self._max_template_date, release_dates=self._release_dates, obsolete_pdbs=self._obsolete_pdbs, strict_error_check=self._strict_error_check, @@ -920,3 +927,84 @@ def get_templates( return TemplateSearchResult( features=template_features, errors=errors, warnings=warnings) + + +class HmmsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hmmsearch to template features.""" + + def get_templates( + self, + query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + already_seen = set() + errors = [] + warnings = [] + + if not hits or hits[0].sum_probs is None: + sorted_hits = hits + else: + sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True) + + for hit in sorted_hits: + # We got all the templates we wanted, stop processing hits. + if len(already_seen) >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.debug('Skipped invalid hit %s, error: %s, warning: %s', + hit.name, result.error, result.warning) + else: + already_seen_key = result.features['template_sequence'] + if already_seen_key in already_seen: + continue + # Increment the hit counter, since we got features out of this hit. + already_seen.add(already_seen_key) + for k in template_features: + template_features[k].append(result.features[k]) + + if already_seen: + for name in template_features: + template_features[name] = np.stack( + template_features[name], axis=0).astype(TEMPLATE_FEATURES[name]) + else: + num_res = len(query_sequence) + # Construct a default template with all zeros. + template_features = { + 'template_aatype': np.zeros( + (1, num_res, len(residue_constants.restypes_with_x_and_gap)), + np.float32), + 'template_all_atom_masks': np.zeros( + (1, num_res, residue_constants.atom_type_num), np.float32), + 'template_all_atom_positions': np.zeros( + (1, num_res, residue_constants.atom_type_num, 3), np.float32), + 'template_domain_names': np.array([''.encode()], dtype=np.object), + 'template_sequence': np.array([''.encode()], dtype=np.object), + 'template_sum_probs': np.array([0], dtype=np.float32) + } + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) diff --git a/alphafold/data/tools/hhblits.py b/alphafold/data/tools/hhblits.py index f068ac919..ae5e1e1fe 100644 --- a/alphafold/data/tools/hhblits.py +++ b/alphafold/data/tools/hhblits.py @@ -17,7 +17,7 @@ import glob import os import subprocess -from typing import Any, Mapping, Optional, Sequence +from typing import Any, List, Mapping, Optional, Sequence from absl import logging from alphafold.data.tools import utils @@ -94,9 +94,9 @@ def __init__(self, self.p = p self.z = z - def query(self, input_fasta_path: str) -> Mapping[str, Any]: + def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: """Queries the database using HHblits.""" - with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + with utils.tmpdir_manager() as query_tmp_dir: a3m_path = os.path.join(query_tmp_dir, 'output.a3m') db_cmd = [] @@ -152,4 +152,4 @@ def query(self, input_fasta_path: str) -> Mapping[str, Any]: stderr=stderr, n_iter=self.n_iter, e_value=self.e_value) - return raw_output + return [raw_output] diff --git a/alphafold/data/tools/hhsearch.py b/alphafold/data/tools/hhsearch.py index f61f0367c..7fd1134b0 100644 --- a/alphafold/data/tools/hhsearch.py +++ b/alphafold/data/tools/hhsearch.py @@ -21,6 +21,7 @@ from absl import logging +from alphafold.data import parsers from alphafold.data.tools import utils # Internal import (7716). @@ -57,9 +58,17 @@ def __init__(self, logging.error('Could not find HHsearch database %s', database_path) raise ValueError(f'Could not find HHsearch database {database_path}') + @property + def output_format(self) -> str: + return 'hhr' + + @property + def input_format(self) -> str: + return 'a3m' + def query(self, a3m: str) -> str: """Queries the database using HHsearch using a given a3m.""" - with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + with utils.tmpdir_manager() as query_tmp_dir: input_path = os.path.join(query_tmp_dir, 'query.a3m') hhr_path = os.path.join(query_tmp_dir, 'output.hhr') with open(input_path, 'w') as f: @@ -92,3 +101,10 @@ def query(self, a3m: str) -> str: with open(hhr_path) as f: hhr = f.read() return hhr + + def get_template_hits(self, + output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + del input_sequence # Used by hmmseach but not needed for hhsearch. + return parsers.parse_hhr(output_string) diff --git a/alphafold/data/tools/hmmbuild.py b/alphafold/data/tools/hmmbuild.py index f3c573047..f8f331da0 100644 --- a/alphafold/data/tools/hmmbuild.py +++ b/alphafold/data/tools/hmmbuild.py @@ -98,7 +98,7 @@ def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: raise ValueError(f'Invalid model_construction {model_construction} - only' 'hand and fast supported.') - with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + with utils.tmpdir_manager() as query_tmp_dir: input_query = os.path.join(query_tmp_dir, 'query.msa') output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') diff --git a/alphafold/data/tools/hmmsearch.py b/alphafold/data/tools/hmmsearch.py index a60d3e760..08f0b8d47 100644 --- a/alphafold/data/tools/hmmsearch.py +++ b/alphafold/data/tools/hmmsearch.py @@ -19,6 +19,8 @@ from typing import Optional, Sequence from absl import logging +from alphafold.data import parsers +from alphafold.data.tools import hmmbuild from alphafold.data.tools import utils # Internal import (7716). @@ -29,12 +31,15 @@ class Hmmsearch(object): def __init__(self, *, binary_path: str, + hmmbuild_binary_path: str, database_path: str, flags: Optional[Sequence[str]] = None): """Initializes the Python hmmsearch wrapper. Args: binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. database_path: The path to the hmmsearch database (FASTA format). flags: List of flags to be used by hmmsearch. @@ -42,18 +47,42 @@ def __init__(self, RuntimeError: If hmmsearch binary not found within the path. """ self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path) self.database_path = database_path + if flags is None: + # Default hmmsearch run settings. + flags = ['--F1', '0.1', + '--F2', '0.1', + '--F3', '0.1', + '--incE', '100', + '-E', '100', + '--domE', '100', + '--incdomE', '100'] self.flags = flags if not os.path.exists(self.database_path): logging.error('Could not find hmmsearch database %s', database_path) raise ValueError(f'Could not find hmmsearch database {database_path}') - def query(self, hmm: str) -> str: + @property + def output_format(self) -> str: + return 'sto' + + @property + def input_format(self) -> str: + return 'sto' + + def query(self, msa_sto: str) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto, + model_construction='hand') + return self.query_with_hmm(hmm) + + def query_with_hmm(self, hmm: str) -> str: """Queries the database using hmmsearch using a given hmm.""" - with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + with utils.tmpdir_manager() as query_tmp_dir: hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') - a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m') + out_path = os.path.join(query_tmp_dir, 'output.sto') with open(hmm_input_path, 'w') as f: f.write(hmm) @@ -66,7 +95,7 @@ def query(self, hmm: str) -> str: if self.flags: cmd.extend(self.flags) cmd.extend([ - '-A', a3m_out_path, + '-A', out_path, hmm_input_path, self.database_path, ]) @@ -84,7 +113,19 @@ def query(self, hmm: str) -> str: 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( stdout.decode('utf-8'), stderr.decode('utf-8'))) - with open(a3m_out_path) as f: - a3m_out = f.read() + with open(out_path) as f: + out_msa = f.read() + + return out_msa - return a3m_out + def get_template_hits(self, + output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + a3m_string = parsers.convert_stockholm_to_a3m(output_string, + remove_first_row_gaps=False) + template_hits = parsers.parse_hmmsearch_a3m( + query_sequence=input_sequence, + a3m_string=a3m_string, + skip_first=False) + return template_hits diff --git a/alphafold/data/tools/jackhmmer.py b/alphafold/data/tools/jackhmmer.py index 3c6c8ba25..cb03324f9 100644 --- a/alphafold/data/tools/jackhmmer.py +++ b/alphafold/data/tools/jackhmmer.py @@ -89,7 +89,7 @@ def __init__(self, def _query_chunk(self, input_fasta_path: str, database_path: str ) -> Mapping[str, Any]: """Queries the database chunk using Jackhmmer.""" - with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + with utils.tmpdir_manager() as query_tmp_dir: sto_path = os.path.join(query_tmp_dir, 'output.sto') # The F1/F2/F3 are the expected proportion to pass each of the filtering @@ -192,7 +192,10 @@ def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: # Remove the local copy of the chunk os.remove(db_local_chunk(i)) - future = next_future + # Do not set next_future for the last chunk so that this works even for + # databases with only 1 chunk. + if i < self.num_streamed_chunks: + future = next_future if self.streaming_callback: self.streaming_callback(i) return chunked_output diff --git a/alphafold/data/tools/kalign.py b/alphafold/data/tools/kalign.py index fc4e58a43..21ce1a361 100644 --- a/alphafold/data/tools/kalign.py +++ b/alphafold/data/tools/kalign.py @@ -70,7 +70,7 @@ def align(self, sequences: Sequence[str]) -> str: raise ValueError('Kalign requires all sequences to be at least 6 ' 'residues long. Got %s (%d residues).' % (s, len(s))) - with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: + with utils.tmpdir_manager() as query_tmp_dir: input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') diff --git a/alphafold/model/all_atom_multimer.py b/alphafold/model/all_atom_multimer.py new file mode 100644 index 000000000..f56c5cf10 --- /dev/null +++ b/alphafold/model/all_atom_multimer.py @@ -0,0 +1,966 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ops for all atom representations.""" + +from typing import Dict, Text + +from alphafold.common import residue_constants +from alphafold.model import geometry +from alphafold.model import utils +import jax +import jax.numpy as jnp +import numpy as np + + +def squared_difference(x, y): + return jnp.square(x - y) + + +def _make_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return np.array(chi_atom_indices) + + +def _make_renaming_matrices(): + """Matrices to map atoms to symmetry partners in ambiguous case.""" + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative groundtruth coordinates where the naming is swapped + restype_3 = [ + residue_constants.restype_1to3[res] for res in residue_constants.restypes + ] + restype_3 += ['UNK'] + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = residue_constants.restype_name_to_atom14_names[ + resname].index(source_atom_swap) + target_index = residue_constants.restype_name_to_atom14_names[ + resname].index(target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14), dtype=np.float32) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1. + all_matrices[resname] = renaming_matrix.astype(np.float32) + renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +def _make_restype_atom37_mask(): + """Mask of which atoms are present for which residue type in atom37.""" + # create the corresponding mask + restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) + for restype, restype_letter in enumerate(residue_constants.restypes): + restype_name = residue_constants.restype_1to3[restype_letter] + atom_names = residue_constants.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = residue_constants.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + return restype_atom37_mask + + +def _make_restype_atom14_mask(): + """Mask of which atoms are present for which residue type in atom14.""" + restype_atom14_mask = [] + + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) + + restype_atom14_mask.append([0.] * 14) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + return restype_atom14_mask + + +def _make_restype_atom37_to_atom14(): + """Map from atom37 to atom14 per residue type.""" + restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in residue_constants.atom_types + ]) + + restype_atom37_to_atom14.append([0] * 37) + restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) + return restype_atom37_to_atom14 + + +def _make_restype_atom14_to_atom37(): + """Map from atom14 to atom37 per residue type.""" + restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 + for rt in residue_constants.restypes: + atom_names = residue_constants.restype_name_to_atom14_names[ + residue_constants.restype_1to3[rt]] + restype_atom14_to_atom37.append([ + (residue_constants.atom_order[name] if name else 0) + for name in atom_names + ]) + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) + return restype_atom14_to_atom37 + + +def _make_restype_atom14_is_ambiguous(): + """Mask which atoms are ambiguous in atom14.""" + # create an ambiguous atoms mask. shape: (21, 14) + restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) + for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + atom_idx1 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name1) + atom_idx2 = residue_constants.restype_name_to_atom14_names[resname].index( + atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + + return restype_atom14_is_ambiguous + + +def _make_restype_rigidgroup_base_atom37_idx(): + """Create Map from rigidgroups to atom37 indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): (21, 8, 3) + base_atom_names = np.full([21, 8, 3], '', dtype=object) + + # 0: backbone frame + base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + + # 3: 'psi-group' + base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate(residue_constants.restypes): + resname = residue_constants.restype_1to3[restype_letter] + for chi_idx in range(4): + if residue_constants.chi_angles_mask[restype][chi_idx]: + atom_names = residue_constants.chi_angles_atoms[resname][chi_idx] + base_atom_names[restype, chi_idx + 4, :] = atom_names[1:] + + # Translate atom names into atom37 indices. + lookuptable = residue_constants.atom_order.copy() + lookuptable[''] = 0 + restype_rigidgroup_base_atom37_idx = np.vectorize(lambda x: lookuptable[x])( + base_atom_names) + return restype_rigidgroup_base_atom37_idx + + +CHI_ATOM_INDICES = _make_chi_atom_indices() +RENAMING_MATRICES = _make_renaming_matrices() +RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37() +RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14() +RESTYPE_ATOM37_MASK = _make_restype_atom37_mask() +RESTYPE_ATOM14_MASK = _make_restype_atom14_mask() +RESTYPE_ATOM14_IS_AMBIGUOUS = _make_restype_atom14_is_ambiguous() +RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX = _make_restype_rigidgroup_base_atom37_idx() + +# Create mask for existing rigid groups. +RESTYPE_RIGIDGROUP_MASK = np.zeros([21, 8], dtype=np.float32) +RESTYPE_RIGIDGROUP_MASK[:, 0] = 1 +RESTYPE_RIGIDGROUP_MASK[:, 3] = 1 +RESTYPE_RIGIDGROUP_MASK[:20, 4:] = residue_constants.chi_angles_mask + + +def get_atom37_mask(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_MASK), aatype) + + +def get_atom14_mask(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) + + +def get_atom14_is_ambiguous(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_IS_AMBIGUOUS), aatype) + + +def get_atom14_to_atom37_map(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype) + + +def get_atom37_to_atom14_map(aatype): + return utils.batched_gather(jnp.asarray(RESTYPE_ATOM37_TO_ATOM14), aatype) + + +def atom14_to_atom37(atom14_data: jnp.ndarray, # (N, 14, ...) + aatype: jnp.ndarray + ) -> jnp.ndarray: # (N, 37, ...) + """Convert atom14 to atom37 representation.""" + assert len(atom14_data.shape) in [2, 3] + idx_atom37_to_atom14 = get_atom37_to_atom14_map(aatype) + atom37_data = utils.batched_gather( + atom14_data, idx_atom37_to_atom14, batch_dims=1) + atom37_mask = get_atom37_mask(aatype) + if len(atom14_data.shape) == 2: + atom37_data *= atom37_mask + elif len(atom14_data.shape) == 3: + atom37_data *= atom37_mask[:, :, None].astype(atom37_data.dtype) + return atom37_data + + +def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): + """Convert Atom37 positions to Atom14 positions.""" + residx_atom14_to_atom37 = utils.batched_gather( + jnp.asarray(RESTYPE_ATOM14_TO_ATOM37), aatype) + atom14_mask = utils.batched_gather( + all_atom_mask, residx_atom14_to_atom37, batch_dims=1).astype(jnp.float32) + # create a mask for known groundtruth positions + atom14_mask *= utils.batched_gather(jnp.asarray(RESTYPE_ATOM14_MASK), aatype) + # gather the groundtruth positions + atom14_positions = jax.tree_map( + lambda x: utils.batched_gather(x, residx_atom14_to_atom37, batch_dims=1), + all_atom_pos) + atom14_positions = atom14_mask * atom14_positions + return atom14_positions, atom14_mask + + +def get_alt_atom14(aatype, positions: geometry.Vec3Array, mask): + """Get alternative atom14 positions.""" + # pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14) + renaming_transform = utils.batched_gather( + jnp.asarray(RENAMING_MATRICES), aatype) + + alternative_positions = jax.tree_map( + lambda x: jnp.sum(x, axis=1), positions[:, :, None] * renaming_transform) + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position) + alternative_mask = jnp.sum(mask[..., None] * renaming_transform, axis=1) + + return alternative_positions, alternative_mask + + +def atom37_to_frames( + aatype: jnp.ndarray, # (...) + all_atom_positions: geometry.Vec3Array, # (..., 37) + all_atom_mask: jnp.ndarray, # (..., 37) +) -> Dict[Text, jnp.ndarray]: + """Computes the frames for the up to 8 rigid groups for each residue.""" + # 0: 'backbone group', + # 1: 'pre-omega-group', (empty) + # 2: 'phi-group', (currently empty, because it defines only hydrogens) + # 3: 'psi-group', + # 4,5,6,7: 'chi1,2,3,4-group' + aatype_in_shape = aatype.shape + + # If there is a batch axis, just flatten it away, and reshape everything + # back at the end of the function. + aatype = jnp.reshape(aatype, [-1]) + all_atom_positions = jax.tree_map(lambda x: jnp.reshape(x, [-1, 37]), + all_atom_positions) + all_atom_mask = jnp.reshape(all_atom_mask, [-1, 37]) + + # Compute the gather indices for all residues in the chain. + # shape (N, 8, 3) + residx_rigidgroup_base_atom37_idx = utils.batched_gather( + RESTYPE_RIGIDGROUP_BASE_ATOM37_IDX, aatype) + + # Gather the base atom positions for each rigid group. + base_atom_pos = jax.tree_map( + lambda x: utils.batched_gather( # pylint: disable=g-long-lambda + x, residx_rigidgroup_base_atom37_idx, batch_dims=1), + all_atom_positions) + + # Compute the Rigids. + point_on_neg_x_axis = base_atom_pos[:, :, 0] + origin = base_atom_pos[:, :, 1] + point_on_xy_plane = base_atom_pos[:, :, 2] + gt_rotation = geometry.Rot3Array.from_two_vectors( + origin - point_on_neg_x_axis, point_on_xy_plane - origin) + + gt_frames = geometry.Rigid3Array(gt_rotation, origin) + + # Compute a mask whether the group exists. + # (N, 8) + group_exists = utils.batched_gather(RESTYPE_RIGIDGROUP_MASK, aatype) + + # Compute a mask whether ground truth exists for the group + gt_atoms_exist = utils.batched_gather( # shape (N, 8, 3) + all_atom_mask.astype(jnp.float32), + residx_rigidgroup_base_atom37_idx, + batch_dims=1) + gt_exists = jnp.min(gt_atoms_exist, axis=-1) * group_exists # (N, 8) + + # Adapt backbone frame to old convention (mirror x-axis and z-axis). + rots = np.tile(np.eye(3, dtype=np.float32), [8, 1, 1]) + rots[0, 0, 0] = -1 + rots[0, 2, 2] = -1 + gt_frames = gt_frames.compose_rotation( + geometry.Rot3Array.from_array(rots)) + + # The frames for ambiguous rigid groups are just rotated by 180 degree around + # the x-axis. The ambiguous group is always the last chi-group. + restype_rigidgroup_is_ambiguous = np.zeros([21, 8], dtype=np.float32) + restype_rigidgroup_rots = np.tile(np.eye(3, dtype=np.float32), [21, 8, 1, 1]) + + for resname, _ in residue_constants.residue_atom_renaming_swaps.items(): + restype = residue_constants.restype_order[ + residue_constants.restype_3to1[resname]] + chi_idx = int(sum(residue_constants.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[restype, chi_idx + 4, 2, 2] = -1 + + # Gather the ambiguity information for each residue. + residx_rigidgroup_is_ambiguous = utils.batched_gather( + restype_rigidgroup_is_ambiguous, aatype) + ambiguity_rot = utils.batched_gather(restype_rigidgroup_rots, aatype) + ambiguity_rot = geometry.Rot3Array.from_array(ambiguity_rot) + + # Create the alternative ground truth frames. + alt_gt_frames = gt_frames.compose_rotation(ambiguity_rot) + + fix_shape = lambda x: jnp.reshape(x, aatype_in_shape + (8,)) + + # reshape back to original residue layout + gt_frames = jax.tree_map(fix_shape, gt_frames) + gt_exists = fix_shape(gt_exists) + group_exists = fix_shape(group_exists) + residx_rigidgroup_is_ambiguous = fix_shape(residx_rigidgroup_is_ambiguous) + alt_gt_frames = jax.tree_map(fix_shape, alt_gt_frames) + + return { + 'rigidgroups_gt_frames': gt_frames, # Rigid (..., 8) + 'rigidgroups_gt_exists': gt_exists, # (..., 8) + 'rigidgroups_group_exists': group_exists, # (..., 8) + 'rigidgroups_group_is_ambiguous': + residx_rigidgroup_is_ambiguous, # (..., 8) + 'rigidgroups_alt_gt_frames': alt_gt_frames, # Rigid (..., 8) + } + + +def torsion_angles_to_frames( + aatype: jnp.ndarray, # (N) + backb_to_global: geometry.Rigid3Array, # (N) + torsion_angles_sin_cos: jnp.ndarray # (N, 7, 2) +) -> geometry.Rigid3Array: # (N, 8) + """Compute rigid group frames from torsion angles.""" + assert len(aatype.shape) == 1, ( + f'Expected array of rank 1, got array with shape: {aatype.shape}.') + assert len(backb_to_global.rotation.shape) == 1, ( + f'Expected array of rank 1, got array with shape: ' + f'{backb_to_global.rotation.shape}') + assert len(torsion_angles_sin_cos.shape) == 3, ( + f'Expected array of rank 3, got array with shape: ' + f'{torsion_angles_sin_cos.shape}') + assert torsion_angles_sin_cos.shape[1] == 7, ( + f'wrong shape {torsion_angles_sin_cos.shape}') + assert torsion_angles_sin_cos.shape[2] == 2, ( + f'wrong shape {torsion_angles_sin_cos.shape}') + + # Gather the default frames for all rigid groups. + # geometry.Rigid3Array with shape (N, 8) + m = utils.batched_gather(residue_constants.restype_rigid_group_default_frame, + aatype) + default_frames = geometry.Rigid3Array.from_array4x4(m) + + # Create the rotation matrices according to the given angles (each frame is + # defined such that its rotation is around the x-axis). + sin_angles = torsion_angles_sin_cos[..., 0] + cos_angles = torsion_angles_sin_cos[..., 1] + + # insert zero rotation for backbone group. + num_residues, = aatype.shape + sin_angles = jnp.concatenate([jnp.zeros([num_residues, 1]), sin_angles], + axis=-1) + cos_angles = jnp.concatenate([jnp.ones([num_residues, 1]), cos_angles], + axis=-1) + zeros = jnp.zeros_like(sin_angles) + ones = jnp.ones_like(sin_angles) + + # all_rots are geometry.Rot3Array with shape (N, 8) + all_rots = geometry.Rot3Array(ones, zeros, zeros, + zeros, cos_angles, -sin_angles, + zeros, sin_angles, cos_angles) + + # Apply rotations to the frames. + all_frames = default_frames.compose_rotation(all_rots) + + # chi2, chi3, and chi4 frames do not transform to the backbone frame but to + # the previous frame. So chain them up accordingly. + + chi1_frame_to_backb = all_frames[:, 4] + chi2_frame_to_backb = chi1_frame_to_backb @ all_frames[:, 5] + chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6] + chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7] + + all_frames_to_backb = jax.tree_multimap( + lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5], + chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None], + chi4_frame_to_backb[:, None]) + + # Create the global frames. + # shape (N, 8) + all_frames_to_global = backb_to_global[:, None] @ all_frames_to_backb + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + aatype: jnp.ndarray, # (N) + all_frames_to_global: geometry.Rigid3Array # (N, 8) +) -> geometry.Vec3Array: # (N, 14) + """Put atom literature positions (atom14 encoding) in each rigid group.""" + + # Pick the appropriate transform for every atom. + residx_to_group_idx = utils.batched_gather( + residue_constants.restype_atom14_to_rigid_group, aatype) + group_mask = jax.nn.one_hot( + residx_to_group_idx, num_classes=8) # shape (N, 14, 8) + + # geometry.Rigid3Array with shape (N, 14) + map_atoms_to_global = jax.tree_map( + lambda x: jnp.sum(x[:, None, :] * group_mask, axis=-1), + all_frames_to_global) + + # Gather the literature atom positions for each residue. + # geometry.Vec3Array with shape (N, 14) + lit_positions = geometry.Vec3Array.from_array( + utils.batched_gather( + residue_constants.restype_atom14_rigid_group_positions, aatype)) + + # Transform each atom from its local frame to the global frame. + # geometry.Vec3Array with shape (N, 14) + pred_positions = map_atoms_to_global.apply_to_point(lit_positions) + + # Mask out non-existing atoms. + mask = utils.batched_gather(residue_constants.restype_atom14_mask, aatype) + pred_positions = pred_positions * mask + + return pred_positions + + +def extreme_ca_ca_distance_violations( + positions: geometry.Vec3Array, # (N, 37(14)) + mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + max_angstrom_tolerance=1.5 + ) -> jnp.ndarray: + """Counts residues whose Ca is a large distance from its neighbor.""" + this_ca_pos = positions[:-1, 1] # (N - 1,) + this_ca_mask = mask[:-1, 1] # (N - 1) + next_ca_pos = positions[1:, 1] # (N - 1,) + next_ca_mask = mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + ca_ca_distance = geometry.euclidean_distance(this_ca_pos, next_ca_pos, 1e-6) + violations = (ca_ca_distance - + residue_constants.ca_ca) > max_angstrom_tolerance + mask = this_ca_mask * next_ca_mask * has_no_gap_mask + return utils.mask_mean(mask=mask, value=violations) + + +def between_residue_bond_loss( + pred_atom_positions: geometry.Vec3Array, # (N, 37(14)) + pred_atom_mask: jnp.ndarray, # (N, 37(14)) + residue_index: jnp.ndarray, # (N) + aatype: jnp.ndarray, # (N) + tolerance_factor_soft=12.0, + tolerance_factor_hard=12.0) -> Dict[Text, jnp.ndarray]: + """Flat-bottom loss to penalize structural violations between residues.""" + assert len(pred_atom_positions.shape) == 2 + assert len(pred_atom_mask.shape) == 2 + assert len(residue_index.shape) == 1 + assert len(aatype.shape) == 1 + + # Get the positions of the relevant backbone atoms. + this_ca_pos = pred_atom_positions[:-1, 1] # (N - 1) + this_ca_mask = pred_atom_mask[:-1, 1] # (N - 1) + this_c_pos = pred_atom_positions[:-1, 2] # (N - 1) + this_c_mask = pred_atom_mask[:-1, 2] # (N - 1) + next_n_pos = pred_atom_positions[1:, 0] # (N - 1) + next_n_mask = pred_atom_mask[1:, 0] # (N - 1) + next_ca_pos = pred_atom_positions[1:, 1] # (N - 1) + next_ca_mask = pred_atom_mask[1:, 1] # (N - 1) + has_no_gap_mask = ((residue_index[1:] - residue_index[:-1]) == 1.0).astype( + jnp.float32) + + # Compute loss for the C--N bond. + c_n_bond_length = geometry.euclidean_distance(this_c_pos, next_n_pos, 1e-6) + + # The C-N bond to proline has slightly different length because of the ring. + next_is_proline = ( + aatype[1:] == residue_constants.restype_order['P']).astype(jnp.float32) + gt_length = ( + (1. - next_is_proline) * residue_constants.between_res_bond_length_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_c_n[1]) + gt_stddev = ( + (1. - next_is_proline) * + residue_constants.between_res_bond_length_stddev_c_n[0] + + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[1]) + c_n_bond_length_error = jnp.sqrt(1e-6 + + jnp.square(c_n_bond_length - gt_length)) + c_n_loss_per_residue = jax.nn.relu( + c_n_bond_length_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * has_no_gap_mask + c_n_loss = jnp.sum(mask * c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_violation_mask = mask * ( + c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)) + + # Compute loss for the angles. + c_ca_unit_vec = (this_ca_pos - this_c_pos).normalized(1e-6) + c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length + n_ca_unit_vec = (next_ca_pos - next_n_pos).normalized(1e-6) + + ca_c_n_cos_angle = c_ca_unit_vec.dot(c_n_unit_vec) + gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0] + gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0] + ca_c_n_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(ca_c_n_cos_angle - gt_angle)) + ca_c_n_loss_per_residue = jax.nn.relu( + ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask + ca_c_n_loss = jnp.sum(mask * ca_c_n_loss_per_residue) / (jnp.sum(mask) + 1e-6) + ca_c_n_violation_mask = mask * (ca_c_n_cos_angle_error > + (tolerance_factor_hard * gt_stddev)) + + c_n_ca_cos_angle = (-c_n_unit_vec).dot(n_ca_unit_vec) + gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0] + gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1] + c_n_ca_cos_angle_error = jnp.sqrt( + 1e-6 + jnp.square(c_n_ca_cos_angle - gt_angle)) + c_n_ca_loss_per_residue = jax.nn.relu( + c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev) + mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask + c_n_ca_loss = jnp.sum(mask * c_n_ca_loss_per_residue) / (jnp.sum(mask) + 1e-6) + c_n_ca_violation_mask = mask * ( + c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)) + + # Compute a per residue loss (equally distribute the loss to both + # neighbouring residues). + per_residue_loss_sum = (c_n_loss_per_residue + + ca_c_n_loss_per_residue + + c_n_ca_loss_per_residue) + per_residue_loss_sum = 0.5 * (jnp.pad(per_residue_loss_sum, [[0, 1]]) + + jnp.pad(per_residue_loss_sum, [[1, 0]])) + + # Compute hard violations. + violation_mask = jnp.max( + jnp.stack([c_n_violation_mask, + ca_c_n_violation_mask, + c_n_ca_violation_mask]), axis=0) + violation_mask = jnp.maximum( + jnp.pad(violation_mask, [[0, 1]]), + jnp.pad(violation_mask, [[1, 0]])) + + return {'c_n_loss_mean': c_n_loss, # shape () + 'ca_c_n_loss_mean': ca_c_n_loss, # shape () + 'c_n_ca_loss_mean': c_n_ca_loss, # shape () + 'per_residue_loss_sum': per_residue_loss_sum, # shape (N) + 'per_residue_violation_mask': violation_mask # shape (N) + } + + +def between_residue_clash_loss( + pred_positions: geometry.Vec3Array, # (N, 14) + atom_exists: jnp.ndarray, # (N, 14) + atom_radius: jnp.ndarray, # (N, 14) + residue_index: jnp.ndarray, # (N) + overlap_tolerance_soft=1.5, + overlap_tolerance_hard=1.5) -> Dict[Text, jnp.ndarray]: + """Loss to penalize steric clashes between residues.""" + assert len(pred_positions.shape) == 2 + assert len(atom_exists.shape) == 2 + assert len(atom_radius.shape) == 2 + assert len(residue_index.shape) == 1 + + # Create the distance matrix. + # (N, N, 14, 14) + dists = geometry.euclidean_distance(pred_positions[:, None, :, None], + pred_positions[None, :, None, :], 1e-10) + + # Create the mask for valid distances. + # shape (N, N, 14, 14) + dists_mask = (atom_exists[:, None, :, None] * atom_exists[None, :, None, :]) + + # Mask out all the duplicate entries in the lower triangular matrix. + # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms + # are handled separately. + dists_mask *= ( + residue_index[:, None, None, None] < residue_index[None, :, None, None]) + + # Backbone C--N bond between subsequent residues is no clash. + c_one_hot = jax.nn.one_hot(2, num_classes=14) + n_one_hot = jax.nn.one_hot(0, num_classes=14) + neighbour_mask = ((residue_index[:, None, None, None] + + 1) == residue_index[None, :, None, None]) + c_n_bonds = neighbour_mask * c_one_hot[None, None, :, + None] * n_one_hot[None, None, None, :] + dists_mask *= (1. - c_n_bonds) + + # Disulfide bridge between two cysteines is no clash. + cys_sg_idx = residue_constants.restype_name_to_atom14_names['CYS'].index('SG') + cys_sg_one_hot = jax.nn.one_hot(cys_sg_idx, num_classes=14) + disulfide_bonds = (cys_sg_one_hot[None, None, :, None] * + cys_sg_one_hot[None, None, None, :]) + dists_mask *= (1. - disulfide_bonds) + + # Compute the lower bound for the allowed distances. + # shape (N, N, 14, 14) + dists_lower_bound = dists_mask * ( + atom_radius[:, None, :, None] + atom_radius[None, :, None, :]) + + # Compute the error. + # shape (N, N, 14, 14) + dists_to_low_error = dists_mask * jax.nn.relu( + dists_lower_bound - overlap_tolerance_soft - dists) + + # Compute the mean loss. + # shape () + mean_loss = (jnp.sum(dists_to_low_error) + / (1e-6 + jnp.sum(dists_mask))) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(dists_to_low_error, axis=[0, 2]) + + jnp.sum(dists_to_low_error, axis=[1, 3])) + + # Compute the hard clash mask. + # shape (N, N, 14, 14) + clash_mask = dists_mask * ( + dists < (dists_lower_bound - overlap_tolerance_hard)) + + # Compute the per atom clash. + # shape (N, 14) + per_atom_clash_mask = jnp.maximum( + jnp.max(clash_mask, axis=[0, 2]), + jnp.max(clash_mask, axis=[1, 3])) + + return {'mean_loss': mean_loss, # shape () + 'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_clash_mask': per_atom_clash_mask # shape (N, 14) + } + + +def within_residue_violations( + pred_positions: geometry.Vec3Array, # (N, 14) + atom_exists: jnp.ndarray, # (N, 14) + dists_lower_bound: jnp.ndarray, # (N, 14, 14) + dists_upper_bound: jnp.ndarray, # (N, 14, 14) + tighten_bounds_for_loss=0.0, +) -> Dict[Text, jnp.ndarray]: + """Find within-residue violations.""" + assert len(pred_positions.shape) == 2 + assert len(atom_exists.shape) == 2 + assert len(dists_lower_bound.shape) == 3 + assert len(dists_upper_bound.shape) == 3 + + # Compute the mask for each residue. + # shape (N, 14, 14) + dists_masks = (1. - jnp.eye(14, 14)[None]) + dists_masks *= (atom_exists[:, :, None] * atom_exists[:, None, :]) + + # Distance matrix + # shape (N, 14, 14) + dists = geometry.euclidean_distance(pred_positions[:, :, None], + pred_positions[:, None, :], 1e-10) + + # Compute the loss. + # shape (N, 14, 14) + dists_to_low_error = jax.nn.relu( + dists_lower_bound + tighten_bounds_for_loss - dists) + dists_to_high_error = jax.nn.relu( + dists + tighten_bounds_for_loss - dists_upper_bound) + loss = dists_masks * (dists_to_low_error + dists_to_high_error) + + # Compute the per atom loss sum. + # shape (N, 14) + per_atom_loss_sum = (jnp.sum(loss, axis=1) + + jnp.sum(loss, axis=2)) + + # Compute the violations mask. + # shape (N, 14, 14) + violations = dists_masks * ((dists < dists_lower_bound) | + (dists > dists_upper_bound)) + + # Compute the per atom violations. + # shape (N, 14) + per_atom_violations = jnp.maximum( + jnp.max(violations, axis=1), jnp.max(violations, axis=2)) + + return {'per_atom_loss_sum': per_atom_loss_sum, # shape (N, 14) + 'per_atom_violations': per_atom_violations # shape (N, 14) + } + + +def find_optimal_renaming( + gt_positions: geometry.Vec3Array, # (N, 14) + alt_gt_positions: geometry.Vec3Array, # (N, 14) + atom_is_ambiguous: jnp.ndarray, # (N, 14) + gt_exists: jnp.ndarray, # (N, 14) + pred_positions: geometry.Vec3Array, # (N, 14) +) -> jnp.ndarray: # (N): + """Find optimal renaming for ground truth that maximizes LDDT.""" + assert len(gt_positions.shape) == 2 + assert len(alt_gt_positions.shape) == 2 + assert len(atom_is_ambiguous.shape) == 2 + assert len(gt_exists.shape) == 2 + assert len(pred_positions.shape) == 2 + + # Create the pred distance matrix. + # shape (N, N, 14, 14) + pred_dists = geometry.euclidean_distance(pred_positions[:, None, :, None], + pred_positions[None, :, None, :], + 1e-10) + + # Compute distances for ground truth with original and alternative names. + # shape (N, N, 14, 14) + gt_dists = geometry.euclidean_distance(gt_positions[:, None, :, None], + gt_positions[None, :, None, :], 1e-10) + + alt_gt_dists = geometry.euclidean_distance(alt_gt_positions[:, None, :, None], + alt_gt_positions[None, :, None, :], + 1e-10) + + # Compute LDDT's. + # shape (N, N, 14, 14) + lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, gt_dists)) + alt_lddt = jnp.sqrt(1e-10 + squared_difference(pred_dists, alt_gt_dists)) + + # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms + # in cols. + # shape (N ,N, 14, 14) + mask = ( + gt_exists[:, None, :, None] * # rows + atom_is_ambiguous[:, None, :, None] * # rows + gt_exists[None, :, None, :] * # cols + (1. - atom_is_ambiguous[None, :, None, :])) # cols + + # Aggregate distances for each residue to the non-amibuguous atoms. + # shape (N) + per_res_lddt = jnp.sum(mask * lddt, axis=[1, 2, 3]) + alt_per_res_lddt = jnp.sum(mask * alt_lddt, axis=[1, 2, 3]) + + # Decide for each residue, whether alternative naming is better. + # shape (N) + alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).astype(jnp.float32) + + return alt_naming_is_better # shape (N) + + +def frame_aligned_point_error( + pred_frames: geometry.Rigid3Array, # shape (num_frames) + target_frames: geometry.Rigid3Array, # shape (num_frames) + frames_mask: jnp.ndarray, # shape (num_frames) + pred_positions: geometry.Vec3Array, # shape (num_positions) + target_positions: geometry.Vec3Array, # shape (num_positions) + positions_mask: jnp.ndarray, # shape (num_positions) + pair_mask: jnp.ndarray, # shape (num_frames, num_posiitons) + l1_clamp_distance: float, + length_scale=20., + epsilon=1e-4) -> jnp.ndarray: # shape () + """Measure point error under different alignements. + + Computes error between two structures with B points + under A alignments derived form the given pairs of frames. + Args: + pred_frames: num_frames reference frames for 'pred_positions'. + target_frames: num_frames reference frames for 'target_positions'. + frames_mask: Mask for frame pairs to use. + pred_positions: num_positions predicted positions of the structure. + target_positions: num_positions target positions of the structure. + positions_mask: Mask on which positions to score. + pair_mask: A (num_frames, num_positions) mask to use in the loss, useful + for separating intra from inter chain losses. + l1_clamp_distance: Distance cutoff on error beyond which gradients will + be zero. + length_scale: length scale to divide loss by. + epsilon: small value used to regularize denominator for masked average. + Returns: + Masked Frame aligned point error. + """ + # For now we do not allow any batch dimensions. + assert len(pred_frames.rotation.shape) == 1 + assert len(target_frames.rotation.shape) == 1 + assert frames_mask.ndim == 1 + assert pred_positions.x.ndim == 1 + assert target_positions.x.ndim == 1 + assert positions_mask.ndim == 1 + + # Compute array of predicted positions in the predicted frames. + # geometry.Vec3Array (num_frames, num_positions) + local_pred_pos = pred_frames[:, None].inverse().apply_to_point( + pred_positions[None, :]) + + # Compute array of target positions in the target frames. + # geometry.Vec3Array (num_frames, num_positions) + local_target_pos = target_frames[:, None].inverse().apply_to_point( + target_positions[None, :]) + + # Compute errors between the structures. + # jnp.ndarray (num_frames, num_positions) + error_dist = geometry.euclidean_distance(local_pred_pos, local_target_pos, + epsilon) + + clipped_error_dist = jnp.clip(error_dist, 0, l1_clamp_distance) + + normed_error = clipped_error_dist / length_scale + normed_error *= jnp.expand_dims(frames_mask, axis=-1) + normed_error *= jnp.expand_dims(positions_mask, axis=-2) + if pair_mask is not None: + normed_error *= pair_mask + + mask = (jnp.expand_dims(frames_mask, axis=-1) * + jnp.expand_dims(positions_mask, axis=-2)) + if pair_mask is not None: + mask *= pair_mask + normalization_factor = jnp.sum(mask, axis=(-1, -2)) + return (jnp.sum(normed_error, axis=(-2, -1)) / + (epsilon + normalization_factor)) + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in residue_constants.restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in residue_constants.restypes: + residue_name = residue_constants.restype_1to3[residue_name] + residue_chi_angles = residue_constants.chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append( + [residue_constants.atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return jnp.asarray(chi_atom_indices) + + +def compute_chi_angles(positions: geometry.Vec3Array, + mask: geometry.Vec3Array, + aatype: geometry.Vec3Array): + """Computes the chi angles given all atom positions and the amino acid type. + + Args: + positions: A Vec3Array of shape + [num_res, residue_constants.atom_type_num], with positions of + atoms needed to calculate chi angles. Supports up to 1 batch dimension. + mask: An optional tensor of shape + [num_res, residue_constants.atom_type_num] that masks which atom + positions are set for each residue. If given, then the chi mask will be + set to 1 for a chi angle only if the amino acid has that chi angle and all + the chi atoms needed to calculate that chi angle are set. If not given + (set to None), the chi mask will be set to 1 for a chi angle if the amino + acid has that chi angle and whether the actual atoms needed to calculate + it were set will be ignored. + aatype: A tensor of shape [num_res] with amino acid type integer + code (0 to 21). Supports up to 1 batch dimension. + + Returns: + A tuple of tensors (chi_angles, mask), where both have shape + [num_res, 4]. The mask masks out unused chi angles for amino acid + types that have less than 4 chi angles. If atom_positions_mask is set, the + chi mask will also mask out uncomputable chi angles. + """ + + # Don't assert on the num_res and batch dimensions as they might be unknown. + assert positions.shape[-1] == residue_constants.atom_type_num + assert mask.shape[-1] == residue_constants.atom_type_num + + # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4]. + chi_atom_indices = get_chi_atom_indices() + # Select atoms to compute chis. Shape: [num_res, chis=4, atoms=4]. + atom_indices = utils.batched_gather( + params=chi_atom_indices, indices=aatype, axis=0) + # Gather atom positions. Shape: [num_res, chis=4, atoms=4, xyz=3]. + chi_angle_atoms = jax.tree_map( + lambda x: utils.batched_gather( # pylint: disable=g-long-lambda + params=x, indices=atom_indices, axis=-1, batch_dims=1), positions) + a, b, c, d = [chi_angle_atoms[..., i] for i in range(4)] + + chi_angles = geometry.dihedral_angle(a, b, c, d) + + # Copy the chi angle mask, add the UNKNOWN residue. Shape: [restypes, 4]. + chi_angles_mask = list(residue_constants.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = jnp.asarray(chi_angles_mask) + # Compute the chi angle mask. Shape [num_res, chis=4]. + chi_mask = utils.batched_gather(params=chi_angles_mask, indices=aatype, + axis=0) + + # The chi_mask is set to 1 only when all necessary chi angle atoms were set. + # Gather the chi angle atoms mask. Shape: [num_res, chis=4, atoms=4]. + chi_angle_atoms_mask = utils.batched_gather( + params=mask, indices=atom_indices, axis=-1, batch_dims=1) + # Check if all 4 chi angle atoms were set. Shape: [num_res, chis=4]. + chi_angle_atoms_mask = jnp.prod(chi_angle_atoms_mask, axis=[-1]) + chi_mask = chi_mask * chi_angle_atoms_mask.astype(jnp.float32) + + return chi_angles, chi_mask + + +def make_transform_from_reference( + a_xyz: geometry.Vec3Array, + b_xyz: geometry.Vec3Array, + c_xyz: geometry.Vec3Array) -> geometry.Rigid3Array: + """Returns rotation and translation matrices to convert from reference. + + Note that this method does not take care of symmetries. If you provide the + coordinates in the non-standard way, the A atom will end up in the negative + y-axis rather than in the positive y-axis. You need to take care of such + cases in your code. + + Args: + a_xyz: A Vec3Array. + b_xyz: A Vec3Array. + c_xyz: A Vec3Array. + + Returns: + A Rigid3Array which, when applied to coordinates in a canonicalized + reference frame, will give coordinates approximately equal + the original coordinates (in the global frame). + """ + rotation = geometry.Rot3Array.from_two_vectors(c_xyz - b_xyz, + a_xyz - b_xyz) + return geometry.Rigid3Array(rotation, b_xyz) diff --git a/alphafold/model/common_modules.py b/alphafold/model/common_modules.py index f239c870b..08776a7f0 100644 --- a/alphafold/model/common_modules.py +++ b/alphafold/model/common_modules.py @@ -13,72 +13,118 @@ # limitations under the License. """A collection of common Haiku modules for use in protein folding.""" +import numbers +from typing import Union, Sequence + import haiku as hk import jax.numpy as jnp +import numpy as np + + +# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, + dtype=np.float32) + + +def get_initializer_scale(initializer_name, input_shape): + """Get Initializer for weights and scale to multiply activations by.""" + + if initializer_name == 'zeros': + w_init = hk.initializers.Constant(0.0) + else: + # fan-in scaling + scale = 1. + for channel_dim in input_shape: + scale /= channel_dim + if initializer_name == 'relu': + scale *= 2 + + noise_scale = scale + + stddev = np.sqrt(noise_scale) + # Adjust stddev for truncation. + stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR + w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev) + + return w_init class Linear(hk.Module): - """Protein folding specific Linear Module. + """Protein folding specific Linear module. This differs from the standard Haiku Linear in a few ways: - * It supports inputs of arbitrary rank + * It supports inputs and outputs of arbitrary rank * Initializers are specified by strings """ def __init__(self, - num_output: int, + num_output: Union[int, Sequence[int]], initializer: str = 'linear', + num_input_dims: int = 1, use_bias: bool = True, bias_init: float = 0., + precision = None, name: str = 'linear'): """Constructs Linear Module. Args: - num_output: number of output channels. + num_output: Number of output channels. Can be tuple when outputting + multiple dimensions. initializer: What initializer to use, should be one of {'linear', 'relu', 'zeros'} + num_input_dims: Number of dimensions from the end to project. use_bias: Whether to include trainable bias bias_init: Value used to initialize bias. - name: name of module, used for name scopes. + precision: What precision to use for matrix multiplication, defaults + to None. + name: Name of module, used for name scopes. """ - super().__init__(name=name) - self.num_output = num_output + if isinstance(num_output, numbers.Integral): + self.output_shape = (num_output,) + else: + self.output_shape = tuple(num_output) self.initializer = initializer self.use_bias = use_bias self.bias_init = bias_init + self.num_input_dims = num_input_dims + self.num_output_dims = len(self.output_shape) + self.precision = precision - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + def __call__(self, inputs): """Connects Module. Args: - inputs: Tensor of shape [..., num_channel] + inputs: Tensor with at least num_input_dims dimensions. Returns: - output of shape [..., num_output] + output of shape [...] + num_output. """ - n_channels = int(inputs.shape[-1]) - weight_shape = [n_channels, self.num_output] - if self.initializer == 'linear': - weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) - elif self.initializer == 'relu': - weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) - elif self.initializer == 'zeros': - weight_init = hk.initializers.Constant(0.0) + num_input_dims = self.num_input_dims + + if self.num_input_dims > 0: + in_shape = inputs.shape[-self.num_input_dims:] + else: + in_shape = () + + weight_init = get_initializer_scale(self.initializer, in_shape) + in_letters = 'abcde'[:self.num_input_dims] + out_letters = 'hijkl'[:self.num_output_dims] + + weight_shape = in_shape + self.output_shape weights = hk.get_parameter('weights', weight_shape, inputs.dtype, weight_init) - # this is equivalent to einsum('...c,cd->...d', inputs, weights) - # but turns out to be slightly faster - inputs = jnp.swapaxes(inputs, -1, -2) - output = jnp.einsum('...cb,cd->...db', inputs, weights) - output = jnp.swapaxes(output, -1, -2) + equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' + + output = jnp.einsum(equation, inputs, weights, precision=self.precision) if self.use_bias: - bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, + bias = hk.get_parameter('bias', self.output_shape, inputs.dtype, hk.initializers.Constant(self.bias_init)) output += bias return output + diff --git a/alphafold/model/config.py b/alphafold/model/config.py index 03d494f71..d1aba86fd 100644 --- a/alphafold/model/config.py +++ b/alphafold/model/config.py @@ -17,7 +17,6 @@ from alphafold.model.tf import shape_placeholders import ml_collections - NUM_RES = shape_placeholders.NUM_RES NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ @@ -27,6 +26,9 @@ def model_config(name: str) -> ml_collections.ConfigDict: """Get the ConfigDict of a CASP14 model.""" + if 'multimer' in name: + return CONFIG_MULTIMER + if name not in CONFIG_DIFFS: raise ValueError(f'Invalid model name {name}.') cfg = copy.deepcopy(CONFIG) @@ -34,6 +36,32 @@ def model_config(name: str) -> ml_collections.ConfigDict: return cfg +MODEL_PRESETS = { + 'monomer': ( + 'model_1', + 'model_2', + 'model_3', + 'model_4', + 'model_5', + ), + 'monomer_ptm': ( + 'model_1_ptm', + 'model_2_ptm', + 'model_3_ptm', + 'model_4_ptm', + 'model_5_ptm', + ), + 'multimer': ( + 'model_1_multimer', + 'model_2_multimer', + 'model_3_multimer', + 'model_4_multimer', + 'model_5_multimer', + ), +} +MODEL_PRESETS['monomer_casp14'] = MODEL_PRESETS['monomer'] + + CONFIG_DIFFS = { 'model_1': { # Jumper et al. (2021) Suppl. Table 5, Model 1.1.1 @@ -206,6 +234,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'shared_dropout': True }, 'outer_product_mean': { + 'first': False, 'chunk_size': 128, 'dropout_rate': 0.0, 'num_outer_channel': 32, @@ -322,6 +351,7 @@ def model_config(name: str) -> ml_collections.ConfigDict: }, 'global_config': { 'deterministic': False, + 'multimer_mode': False, 'subbatch_size': 4, 'use_remat': False, 'zero_init': True @@ -400,3 +430,228 @@ def model_config(name: str) -> ml_collections.ConfigDict: 'resample_msa_in_recycling': True }, }) + + +CONFIG_MULTIMER = ml_collections.ConfigDict({ + 'model': { + 'embeddings_and_evoformer': { + 'evoformer_num_block': 48, + 'evoformer': { + 'msa_column_attention': { + 'dropout_rate': 0.0, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'msa_row_attention_with_pair_bias': { + 'dropout_rate': 0.15, + 'gating': True, + 'num_head': 8, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'msa_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'outer_product_mean': { + 'chunk_size': 128, + 'dropout_rate': 0.0, + 'first': True, + 'num_outer_channel': 32, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 128, + 'orientation': 'per_row', + 'shared_dropout': True + } + }, + 'extra_msa_channel': 64, + 'extra_msa_stack_num_block': 4, + 'num_msa': 252, + 'num_extra_msa': 1152, + 'masked_msa': { + 'profile_prob': 0.1, + 'replace_fraction': 0.15, + 'same_prob': 0.1, + 'uniform_prob': 0.1 + }, + 'use_chain_relative': True, + 'max_relative_chain': 2, + 'max_relative_idx': 32, + 'seq_channel': 384, + 'msa_channel': 256, + 'pair_channel': 128, + 'prev_pos': { + 'max_bin': 20.75, + 'min_bin': 3.25, + 'num_bins': 15 + }, + 'recycle_features': True, + 'recycle_pos': True, + 'template': { + 'attention': { + 'gating': False, + 'num_head': 4 + }, + 'dgram_features': { + 'max_bin': 50.75, + 'min_bin': 3.25, + 'num_bins': 39 + }, + 'enabled': True, + 'max_templates': 4, + 'num_channels': 64, + 'subbatch_size': 128, + 'template_pair_stack': { + 'num_block': 2, + 'pair_transition': { + 'dropout_rate': 0.0, + 'num_intermediate_factor': 2, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_attention_ending_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_column', + 'shared_dropout': True + }, + 'triangle_attention_starting_node': { + 'dropout_rate': 0.25, + 'gating': True, + 'num_head': 4, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_incoming': { + 'dropout_rate': 0.25, + 'equation': 'kjc,kic->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + }, + 'triangle_multiplication_outgoing': { + 'dropout_rate': 0.25, + 'equation': 'ikc,jkc->ijc', + 'num_intermediate_channel': 64, + 'orientation': 'per_row', + 'shared_dropout': True + } + } + }, + }, + 'global_config': { + 'deterministic': False, + 'multimer_mode': True, + 'subbatch_size': 4, + 'use_remat': False, + 'zero_init': True + }, + 'heads': { + 'distogram': { + 'first_break': 2.3125, + 'last_break': 21.6875, + 'num_bins': 64, + 'weight': 0.3 + }, + 'experimentally_resolved': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'weight': 0.01 + }, + 'masked_msa': { + 'weight': 2.0 + }, + 'predicted_aligned_error': { + 'filter_by_resolution': True, + 'max_error_bin': 31.0, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 64, + 'num_channels': 128, + 'weight': 0.1 + }, + 'predicted_lddt': { + 'filter_by_resolution': True, + 'max_resolution': 3.0, + 'min_resolution': 0.1, + 'num_bins': 50, + 'num_channels': 128, + 'weight': 0.01 + }, + 'structure_module': { + 'angle_norm_weight': 0.01, + 'chi_weight': 0.5, + 'clash_overlap_tolerance': 1.5, + 'dropout': 0.1, + 'interface_fape': { + 'atom_clamp_distance': 1000.0, + 'loss_unit_distance': 20.0 + }, + 'intra_chain_fape': { + 'atom_clamp_distance': 10.0, + 'loss_unit_distance': 10.0 + }, + 'num_channel': 384, + 'num_head': 12, + 'num_layer': 8, + 'num_layer_in_transition': 3, + 'num_point_qk': 4, + 'num_point_v': 8, + 'num_scalar_qk': 16, + 'num_scalar_v': 16, + 'position_scale': 20.0, + 'sidechain': { + 'atom_clamp_distance': 10.0, + 'loss_unit_distance': 10.0, + 'num_channel': 128, + 'num_residual_block': 2, + 'weight_frac': 0.5 + }, + 'structural_violation_loss_weight': 1.0, + 'violation_tolerance_factor': 12.0, + 'weight': 1.0 + } + }, + 'num_ensemble_eval': 1, + 'num_recycle': 3, + 'resample_msa_in_recycling': True + } +}) diff --git a/alphafold/model/features.py b/alphafold/model/features.py index b31b277e0..c261cef19 100644 --- a/alphafold/model/features.py +++ b/alphafold/model/features.py @@ -15,8 +15,10 @@ """Code to generate processed features.""" import copy from typing import List, Mapping, Tuple + from alphafold.model.tf import input_pipeline from alphafold.model.tf import proteins_dataset + import ml_collections import numpy as np import tensorflow.compat.v1 as tf diff --git a/alphafold/model/folding_multimer.py b/alphafold/model/folding_multimer.py new file mode 100644 index 000000000..6bdc6f163 --- /dev/null +++ b/alphafold/model/folding_multimer.py @@ -0,0 +1,1160 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and utilities for the structure module in the multimer system.""" + +import functools +import numbers +from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union + +from alphafold.common import residue_constants +from alphafold.model import all_atom_multimer +from alphafold.model import common_modules +from alphafold.model import geometry +from alphafold.model import modules +from alphafold.model import prng +from alphafold.model import utils +from alphafold.model.geometry import utils as geometry_utils +import haiku as hk +import jax +import jax.numpy as jnp +import ml_collections +import numpy as np + + +EPSILON = 1e-8 +Float = Union[float, jnp.ndarray] + + +def squared_difference(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Computes Squared difference between two arrays.""" + return jnp.square(x - y) + + +def make_backbone_affine( + positions: geometry.Vec3Array, + mask: jnp.ndarray, + aatype: jnp.ndarray, + ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]: + """Make backbone Rigid3Array and mask.""" + del aatype + a = residue_constants.atom_order['N'] + b = residue_constants.atom_order['CA'] + c = residue_constants.atom_order['C'] + + rigid_mask = (mask[:, a] * mask[:, b] * mask[:, c]).astype( + jnp.float32) + + rigid = all_atom_multimer.make_transform_from_reference( + a_xyz=positions[:, a], b_xyz=positions[:, b], c_xyz=positions[:, c]) + + return rigid, rigid_mask + + +class QuatRigid(hk.Module): + """Module for projecting Rigids via a quaternion.""" + + def __init__(self, + global_config: ml_collections.ConfigDict, + rigid_shape: Union[int, Iterable[int]] = tuple(), + full_quat: bool = False, + init: str = 'zeros', + name: str = 'quat_rigid'): + """Module projecting a Rigid Object. + + For this Module the Rotation is parametrized as a quaternion, + If 'full_quat' is True a 4 vector is produced for the rotation which is + normalized and treated as a quaternion. + When 'full_quat' is False a 3 vector is produced and the 1st component of + the quaternion is set to 1. + + Args: + global_config: Global Config, used to set certain properties of underlying + Linear module, see common_modules.Linear for details. + rigid_shape: Shape of Rigids relative to shape of activations, e.g. when + activations have shape (n,) and this is (m,) output will be (n, m) + full_quat: Whether to parametrize rotation using full quaternion. + init: initializer to use, see common_modules.Linear for details + name: Name to use for module. + """ + self.init = init + self.global_config = global_config + if isinstance(rigid_shape, int): + self.rigid_shape = (rigid_shape,) + else: + self.rigid_shape = tuple(rigid_shape) + self.full_quat = full_quat + super(QuatRigid, self).__init__(name=name) + + def __call__(self, activations: jnp.ndarray) -> geometry.Rigid3Array: + """Executes Module. + + This returns a set of rigid with the same shape as activations, projecting + the channel dimension, rigid_shape controls the trailing dimensions. + For example when activations is shape (12, 5) and rigid_shape is (3, 2) + then the shape of the output rigids will be (12, 3, 2). + This also supports passing in an empty tuple for rigid shape, in that case + the example would produce a rigid of shape (12,). + + Args: + activations: Activations to use for projection, shape [..., num_channel] + Returns: + Rigid transformations with shape [...] + rigid_shape + """ + if self.full_quat: + rigid_dim = 7 + else: + rigid_dim = 6 + linear_dims = self.rigid_shape + (rigid_dim,) + rigid_flat = common_modules.Linear( + linear_dims, + initializer=self.init, + precision=jax.lax.Precision.HIGHEST, + name='rigid')( + activations) + rigid_flat = geometry_utils.unstack(rigid_flat) + if self.full_quat: + qw, qx, qy, qz = rigid_flat[:4] + translation = rigid_flat[4:] + else: + qx, qy, qz = rigid_flat[:3] + qw = jnp.ones_like(qx) + translation = rigid_flat[3:] + rotation = geometry.Rot3Array.from_quaternion( + qw, qx, qy, qz, normalize=True) + translation = geometry.Vec3Array(*translation) + return geometry.Rigid3Array(rotation, translation) + + +class PointProjection(hk.Module): + """Given input reprensentation and frame produces points in global frame.""" + + def __init__(self, + num_points: Union[Iterable[int], int], + global_config: ml_collections.ConfigDict, + return_local_points: bool = False, + name: str = 'point_projection'): + """Constructs Linear Module. + + Args: + num_points: number of points to project. Can be tuple when outputting + multiple dimensions + global_config: Global Config, passed through to underlying Linear + return_local_points: Whether to return points in local frame as well. + name: name of module, used for name scopes. + """ + if isinstance(num_points, numbers.Integral): + self.num_points = (num_points,) + else: + self.num_points = tuple(num_points) + + self.return_local_points = return_local_points + + self.global_config = global_config + + super().__init__(name=name) + + def __call__( + self, activations: jnp.ndarray, rigids: geometry.Rigid3Array + ) -> Union[geometry.Vec3Array, Tuple[geometry.Vec3Array, geometry.Vec3Array]]: + output_shape = self.num_points + output_shape = output_shape[:-1] + (3 * output_shape[-1],) + points_local = common_modules.Linear( + output_shape, + precision=jax.lax.Precision.HIGHEST, + name='point_projection')( + activations) + points_local = jnp.split(points_local, 3, axis=-1) + points_local = geometry.Vec3Array(*points_local) + rigids = rigids[(...,) + (None,) * len(output_shape)] + points_global = rigids.apply_to_point(points_local) + if self.return_local_points: + return points_global, points_local + else: + return points_global + + +class InvariantPointAttention(hk.Module): + """Covariant attention module. + + The high-level idea is that this attention module works over a set of points + and associated orientations in 3D space (e.g. protein residues). + + Each residue outputs a set of queries and keys as points in their local + reference frame. The attention is then defined as the euclidean distance + between the queries and keys in the global frame. + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + dist_epsilon: float = 1e-8, + name: str = 'invariant_point_attention'): + """Initialize. + + Args: + config: iterative Fold Head Config + global_config: Global Config of Model. + dist_epsilon: Small value to avoid NaN in distance calculation. + name: Sonnet name. + """ + super().__init__(name=name) + + self._dist_epsilon = dist_epsilon + self._zero_initialize_last = global_config.zero_init + + self.config = config + + self.global_config = global_config + + def __call__( + self, + inputs_1d: jnp.ndarray, + inputs_2d: jnp.ndarray, + mask: jnp.ndarray, + rigid: geometry.Rigid3Array, + ) -> jnp.ndarray: + """Compute geometric aware attention. + + Given a set of query residues (defined by affines and associated scalar + features), this function computes geometric aware attention between the + query residues and target residues. + + The residues produce points in their local reference frame, which + are converted into the global frame to get attention via euclidean distance. + + Equivalently the target residues produce points in their local frame to be + used as attention values, which are converted into the query residues local + frames. + + Args: + inputs_1d: (N, C) 1D input embedding that is the basis for the + scalar queries. + inputs_2d: (N, M, C') 2D input embedding, used for biases values in the + attention between query_inputs_1d and target_inputs_1d. + mask: (N, 1) mask to indicate query_inputs_1d that participate in + the attention. + rigid: Rigid object describing the position and orientation of + every element in query_inputs_1d. + + Returns: + Transformation of the input embedding. + """ + + num_head = self.config.num_head + + attn_logits = 0. + + num_point_qk = self.config.num_point_qk + # Each point pair (q, k) contributes Var [0.5 ||q||^2 - ] = 9 / 2 + point_variance = max(num_point_qk, 1) * 9. / 2 + point_weights = np.sqrt(1.0 / point_variance) + + # This is equivalent to jax.nn.softplus, but avoids a bug in the test... + softplus = lambda x: jnp.logaddexp(x, jnp.zeros_like(x)) + raw_point_weights = hk.get_parameter( + 'trainable_point_weights', + shape=[num_head], + # softplus^{-1} (1) + init=hk.initializers.Constant(np.log(np.exp(1.) - 1.))) + + # Trainable per-head weights for points. + trainable_point_weights = softplus(raw_point_weights) + point_weights *= trainable_point_weights + q_point = PointProjection([num_head, num_point_qk], + self.global_config, + name='q_point_projection')(inputs_1d, + rigid) + + k_point = PointProjection([num_head, num_point_qk], + self.global_config, + name='k_point_projection')(inputs_1d, + rigid) + + dist2 = geometry.square_euclidean_distance( + q_point[:, None, :, :], k_point[None, :, :, :], epsilon=0.) + attn_qk_point = -0.5 * jnp.sum(point_weights[:, None] * dist2, axis=-1) + attn_logits += attn_qk_point + + num_scalar_qk = self.config.num_scalar_qk + # We assume that all queries and keys come iid from N(0, 1) distribution + # and compute the variances of the attention logits. + # Each scalar pair (q, k) contributes Var q*k = 1 + scalar_variance = max(num_scalar_qk, 1) * 1. + scalar_weights = np.sqrt(1.0 / scalar_variance) + q_scalar = common_modules.Linear([num_head, num_scalar_qk], + use_bias=False, + name='q_scalar_projection')( + inputs_1d) + + k_scalar = common_modules.Linear([num_head, num_scalar_qk], + use_bias=False, + name='k_scalar_projection')( + inputs_1d) + q_scalar *= scalar_weights + attn_logits += jnp.einsum('qhc,khc->qkh', q_scalar, k_scalar) + + attention_2d = common_modules.Linear( + num_head, name='attention_2d')(inputs_2d) + attn_logits += attention_2d + + mask_2d = mask * jnp.swapaxes(mask, -1, -2) + attn_logits -= 1e5 * (1. - mask_2d[..., None]) + + attn_logits *= np.sqrt(1. / 3) # Normalize by number of logit terms (3) + attn = jax.nn.softmax(attn_logits, axis=-2) + + num_scalar_v = self.config.num_scalar_v + + v_scalar = common_modules.Linear([num_head, num_scalar_v], + use_bias=False, + name='v_scalar_projection')( + inputs_1d) + + # [num_query_residues, num_head, num_scalar_v] + result_scalar = jnp.einsum('qkh, khc->qhc', attn, v_scalar) + + num_point_v = self.config.num_point_v + v_point = PointProjection([num_head, num_point_v], + self.global_config, + name='v_point_projection')(inputs_1d, + rigid) + + result_point_global = jax.tree_map( + lambda x: jnp.sum(attn[..., None] * x, axis=-3), v_point[None]) + + # Features used in the linear output projection. Should have the size + # [num_query_residues, ?] + output_features = [] + num_query_residues, _ = inputs_1d.shape + + flat_shape = [num_query_residues, -1] + + result_scalar = jnp.reshape(result_scalar, flat_shape) + output_features.append(result_scalar) + + result_point_global = jax.tree_map(lambda r: jnp.reshape(r, flat_shape), + result_point_global) + result_point_local = rigid[..., None].apply_inverse_to_point( + result_point_global) + output_features.extend( + [result_point_local.x, result_point_local.y, result_point_local.z]) + + point_norms = result_point_local.norm(self._dist_epsilon) + output_features.append(point_norms) + + # Dimensions: h = heads, i and j = residues, + # c = inputs_2d channels + # Contraction happens over the second residue dimension, similarly to how + # the usual attention is performed. + result_attention_over_2d = jnp.einsum('ijh, ijc->ihc', attn, inputs_2d) + output_features.append(jnp.reshape(result_attention_over_2d, flat_shape)) + + final_init = 'zeros' if self._zero_initialize_last else 'linear' + + final_act = jnp.concatenate(output_features, axis=-1) + + return common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='output_projection')(final_act) + + +class FoldIteration(hk.Module): + """A single iteration of iterative folding. + + First, each residue attends to all residues using InvariantPointAttention. + Then, we apply transition layers to update the hidden representations. + Finally, we use the hidden representations to produce an update to the + affine of each residue. + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'fold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__( + self, + activations: Mapping[str, Any], + aatype: jnp.ndarray, + sequence_mask: jnp.ndarray, + update_rigid: bool, + is_training: bool, + initial_act: jnp.ndarray, + safe_key: Optional[prng.SafeKey] = None, + static_feat_2d: Optional[jnp.ndarray] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + + c = self.config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + def safe_dropout_fn(tensor, safe_key): + return modules.apply_dropout( + tensor=tensor, + safe_key=safe_key, + rate=0.0 if self.global_config.deterministic else c.dropout, + is_training=is_training) + + rigid = activations['rigid'] + + act = activations['act'] + attention_module = InvariantPointAttention( + self.config, self.global_config) + # Attention + act += attention_module( + inputs_1d=act, + inputs_2d=static_feat_2d, + mask=sequence_mask, + rigid=rigid) + + safe_key, *sub_keys = safe_key.split(3) + sub_keys = iter(sub_keys) + act = safe_dropout_fn(act, next(sub_keys)) + act = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='attention_layer_norm')( + act) + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Transition + input_act = act + for i in range(c.num_layer_in_transition): + init = 'relu' if i < c.num_layer_in_transition - 1 else final_init + act = common_modules.Linear( + c.num_channel, + initializer=init, + name='transition')( + act) + if i < c.num_layer_in_transition - 1: + act = jax.nn.relu(act) + act += input_act + act = safe_dropout_fn(act, next(sub_keys)) + act = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='transition_layer_norm')(act) + if update_rigid: + # Rigid update + rigid_update = QuatRigid( + self.global_config, init=final_init)( + act) + rigid = rigid @ rigid_update + + sc = MultiRigidSidechain(c.sidechain, self.global_config)( + rigid.scale_translation(c.position_scale), [act, initial_act], aatype) + + outputs = {'rigid': rigid, 'sc': sc} + + rotation = jax.tree_map(jax.lax.stop_gradient, rigid.rotation) + rigid = geometry.Rigid3Array(rotation, rigid.translation) + + new_activations = { + 'act': act, + 'rigid': rigid + } + return new_activations, outputs + + +def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray], + batch: Mapping[str, jnp.ndarray], + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + is_training: bool, + safe_key: prng.SafeKey + ) -> Dict[str, Any]: + """Generate predicted Rigid's for a single chain. + + This is the main part of the iterative fold head - it iteratively applies + folding to produce a set of predicted residue positions. + + Args: + representations: Embeddings dictionary. + batch: Batch dictionary. + config: config for the iterative fold head. + global_config: global config. + is_training: is training. + safe_key: A prng.SafeKey object that wraps a PRNG key. + + Returns: + A dictionary containing residue Rigid's and sidechain positions. + """ + c = config + sequence_mask = batch['seq_mask'][:, None] + act = hk.LayerNorm( + axis=-1, create_scale=True, create_offset=True, name='single_layer_norm')( + representations['single']) + + initial_act = act + act = common_modules.Linear( + c.num_channel, name='initial_projection')(act) + + # Sequence Mask has extra 1 at the end. + rigid = geometry.Rigid3Array.identity(sequence_mask.shape[:-1]) + + fold_iteration = FoldIteration( + c, global_config, name='fold_iteration') + + assert len(batch['seq_mask'].shape) == 1 + + activations = { + 'act': + act, + 'rigid': + rigid + } + + act_2d = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name='pair_layer_norm')( + representations['pair']) + + safe_keys = safe_key.split(c.num_layer) + outputs = [] + for key in safe_keys: + + activations, output = fold_iteration( + activations, + initial_act=initial_act, + static_feat_2d=act_2d, + aatype=batch['aatype'], + safe_key=key, + sequence_mask=sequence_mask, + update_rigid=True, + is_training=is_training, + ) + outputs.append(output) + + output = jax.tree_multimap(lambda *x: jnp.stack(x), *outputs) + # Pass along for LDDT-Head. + output['act'] = activations['act'] + + return output + + +class StructureModule(hk.Module): + """StructureModule as a network head. + + Jumper et al. (2021) Suppl. Alg. 20 "StructureModule" + """ + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'structure_module'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + representations: Mapping[str, jnp.ndarray], + batch: Mapping[str, Any], + is_training: bool, + safe_key: Optional[prng.SafeKey] = None, + compute_loss: bool = False + ) -> Dict[str, Any]: + c = self.config + ret = {} + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = generate_monomer_rigids( + representations=representations, + batch=batch, + config=self.config, + global_config=self.global_config, + is_training=is_training, + safe_key=safe_key) + + ret['traj'] = output['rigid'].scale_translation(c.position_scale).to_array() + ret['sidechains'] = output['sc'] + ret['sidechains']['atom_pos'] = ret['sidechains']['atom_pos'].to_array() + ret['sidechains']['frames'] = ret['sidechains']['frames'].to_array() + if 'local_atom_pos' in ret['sidechains']: + ret['sidechains']['local_atom_pos'] = ret['sidechains'][ + 'local_atom_pos'].to_array() + ret['sidechains']['local_frames'] = ret['sidechains'][ + 'local_frames'].to_array() + + aatype = batch['aatype'] + seq_mask = batch['seq_mask'] + + atom14_pred_mask = all_atom_multimer.get_atom14_mask( + aatype) * seq_mask[:, None] + atom14_pred_positions = output['sc']['atom_pos'][-1] + ret['final_atom14_positions'] = atom14_pred_positions # (N, 14, 3) + ret['final_atom14_mask'] = atom14_pred_mask # (N, 14) + + atom37_mask = all_atom_multimer.get_atom37_mask(aatype) * seq_mask[:, None] + atom37_pred_positions = all_atom_multimer.atom14_to_atom37( + atom14_pred_positions, aatype) + atom37_pred_positions *= atom37_mask[:, :, None] + ret['final_atom_positions'] = atom37_pred_positions # (N, 37, 3) + ret['final_atom_mask'] = atom37_mask # (N, 37) + ret['final_rigids'] = ret['traj'][-1] + + ret['act'] = output['act'] + + if compute_loss: + return ret + else: + no_loss_features = ['final_atom_positions', 'final_atom_mask', 'act'] + no_loss_ret = {k: ret[k] for k in no_loss_features} + return no_loss_ret + + def loss(self, + value: Mapping[str, Any], + batch: Mapping[str, Any] + ) -> Dict[str, Any]: + + raise NotImplementedError( + 'This function should be called on a batch with reordered chains (see ' + 'Evans et al (2021) Section 7.3. Multi-Chain Permutation Alignment.') + + ret = {'loss': 0.} + + ret['metrics'] = {} + + aatype = batch['aatype'] + all_atom_positions = batch['all_atom_positions'] + all_atom_positions = geometry.Vec3Array.from_array(all_atom_positions) + all_atom_mask = batch['all_atom_mask'] + seq_mask = batch['seq_mask'] + residue_index = batch['residue_index'] + + gt_rigid, gt_affine_mask = make_backbone_affine(all_atom_positions, + all_atom_mask, + aatype) + + chi_angles, chi_mask = all_atom_multimer.compute_chi_angles( + all_atom_positions, all_atom_mask, aatype) + + pred_mask = all_atom_multimer.get_atom14_mask(aatype) + pred_mask *= seq_mask[:, None] + pred_positions = value['final_atom14_positions'] + pred_positions = geometry.Vec3Array.from_array(pred_positions) + + gt_positions, gt_mask, alt_naming_is_better = compute_atom14_gt( + aatype, all_atom_positions, all_atom_mask, pred_positions) + + violations = find_structural_violations( + aatype=aatype, + residue_index=residue_index, + mask=pred_mask, + pred_positions=pred_positions, + config=self.config) + + sidechains = value['sidechains'] + + gt_chi_angles = get_renamed_chi_angles(aatype, chi_angles, + alt_naming_is_better) + + # Several violation metrics: + violation_metrics = compute_violation_metrics( + residue_index=residue_index, + mask=pred_mask, + seq_mask=seq_mask, + pred_positions=pred_positions, + violations=violations) + ret['metrics'].update(violation_metrics) + + target_rigid = geometry.Rigid3Array.from_array(value['traj']) + gt_frames_mask = gt_affine_mask + + # Split the loss into within-chain and between-chain components. + intra_chain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + intra_chain_bb_loss, intra_chain_fape = backbone_loss( + gt_rigid=gt_rigid, + gt_frames_mask=gt_frames_mask, + gt_positions_mask=gt_affine_mask, + target_rigid=target_rigid, + config=self.config.intra_chain_fape, + pair_mask=intra_chain_mask) + interface_bb_loss, interface_fape = backbone_loss( + gt_rigid=gt_rigid, + gt_frames_mask=gt_frames_mask, + gt_positions_mask=gt_affine_mask, + target_rigid=target_rigid, + config=self.config.interface_fape, + pair_mask=1. - intra_chain_mask) + + bb_loss = intra_chain_bb_loss + interface_bb_loss + ret['fape'] = intra_chain_fape + interface_fape + ret['bb_loss'] = bb_loss + ret['loss'] += bb_loss + + pred_frames = geometry.Rigid3Array.from_array(sidechains['frames']) + pred_positions = geometry.Vec3Array.from_array(sidechains['atom_pos']) + gt_sc_frames, gt_sc_frames_mask = compute_frames( + aatype=aatype, + all_atom_positions=all_atom_positions, + all_atom_mask=all_atom_mask, + use_alt=alt_naming_is_better) + + sc_loss = sidechain_loss( + gt_frames=gt_sc_frames, + gt_frames_mask=gt_sc_frames_mask, + gt_positions=gt_positions, + gt_mask=gt_mask, + pred_frames=pred_frames, + pred_positions=pred_positions, + config=self.config) + + ret['loss'] = ((1 - self.config.sidechain.weight_frac) * ret['loss'] + + self.config.sidechain.weight_frac * sc_loss['loss']) + ret['sidechain_fape'] = sc_loss['fape'] + + unnormed_angles = sidechains['unnormalized_angles_sin_cos'] + pred_angles = sidechains['angles_sin_cos'] + + sup_chi_loss, ret['chi_loss'], ret[ + 'angle_norm_loss'] = supervised_chi_loss( + sequence_mask=seq_mask, + target_chi_mask=chi_mask, + target_chi_angles=gt_chi_angles, + aatype=aatype, + pred_angles=pred_angles, + unnormed_angles=unnormed_angles, + config=self.config) + ret['loss'] += sup_chi_loss + + if self.config.structural_violation_loss_weight: + + ret['loss'] += structural_violation_loss( + mask=pred_mask, violations=violations, config=self.config) + + return ret + + +def compute_atom14_gt( + aatype: jnp.ndarray, + all_atom_positions: geometry.Vec3Array, + all_atom_mask: jnp.ndarray, + pred_pos: geometry.Vec3Array +) -> Tuple[geometry.Vec3Array, jnp.ndarray, jnp.ndarray]: + """Find atom14 positions, this includes finding the correct renaming.""" + gt_positions, gt_mask = all_atom_multimer.atom37_to_atom14( + aatype, all_atom_positions, + all_atom_mask) + alt_gt_positions, alt_gt_mask = all_atom_multimer.get_alt_atom14( + aatype, gt_positions, gt_mask) + atom_is_ambiguous = all_atom_multimer.get_atom14_is_ambiguous(aatype) + + alt_naming_is_better = all_atom_multimer.find_optimal_renaming( + gt_positions=gt_positions, + alt_gt_positions=alt_gt_positions, + atom_is_ambiguous=atom_is_ambiguous, + gt_exists=gt_mask, + pred_positions=pred_pos) + + use_alt = alt_naming_is_better[:, None] + + gt_mask = (1. - use_alt) * gt_mask + use_alt * alt_gt_mask + gt_positions = (1. - use_alt) * gt_positions + use_alt * alt_gt_positions + + return gt_positions, alt_gt_mask, alt_naming_is_better + + +def backbone_loss(gt_rigid: geometry.Rigid3Array, + gt_frames_mask: jnp.ndarray, + gt_positions_mask: jnp.ndarray, + target_rigid: geometry.Rigid3Array, + config: ml_collections.ConfigDict, + pair_mask: jnp.ndarray + ) -> Tuple[Float, jnp.ndarray]: + """Backbone FAPE Loss.""" + loss_fn = functools.partial( + all_atom_multimer.frame_aligned_point_error, + l1_clamp_distance=config.atom_clamp_distance, + loss_unit_distance=config.loss_unit_distance) + + loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None)) + fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask, + target_rigid.translation, gt_rigid.translation, + gt_positions_mask, pair_mask) + + return jnp.mean(fape), fape[-1] + + +def compute_frames( + aatype: jnp.ndarray, + all_atom_positions: geometry.Vec3Array, + all_atom_mask: jnp.ndarray, + use_alt: jnp.ndarray + ) -> Tuple[geometry.Rigid3Array, jnp.ndarray]: + """Compute Frames from all atom positions. + + Args: + aatype: array of aatypes, int of [N] + all_atom_positions: Vector of all atom positions, shape [N, 37] + all_atom_mask: mask, shape [N] + use_alt: whether to use alternative orientation for ambiguous aatypes + shape [N] + Returns: + Rigid corresponding to Frames w shape [N, 8], + mask which Rigids are present w shape [N, 8] + """ + frames_batch = all_atom_multimer.atom37_to_frames(aatype, all_atom_positions, + all_atom_mask) + gt_frames = frames_batch['rigidgroups_gt_frames'] + alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames'] + use_alt = use_alt[:, None] + + renamed_gt_frames = jax.tree_multimap( + lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames) + + return renamed_gt_frames, frames_batch['rigidgroups_gt_exists'] + + +def sidechain_loss(gt_frames: geometry.Rigid3Array, + gt_frames_mask: jnp.ndarray, + gt_positions: geometry.Vec3Array, + gt_mask: jnp.ndarray, + pred_frames: geometry.Rigid3Array, + pred_positions: geometry.Vec3Array, + config: ml_collections.ConfigDict + ) -> Dict[str, jnp.ndarray]: + """Sidechain Loss using cleaned up rigids.""" + + flat_gt_frames = jax.tree_map(jnp.ravel, gt_frames) + flat_frames_mask = jnp.ravel(gt_frames_mask) + + flat_gt_positions = jax.tree_map(jnp.ravel, gt_positions) + flat_positions_mask = jnp.ravel(gt_mask) + + # Compute frame_aligned_point_error score for the final layer. + def _slice_last_layer_and_flatten(x): + return jnp.ravel(x[-1]) + + flat_pred_frames = jax.tree_map(_slice_last_layer_and_flatten, pred_frames) + flat_pred_positions = jax.tree_map(_slice_last_layer_and_flatten, + pred_positions) + fape = all_atom_multimer.frame_aligned_point_error( + pred_frames=flat_pred_frames, + target_frames=flat_gt_frames, + frames_mask=flat_frames_mask, + pred_positions=flat_pred_positions, + target_positions=flat_gt_positions, + positions_mask=flat_positions_mask, + pair_mask=None, + length_scale=config.sidechain.loss_unit_distance, + l1_clamp_distance=config.sidechain.atom_clamp_distance) + + return { + 'fape': fape, + 'loss': fape} + + +def structural_violation_loss(mask: jnp.ndarray, + violations: Mapping[str, Float], + config: ml_collections.ConfigDict + ) -> Float: + """Computes Loss for structural Violations.""" + # Put all violation losses together to one large loss. + num_atoms = jnp.sum(mask).astype(jnp.float32) + 1e-6 + between_residues = violations['between_residues'] + within_residues = violations['within_residues'] + return (config.structural_violation_loss_weight * + (between_residues['bonds_c_n_loss_mean'] + + between_residues['angles_ca_c_n_loss_mean'] + + between_residues['angles_c_n_ca_loss_mean'] + + jnp.sum(between_residues['clashes_per_atom_loss_sum'] + + within_residues['per_atom_loss_sum']) / num_atoms + )) + + +def find_structural_violations( + aatype: jnp.ndarray, + residue_index: jnp.ndarray, + mask: jnp.ndarray, + pred_positions: geometry.Vec3Array, # (N, 14) + config: ml_collections.ConfigDict + ) -> Dict[str, Any]: + """Computes several checks for structural Violations.""" + + # Compute between residue backbone violations of bonds and angles. + connection_violations = all_atom_multimer.between_residue_bond_loss( + pred_atom_positions=pred_positions, + pred_atom_mask=mask.astype(jnp.float32), + residue_index=residue_index.astype(jnp.float32), + aatype=aatype, + tolerance_factor_soft=config.violation_tolerance_factor, + tolerance_factor_hard=config.violation_tolerance_factor) + + # Compute the van der Waals radius for every atom + # (the first letter of the atom name is the element type). + # shape (N, 14) + atomtype_radius = jnp.array([ + residue_constants.van_der_waals_radius[name[0]] + for name in residue_constants.atom_types + ]) + residx_atom14_to_atom37 = all_atom_multimer.get_atom14_to_atom37_map(aatype) + atom_radius = mask * utils.batched_gather(atomtype_radius, + residx_atom14_to_atom37) + + # Compute the between residue clash loss. + between_residue_clashes = all_atom_multimer.between_residue_clash_loss( + pred_positions=pred_positions, + atom_exists=mask, + atom_radius=atom_radius, + residue_index=residue_index, + overlap_tolerance_soft=config.clash_overlap_tolerance, + overlap_tolerance_hard=config.clash_overlap_tolerance) + + # Compute all within-residue violations (clashes, + # bond length and angle violations). + restype_atom14_bounds = residue_constants.make_atom14_dists_bounds( + overlap_tolerance=config.clash_overlap_tolerance, + bond_length_tolerance_factor=config.violation_tolerance_factor) + dists_lower_bound = utils.batched_gather(restype_atom14_bounds['lower_bound'], + aatype) + dists_upper_bound = utils.batched_gather(restype_atom14_bounds['upper_bound'], + aatype) + within_residue_violations = all_atom_multimer.within_residue_violations( + pred_positions=pred_positions, + atom_exists=mask, + dists_lower_bound=dists_lower_bound, + dists_upper_bound=dists_upper_bound, + tighten_bounds_for_loss=0.0) + + # Combine them to a single per-residue violation mask (used later for LDDT). + per_residue_violations_mask = jnp.max(jnp.stack([ + connection_violations['per_residue_violation_mask'], + jnp.max(between_residue_clashes['per_atom_clash_mask'], axis=-1), + jnp.max(within_residue_violations['per_atom_violations'], + axis=-1)]), axis=0) + + return { + 'between_residues': { + 'bonds_c_n_loss_mean': + connection_violations['c_n_loss_mean'], # () + 'angles_ca_c_n_loss_mean': + connection_violations['ca_c_n_loss_mean'], # () + 'angles_c_n_ca_loss_mean': + connection_violations['c_n_ca_loss_mean'], # () + 'connections_per_residue_loss_sum': + connection_violations['per_residue_loss_sum'], # (N) + 'connections_per_residue_violation_mask': + connection_violations['per_residue_violation_mask'], # (N) + 'clashes_mean_loss': + between_residue_clashes['mean_loss'], # () + 'clashes_per_atom_loss_sum': + between_residue_clashes['per_atom_loss_sum'], # (N, 14) + 'clashes_per_atom_clash_mask': + between_residue_clashes['per_atom_clash_mask'], # (N, 14) + }, + 'within_residues': { + 'per_atom_loss_sum': + within_residue_violations['per_atom_loss_sum'], # (N, 14) + 'per_atom_violations': + within_residue_violations['per_atom_violations'], # (N, 14), + }, + 'total_per_residue_violations_mask': + per_residue_violations_mask, # (N) + } + + +def compute_violation_metrics( + residue_index: jnp.ndarray, + mask: jnp.ndarray, + seq_mask: jnp.ndarray, + pred_positions: geometry.Vec3Array, # (N, 14) + violations: Mapping[str, jnp.ndarray], +) -> Dict[str, jnp.ndarray]: + """Compute several metrics to assess the structural violations.""" + ret = {} + between_residues = violations['between_residues'] + within_residues = violations['within_residues'] + extreme_ca_ca_violations = all_atom_multimer.extreme_ca_ca_distance_violations( + positions=pred_positions, + mask=mask.astype(jnp.float32), + residue_index=residue_index.astype(jnp.float32)) + ret['violations_extreme_ca_ca_distance'] = extreme_ca_ca_violations + ret['violations_between_residue_bond'] = utils.mask_mean( + mask=seq_mask, + value=between_residues['connections_per_residue_violation_mask']) + ret['violations_between_residue_clash'] = utils.mask_mean( + mask=seq_mask, + value=jnp.max(between_residues['clashes_per_atom_clash_mask'], axis=-1)) + ret['violations_within_residue'] = utils.mask_mean( + mask=seq_mask, + value=jnp.max(within_residues['per_atom_violations'], axis=-1)) + ret['violations_per_residue'] = utils.mask_mean( + mask=seq_mask, value=violations['total_per_residue_violations_mask']) + return ret + + +def supervised_chi_loss( + sequence_mask: jnp.ndarray, + target_chi_mask: jnp.ndarray, + aatype: jnp.ndarray, + target_chi_angles: jnp.ndarray, + pred_angles: jnp.ndarray, + unnormed_angles: jnp.ndarray, + config: ml_collections.ConfigDict) -> Tuple[Float, Float, Float]: + """Computes loss for direct chi angle supervision.""" + eps = 1e-6 + chi_mask = target_chi_mask.astype(jnp.float32) + + pred_angles = pred_angles[:, :, 3:] + + residue_type_one_hot = jax.nn.one_hot( + aatype, residue_constants.restype_num + 1, dtype=jnp.float32)[None] + chi_pi_periodic = jnp.einsum('ijk, kl->ijl', residue_type_one_hot, + jnp.asarray(residue_constants.chi_pi_periodic)) + + true_chi = target_chi_angles[None] + sin_true_chi = jnp.sin(true_chi) + cos_true_chi = jnp.cos(true_chi) + sin_cos_true_chi = jnp.stack([sin_true_chi, cos_true_chi], axis=-1) + + # This is -1 if chi is pi periodic and +1 if it's 2 pi periodic + shifted_mask = (1 - 2 * chi_pi_periodic)[..., None] + sin_cos_true_chi_shifted = shifted_mask * sin_cos_true_chi + + sq_chi_error = jnp.sum( + squared_difference(sin_cos_true_chi, pred_angles), -1) + sq_chi_error_shifted = jnp.sum( + squared_difference(sin_cos_true_chi_shifted, pred_angles), -1) + sq_chi_error = jnp.minimum(sq_chi_error, sq_chi_error_shifted) + + sq_chi_loss = utils.mask_mean(mask=chi_mask[None], value=sq_chi_error) + angle_norm = jnp.sqrt(jnp.sum(jnp.square(unnormed_angles), axis=-1) + eps) + norm_error = jnp.abs(angle_norm - 1.) + angle_norm_loss = utils.mask_mean(mask=sequence_mask[None, :, None], + value=norm_error) + loss = (config.chi_weight * sq_chi_loss + + config.angle_norm_weight * angle_norm_loss) + return loss, sq_chi_loss, angle_norm_loss + + +def l2_normalize(x: jnp.ndarray, + axis: int = -1, + epsilon: float = 1e-12 + ) -> jnp.ndarray: + return x / jnp.sqrt( + jnp.maximum(jnp.sum(x**2, axis=axis, keepdims=True), epsilon)) + + +def get_renamed_chi_angles(aatype: jnp.ndarray, + chi_angles: jnp.ndarray, + alt_is_better: jnp.ndarray + ) -> jnp.ndarray: + """Return renamed chi angles.""" + chi_angle_is_ambiguous = utils.batched_gather( + jnp.array(residue_constants.chi_pi_periodic, dtype=jnp.float32), aatype) + alt_chi_angles = chi_angles + np.pi * chi_angle_is_ambiguous + # Map back to [-pi, pi]. + alt_chi_angles = alt_chi_angles - 2 * np.pi * (alt_chi_angles > np.pi).astype( + jnp.float32) + alt_is_better = alt_is_better[:, None] + return (1. - alt_is_better) * chi_angles + alt_is_better * alt_chi_angles + + +class MultiRigidSidechain(hk.Module): + """Class to make side chain atoms.""" + + def __init__(self, + config: ml_collections.ConfigDict, + global_config: ml_collections.ConfigDict, + name: str = 'rigid_sidechain'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + rigid: geometry.Rigid3Array, + representations_list: Iterable[jnp.ndarray], + aatype: jnp.ndarray + ) -> Dict[str, Any]: + """Predict sidechains using multi-rigid representations. + + Args: + rigid: The Rigid's for each residue (translations in angstoms) + representations_list: A list of activations to predict sidechains from. + aatype: amino acid types. + + Returns: + dict containing atom positions and frames (in angstrom) + """ + act = [ + common_modules.Linear( # pylint: disable=g-complex-comprehension + self.config.num_channel, + name='input_projection')(jax.nn.relu(x)) + for x in representations_list] + # Sum the activation list (equivalent to concat then Conv1D) + act = sum(act) + + final_init = 'zeros' if self.global_config.zero_init else 'linear' + + # Mapping with some residual blocks. + for _ in range(self.config.num_residual_block): + old_act = act + act = common_modules.Linear( + self.config.num_channel, + initializer='relu', + name='resblock1')( + jax.nn.relu(act)) + act = common_modules.Linear( + self.config.num_channel, + initializer=final_init, + name='resblock2')( + jax.nn.relu(act)) + act += old_act + + # Map activations to torsion angles. + # [batch_size, num_res, 14] + num_res = act.shape[0] + unnormalized_angles = common_modules.Linear( + 14, name='unnormalized_angles')( + jax.nn.relu(act)) + unnormalized_angles = jnp.reshape( + unnormalized_angles, [num_res, 7, 2]) + angles = l2_normalize(unnormalized_angles, axis=-1) + + outputs = { + 'angles_sin_cos': angles, # jnp.ndarray (N, 7, 2) + 'unnormalized_angles_sin_cos': + unnormalized_angles, # jnp.ndarray (N, 7, 2) + } + + # Map torsion angles to frames. + # geometry.Rigid3Array with shape (N, 8) + all_frames_to_global = all_atom_multimer.torsion_angles_to_frames( + aatype, + rigid, + angles) + + # Use frames and literature positions to create the final atom coordinates. + # geometry.Vec3Array with shape (N, 14) + pred_positions = all_atom_multimer.frames_and_literature_positions_to_atom14_pos( + aatype, all_frames_to_global) + + outputs.update({ + 'atom_pos': pred_positions, # geometry.Vec3Array (N, 14) + 'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8) + }) + return outputs + diff --git a/alphafold/model/geometry/__init__.py b/alphafold/model/geometry/__init__.py new file mode 100644 index 000000000..671d07eed --- /dev/null +++ b/alphafold/model/geometry/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Geometry Module.""" + +from alphafold.model.geometry import rigid_matrix_vector +from alphafold.model.geometry import rotation_matrix +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +StructOfArray = struct_of_array.StructOfArray + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/alphafold/model/geometry/rigid_matrix_vector.py b/alphafold/model/geometry/rigid_matrix_vector.py new file mode 100644 index 000000000..299f64017 --- /dev/null +++ b/alphafold/model/geometry/rigid_matrix_vector.py @@ -0,0 +1,106 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from __future__ import annotations +from typing import Union + +from alphafold.model.geometry import rotation_matrix +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import vector +import jax +import jax.numpy as jnp + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: Rigid3Array) -> Rigid3Array: + new_rotation = self.rotation @ other.rotation + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def inverse(self) -> Rigid3Array: + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation): + rot = self.rotation @ other_rotation + trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), + self.translation) + return Rigid3Array(rot, trans) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array: + """Return identity Rigid3Array of given shape.""" + return cls( + rotation_matrix.Rot3Array.identity(shape, dtype=dtype), + vector.Vec3Array.zeros(shape, dtype=dtype)) + + def scale_translation(self, factor: Float) -> Rigid3Array: + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_array(self): + rot_array = self.rotation.to_array() + vec_array = self.translation.to_array() + return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) + vec = vector.Vec3Array.from_array(array[..., -1]) + return cls(rot, vec) + + @classmethod + def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: + """Construct Rigid3Array from homogeneous 4x4 array.""" + assert array.shape[-1] == 4 + assert array.shape[-2] == 4 + rotation = rotation_matrix.Rot3Array( + array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], + array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], + array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) + return cls(rotation, translation) + + def __getstate__(self): + return (VERSION, (self.rotation, self.translation)) + + def __setstate__(self, state): + version, (rot, trans) = state + del version + object.__setattr__(self, 'rotation', rot) + object.__setattr__(self, 'translation', trans) diff --git a/alphafold/model/geometry/rotation_matrix.py b/alphafold/model/geometry/rotation_matrix.py new file mode 100644 index 000000000..322232994 --- /dev/null +++ b/alphafold/model/geometry/rotation_matrix.py @@ -0,0 +1,157 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rot3Array Matrix Class.""" + +from __future__ import annotations +import dataclasses + +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import utils +from alphafold.model.geometry import vector +import jax +import jax.numpy as jnp +import numpy as np + +COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + + xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + xy: jnp.ndarray + xz: jnp.ndarray + yx: jnp.ndarray + yy: jnp.ndarray + yz: jnp.ndarray + zx: jnp.ndarray + zy: jnp.ndarray + zz: jnp.ndarray + + __array_ufunc__ = None + + def inverse(self) -> Rot3Array: + """Returns inverse of Rot3Array.""" + return Rot3Array(self.xx, self.yx, self.zx, + self.xy, self.yy, self.zy, + self.xz, self.yz, self.zz) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + return vector.Vec3Array( + self.xx * point.x + self.xy * point.y + self.xz * point.z, + self.yx * point.x + self.yy * point.y + self.yz * point.z, + self.zx * point.x + self.zy * point.y + self.zz * point.z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def __matmul__(self, other: Rot3Array) -> Rot3Array: + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + @classmethod + def identity(cls, shape, dtype=jnp.float32) -> Rot3Array: + """Returns identity of given shape.""" + ones = jnp.ones(shape, dtype=dtype) + zeros = jnp.zeros(shape, dtype=dtype) + return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, + e1: vector.Vec3Array) -> Rot3Array: + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - c * e0).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: jnp.ndarray) -> Rot3Array: + """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" + unstacked = utils.unstack(array, axis=-2) + unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) + return cls(*unstacked) + + def to_array(self) -> jnp.ndarray: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return jnp.stack( + [jnp.stack([self.xx, self.xy, self.xz], axis=-1), + jnp.stack([self.yx, self.yy, self.yz], axis=-1), + jnp.stack([self.zx, self.zy, self.zz], axis=-1)], + axis=-2) + + @classmethod + def from_quaternion(cls, + w: jnp.ndarray, + x: jnp.ndarray, + y: jnp.ndarray, + z: jnp.ndarray, + normalize: bool = True, + epsilon: float = 1e-6) -> Rot3Array: + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (jnp.square(y) + jnp.square(z)) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (jnp.square(x) + jnp.square(z)) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + @classmethod + def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: + """Samples uniform random Rot3Array according to Haar Measure.""" + quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype) + quats = utils.unstack(quat_array) + return cls.from_quaternion(*quats) + + def __getstate__(self): + return (VERSION, + [np.asarray(getattr(self, field)) for field in COMPONENTS]) + + def __setstate__(self, state): + version, state = state + del version + for i, field in enumerate(COMPONENTS): + object.__setattr__(self, field, state[i]) diff --git a/alphafold/model/geometry/struct_of_array.py b/alphafold/model/geometry/struct_of_array.py new file mode 100644 index 000000000..97a89fd4a --- /dev/null +++ b/alphafold/model/geometry/struct_of_array.py @@ -0,0 +1,220 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class decorator to represent (nested) struct of arrays.""" + +import dataclasses + +import jax + + +def get_item(instance, key): + sliced = {} + for field in get_array_fields(instance): + num_trailing_dims = field.metadata.get('num_trailing_dims', 0) + this_key = key + if isinstance(key, tuple) and Ellipsis in this_key: + this_key += (slice(None),) * num_trailing_dims + sliced[field.name] = getattr(instance, field.name)[this_key] + return dataclasses.replace(instance, **sliced) + + +@property +def get_shape(instance): + """Returns Shape for given instance of dataclass.""" + first_field = dataclasses.fields(instance)[0] + num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) + value = getattr(instance, first_field.name) + if num_trailing_dims: + return value.shape[:-num_trailing_dims] + else: + return value.shape + + +def get_len(instance): + """Returns length for given instance of dataclass.""" + shape = instance.shape + if shape: + return shape[0] + else: + raise TypeError('len() of unsized object') # Match jax.numpy behavior. + + +@property +def get_dtype(instance): + """Returns Dtype for given instance of dataclass.""" + fields = dataclasses.fields(instance) + sets_dtype = [ + field.name for field in fields if field.metadata.get('sets_dtype', False) + ] + if sets_dtype: + assert len(sets_dtype) == 1, 'at most field can set dtype' + field_value = getattr(instance, sets_dtype[0]) + elif instance.same_dtype: + field_value = getattr(instance, fields[0].name) + else: + # Should this be Value Error? + raise AttributeError('Trying to access Dtype on Struct of Array without' + 'either "same_dtype" or field setting dtype') + + if hasattr(field_value, 'dtype'): + return field_value.dtype + else: + # Should this be Value Error? + raise AttributeError(f'field_value {field_value} does not have dtype') + + +def replace(instance, **kwargs): + return dataclasses.replace(instance, **kwargs) + + +def post_init(instance): + """Validate instance has same shapes & dtypes.""" + array_fields = get_array_fields(instance) + arrays = list(get_array_fields(instance, return_values=True).values()) + first_field = array_fields[0] + # These slightly weird constructions about checking whether the leaves are + # actual arrays is since e.g. vmap internally relies on being able to + # construct pytree's with object() as leaves, this would break the checking + # as such we are only validating the object when the entries in the dataclass + # Are arrays or other dataclasses of arrays. + try: + dtype = instance.dtype + except AttributeError: + dtype = None + if dtype is not None: + first_shape = instance.shape + for array, field in zip(arrays, array_fields): + field_shape = array.shape + num_trailing_dims = field.metadata.get('num_trailing_dims', None) + if num_trailing_dims: + array_shape = array.shape + field_shape = array_shape[:-num_trailing_dims] + msg = (f'field {field} should have number of trailing dims' + ' {num_trailing_dims}') + assert len(array_shape) == len(first_shape) + num_trailing_dims, msg + else: + field_shape = array.shape + + shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't " + f"match shape {first_shape} of field {first_field}") + assert field_shape == first_shape, shape_msg + + field_dtype = array.dtype + + allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) + if allowed_metadata_dtypes: + msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' + assert field_dtype in allowed_metadata_dtypes, msg + + if 'dtype' in field.metadata: + target_dtype = field.metadata['dtype'] + else: + target_dtype = dtype + + msg = f'Dtype is {field_dtype} but must be {target_dtype}' + assert field_dtype == target_dtype, msg + + +def flatten(instance): + """Flatten Struct of Array instance.""" + array_likes = list(get_array_fields(instance, return_values=True).values()) + flat_array_likes = [] + inner_treedefs = [] + num_arrays = [] + for array_like in array_likes: + flat_array_like, inner_treedef = jax.tree_flatten(array_like) + inner_treedefs.append(inner_treedef) + flat_array_likes += flat_array_like + num_arrays.append(len(flat_array_like)) + metadata = get_metadata_fields(instance, return_values=True) + metadata = type(instance).metadata_cls(**metadata) + return flat_array_likes, (inner_treedefs, metadata, num_arrays) + + +def make_metadata_class(cls): + metadata_fields = get_fields(cls, + lambda x: x.metadata.get('is_metadata', False)) + metadata_cls = dataclasses.make_dataclass( + cls_name='Meta' + cls.__name__, + fields=[(field.name, field.type, field) for field in metadata_fields], + frozen=True, + eq=True) + return metadata_cls + + +def get_fields(cls_or_instance, filterfn, return_values=False): + fields = dataclasses.fields(cls_or_instance) + fields = [field for field in fields if filterfn(field)] + if return_values: + return { + field.name: getattr(cls_or_instance, field.name) for field in fields + } + else: + return fields + + +def get_array_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: not x.metadata.get('is_metadata', False), + return_values=return_values) + + +def get_metadata_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: x.metadata.get('is_metadata', False), + return_values=return_values) + + +class StructOfArray: + """Class Decorator for Struct Of Arrays.""" + + def __init__(self, same_dtype=True): + self.same_dtype = same_dtype + + def __call__(self, cls): + cls.__array_ufunc__ = None + cls.replace = replace + cls.same_dtype = self.same_dtype + cls.dtype = get_dtype + cls.shape = get_shape + cls.__len__ = get_len + cls.__getitem__ = get_item + cls.__post_init__ = post_init + new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args + # pytree claims to require metadata to be hashable, not sure why, + # But making derived dataclass that can just hold metadata + new_cls.metadata_cls = make_metadata_class(new_cls) + + def unflatten(aux, data): + inner_treedefs, metadata, num_arrays = aux + array_fields = [field.name for field in get_array_fields(new_cls)] + value_dict = {} + array_start = 0 + for num_array, inner_treedef, array_field in zip(num_arrays, + inner_treedefs, + array_fields): + value_dict[array_field] = jax.tree_unflatten( + inner_treedef, data[array_start:array_start + num_array]) + array_start += num_array + metadata_fields = get_metadata_fields(new_cls) + for field in metadata_fields: + value_dict[field.name] = getattr(metadata, field.name) + + return new_cls(**value_dict) + + jax.tree_util.register_pytree_node( + nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten) + return new_cls diff --git a/alphafold/model/geometry/test_utils.py b/alphafold/model/geometry/test_utils.py new file mode 100644 index 000000000..32a68400d --- /dev/null +++ b/alphafold/model/geometry/test_utils.py @@ -0,0 +1,98 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utils for tests.""" + +import dataclasses + +from alphafold.model.geometry import rigid_matrix_vector +from alphafold.model.geometry import rotation_matrix +from alphafold.model.geometry import vector +import jax.numpy as jnp +import numpy as np + + +def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, + matrix2: rotation_matrix.Rot3Array): + for field in dataclasses.fields(rotation_matrix.Rot3Array): + field = field.name + np.testing.assert_array_equal( + getattr(matrix1, field), getattr(matrix2, field)) + + +def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, + mat2: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) + + +def assert_array_equal_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + """Check that array and Matrix match.""" + np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) + np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) + np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) + np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) + np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) + np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) + np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) + np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) + np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) + + +def assert_array_close_to_rotation_matrix(array: jnp.ndarray, + matrix: rotation_matrix.Rot3Array): + np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) + + +def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_array_equal(vec1.x, vec2.x) + np.testing.assert_array_equal(vec1.y, vec2.y) + np.testing.assert_array_equal(vec1.z, vec2.z) + + +def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): + np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) + np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) + + +def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) + + +def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): + np.testing.assert_array_equal(vec.to_array(), array) + + +def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, + rigid2: rigid_matrix_vector.Rigid3Array): + assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) + + +def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_equal(rot, rigid.rotation) + assert_vectors_equal(trans, rigid.translation) + + +def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, + trans: vector.Vec3Array, + rigid: rigid_matrix_vector.Rigid3Array): + assert_rotation_matrix_close(rot, rigid.rotation) + assert_vectors_close(trans, rigid.translation) diff --git a/alphafold/model/geometry/utils.py b/alphafold/model/geometry/utils.py new file mode 100644 index 000000000..64c4a649d --- /dev/null +++ b/alphafold/model/geometry/utils.py @@ -0,0 +1,23 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for geometry library.""" + +from typing import List + +import jax.numpy as jnp + + +def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]: + return [jnp.squeeze(v, axis=axis) + for v in jnp.split(value, value.shape[axis], axis=axis)] diff --git a/alphafold/model/geometry/vector.py b/alphafold/model/geometry/vector.py new file mode 100644 index 000000000..99dcb50f7 --- /dev/null +++ b/alphafold/model/geometry/vector.py @@ -0,0 +1,217 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vec3Array Class.""" + +from __future__ import annotations +import dataclasses +from typing import Union + +from alphafold.model.geometry import struct_of_array +from alphafold.model.geometry import utils +import jax +import jax.numpy as jnp +import numpy as np + +Float = Union[float, jnp.ndarray] + +VERSION = '0.1' + + +@struct_of_array.StructOfArray(same_dtype=True) +class Vec3Array: + """Vec3Array in 3 dimensional Space implemented as struct of arrays. + + This is done in order to improve performance and precision. + On TPU small matrix multiplications are very suboptimal and will waste large + compute ressources, furthermore any matrix multiplication on tpu happen in + mixed bfloat16/float32 precision, which is often undesirable when handling + physical coordinates. + In most cases this will also be faster on cpu's/gpu's since it allows for + easier use of vector instructions. + """ + + x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) + y: jnp.ndarray + z: jnp.ndarray + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + assert self.x.dtype == self.y.dtype + assert self.x.dtype == self.z.dtype + assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) + assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) + + def __add__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_multimap(lambda x, y: x + y, self, other) + + def __sub__(self, other: Vec3Array) -> Vec3Array: + return jax.tree_multimap(lambda x, y: x - y, self, other) + + def __mul__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x * other, self) + + def __rmul__(self, other: Float) -> Vec3Array: + return self * other + + def __truediv__(self, other: Float) -> Vec3Array: + return jax.tree_map(lambda x: x / other, self) + + def __neg__(self) -> Vec3Array: + return jax.tree_map(lambda x: -x, self) + + def __pos__(self) -> Vec3Array: + return jax.tree_map(lambda x: x, self) + + def cross(self, other: Vec3Array) -> Vec3Array: + """Compute cross product between 'self' and 'other'.""" + new_x = self.y * other.z - self.z * other.y + new_y = self.z * other.x - self.x * other.z + new_z = self.x * other.y - self.y * other.x + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Vec3Array) -> Float: + """Compute dot product between 'self' and 'other'.""" + return self.x * other.x + self.y * other.y + self.z * other.z + + def norm(self, epsilon: float = 1e-6) -> Float: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = jnp.maximum(norm2, epsilon**2) + return jnp.sqrt(norm2) + + def norm2(self): + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Vec3Array: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + @classmethod + def zeros(cls, shape, dtype=jnp.float32): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), + jnp.zeros(shape, dtype)) + + def to_array(self) -> jnp.ndarray: + return jnp.stack([self.x, self.y, self.z], axis=-1) + + @classmethod + def from_array(cls, array): + return cls(*utils.unstack(array)) + + def __getstate__(self): + return (VERSION, + [np.asarray(self.x), + np.asarray(self.y), + np.asarray(self.z)]) + + def __setstate__(self, state): + version, state = state + del version + for i, letter in enumerate('xyz'): + object.__setattr__(self, letter, state[i]) + + +def square_euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = jnp.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance(vec1: Vec3Array, + vec2: Vec3Array, + epsilon: float = 1e-6) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be + broadcast compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = jnp.sqrt(distance_sq) + return distance + + +def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, + d: Vec3Array) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2)) + + +def random_gaussian_vector(shape, key, dtype=jnp.float32): + vec_array = jax.random.normal(key, shape + (3,), dtype) + return Vec3Array.from_array(vec_array) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 66addeb6e..5a77b3c89 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -19,6 +19,7 @@ from alphafold.common import confidence from alphafold.model import features from alphafold.model import modules +from alphafold.model import modules_multimer import haiku as hk import jax import ml_collections @@ -28,19 +29,34 @@ def get_confidence_metrics( - prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: + prediction_result: Mapping[str, Any], + multimer_mode: bool) -> Mapping[str, Any]: """Post processes prediction_result to get confidence metrics.""" - confidence_metrics = {} confidence_metrics['plddt'] = confidence.compute_plddt( prediction_result['predicted_lddt']['logits']) if 'predicted_aligned_error' in prediction_result: confidence_metrics.update(confidence.compute_predicted_aligned_error( - prediction_result['predicted_aligned_error']['logits'], - prediction_result['predicted_aligned_error']['breaks'])) + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'])) confidence_metrics['ptm'] = confidence.predicted_tm_score( - prediction_result['predicted_aligned_error']['logits'], - prediction_result['predicted_aligned_error']['breaks']) + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'], + asym_id=None) + if multimer_mode: + # Compute the ipTM only for the multimer model. + confidence_metrics['iptm'] = confidence.predicted_tm_score( + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'], + asym_id=prediction_result['predicted_aligned_error']['asym_id'], + interface=True) + confidence_metrics['ranking_confidence'] = ( + 0.8 * confidence_metrics['iptm'] + 0.2 * confidence_metrics['ptm']) + + if not multimer_mode: + # Monomer models use mean pLDDT for model ranking. + confidence_metrics['ranking_confidence'] = np.mean( + confidence_metrics['plddt']) return confidence_metrics @@ -53,14 +69,22 @@ def __init__(self, params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): self.config = config self.params = params - - def _forward_fn(batch): - model = modules.AlphaFold(self.config.model) - return model( - batch, - is_training=False, - compute_loss=False, - ensemble_representations=True) + self.multimer_mode = config.model.global_config.multimer_mode + + if self.multimer_mode: + def _forward_fn(batch): + model = modules_multimer.AlphaFold(self.config.model) + return model( + batch, + is_training=False) + else: + def _forward_fn(batch): + model = modules.AlphaFold(self.config.model) + return model( + batch, + is_training=False, + compute_loss=False, + ensemble_representations=True) self.apply = jax.jit(hk.transform(_forward_fn).apply) self.init = jax.jit(hk.transform(_forward_fn).init) @@ -98,6 +122,11 @@ def process_features( Returns: A dict of NumPy feature arrays suitable for feeding into the model. """ + + if self.multimer_mode: + return raw_features + + # Single-chain mode. if isinstance(raw_features, dict): return features.np_example_to_features( np_example=raw_features, @@ -117,12 +146,17 @@ def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: logging.info('Output shape was %s', shape) return shape - def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: + def predict(self, + feat: features.FeatureDict, + random_seed: int, + ) -> Mapping[str, Any]: """Makes a prediction by inferencing the model on the provided features. Args: feat: A dictionary of NumPy feature arrays as output by RunModel.process_features. + random_seed: The random seed to use when running the model. In the + multimer model this controls the MSA sampling. Returns: A dictionary of model outputs. @@ -130,12 +164,14 @@ def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: self.init_params(feat) logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) - result = self.apply(self.params, jax.random.PRNGKey(0), feat) + result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) + # This block is to ensure benchmark timings are accurate. Some blocking is # already happening when computing get_confidence_metrics, and this ensures # all outputs are blocked on. jax.tree_map(lambda x: x.block_until_ready(), result) - result.update(get_confidence_metrics(result)) + result.update( + get_confidence_metrics(result, multimer_mode=self.multimer_mode)) logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index 794597ff1..031cd1d31 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -965,6 +965,11 @@ def __init__(self, config, global_config, name='masked_msa_head'): self.config = config self.global_config = global_config + if global_config.multimer_mode: + self.num_output = len(residue_constants.restypes_with_x_and_gap) + else: + self.num_output = config.num_output + def __call__(self, representations, batch, is_training): """Builds MaskedMsaHead module. @@ -981,7 +986,7 @@ def __call__(self, representations, batch, is_training): """ del batch logits = common_modules.Linear( - self.config.num_output, + self.num_output, initializer=utils.final_init(self.global_config), name='logits')( representations['msa']) @@ -989,7 +994,7 @@ def __call__(self, representations, batch, is_training): def loss(self, value, batch): errors = softmax_cross_entropy( - labels=jax.nn.one_hot(batch['true_msa'], num_classes=23), + labels=jax.nn.one_hot(batch['true_msa'], num_classes=self.num_output), logits=value['logits']) loss = (jnp.sum(errors * batch['bert_mask'], axis=(-2, -1)) / (1e-8 + jnp.sum(batch['bert_mask'], axis=(-2, -1)))) @@ -1009,7 +1014,7 @@ def __init__(self, config, global_config, name='predicted_lddt_head'): self.global_config = global_config def __call__(self, representations, batch, is_training): - """Builds ExperimentallyResolvedHead module. + """Builds PredictedLDDTHead module. Arguments: representations: Dictionary of representations, must contain: @@ -1071,7 +1076,7 @@ def loss(self, value, batch): # Shape (batch_size, num_res, 1) true_points_mask=all_atom_mask[None, :, 1:2].astype(jnp.float32), cutoff=15., - per_residue=True)[0] + per_residue=True) lddt_ca = jax.lax.stop_gradient(lddt_ca) num_bins = self.config.num_bins @@ -1597,6 +1602,19 @@ def __call__(self, activations, masks, is_training=True, safe_key=None): safe_key, *sub_keys = safe_key.split(10) sub_keys = iter(sub_keys) + outer_module = OuterProductMean( + config=c.outer_product_mean, + global_config=self.global_config, + num_output_channel=int(pair_act.shape[-1]), + name='outer_product_mean') + if c.outer_product_mean.first: + pair_act = dropout_wrapper_fn( + outer_module, + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) + msa_act = dropout_wrapper_fn( MSARowAttentionWithPairBias( c.msa_row_attention_with_pair_bias, gc, @@ -1624,16 +1642,13 @@ def __call__(self, activations, masks, is_training=True, safe_key=None): msa_mask, safe_key=next(sub_keys)) - pair_act = dropout_wrapper_fn( - OuterProductMean( - config=c.outer_product_mean, - global_config=self.global_config, - num_output_channel=int(pair_act.shape[-1]), - name='outer_product_mean'), - msa_act, - msa_mask, - safe_key=next(sub_keys), - output_act=pair_act) + if not c.outer_product_mean.first: + pair_act = dropout_wrapper_fn( + outer_module, + msa_act, + msa_mask, + safe_key=next(sub_keys), + output_act=pair_act) pair_act = dropout_wrapper_fn( TriangleMultiplication(c.triangle_multiplication_outgoing, gc, @@ -1730,8 +1745,7 @@ def __call__(self, batch, is_training, safe_key=None): True, name='prev_msa_first_row_norm')( batch['prev_msa_first_row']) - msa_activations = jax.ops.index_add(msa_activations, 0, - prev_msa_first_row) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) if 'prev_pair' in batch: pair_activations += hk.LayerNorm([-1], diff --git a/alphafold/model/modules_multimer.py b/alphafold/model/modules_multimer.py new file mode 100644 index 000000000..4e76d41e3 --- /dev/null +++ b/alphafold/model/modules_multimer.py @@ -0,0 +1,1129 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core modules, which have been refactored in AlphaFold-Multimer. + +The main difference is that MSA sampling pipeline is moved inside the JAX model +for easier implementation of recycling and ensembling. + +Lower-level modules up to EvoformerIteration are reused from modules.py. +""" + +import functools +from typing import Sequence + +from alphafold.common import residue_constants +from alphafold.model import all_atom_multimer +from alphafold.model import common_modules +from alphafold.model import folding_multimer +from alphafold.model import geometry +from alphafold.model import layer_stack +from alphafold.model import modules +from alphafold.model import prng +from alphafold.model import utils + +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + + +def reduce_fn(x, mode): + if mode == 'none' or mode is None: + return jnp.asarray(x) + elif mode == 'sum': + return jnp.asarray(x).sum() + elif mode == 'mean': + return jnp.mean(jnp.asarray(x)) + else: + raise ValueError('Unsupported reduction option.') + + +def gumbel_noise(key: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray: + """Generate Gumbel Noise of given Shape. + + This generates samples from Gumbel(0, 1). + + Args: + key: Jax random number key. + shape: Shape of noise to return. + + Returns: + Gumbel noise of given shape. + """ + epsilon = 1e-6 + uniform = utils.padding_consistent_rng(jax.random.uniform) + uniform_noise = uniform( + key, shape=shape, dtype=jnp.float32, minval=0., maxval=1.) + gumbel = -jnp.log(-jnp.log(uniform_noise + epsilon) + epsilon) + return gumbel + + +def gumbel_max_sample(key: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: + """Samples from a probability distribution given by 'logits'. + + This uses Gumbel-max trick to implement the sampling in an efficient manner. + + Args: + key: prng key. + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + + Returns: + Sample from logprobs in one-hot form. + """ + z = gumbel_noise(key, logits.shape) + return jax.nn.one_hot( + jnp.argmax(logits + z, axis=-1), + logits.shape[-1], + dtype=logits.dtype) + + +def gumbel_argsort_sample_idx(key: jnp.ndarray, + logits: jnp.ndarray) -> jnp.ndarray: + """Samples with replacement from a distribution given by 'logits'. + + This uses Gumbel trick to implement the sampling an efficient manner. For a + distribution over k items this samples k times without replacement, so this + is effectively sampling a random permutation with probabilities over the + permutations derived from the logprobs. + + Args: + key: prng key. + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + + Returns: + Sample from logprobs in one-hot form. + """ + z = gumbel_noise(key, logits.shape) + # This construction is equivalent to jnp.argsort, but using a non stable sort, + # since stable sort's aren't supported by jax2tf. + axis = len(logits.shape) - 1 + iota = jax.lax.broadcasted_iota(jnp.int64, logits.shape, axis) + _, perm = jax.lax.sort_key_val( + logits + z, iota, dimension=-1, is_stable=False) + return perm[::-1] + + +def make_masked_msa(batch, key, config, epsilon=1e-6): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly. + random_aa = jnp.array([0.05] * 20 + [0., 0.], dtype=jnp.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * batch['msa_profile'] + + config.same_prob * jax.nn.one_hot(batch['msa'], 22)) + + # Put all remaining probability on [MASK] which is a new column. + pad_shapes = [[0, 0] for _ in range(len(categorical_probs.shape))] + pad_shapes[-1][1] = 1 + mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob + assert mask_prob >= 0. + categorical_probs = jnp.pad( + categorical_probs, pad_shapes, constant_values=mask_prob) + sh = batch['msa'].shape + key, mask_subkey, gumbel_subkey = key.split(3) + uniform = utils.padding_consistent_rng(jax.random.uniform) + mask_position = uniform(mask_subkey.get(), sh) < config.replace_fraction + mask_position *= batch['msa_mask'] + + logits = jnp.log(categorical_probs + epsilon) + bert_msa = gumbel_max_sample(gumbel_subkey.get(), logits) + bert_msa = jnp.where(mask_position, + jnp.argmax(bert_msa, axis=-1), batch['msa']) + bert_msa *= batch['msa_mask'] + + # Mix real and masked MSA. + if 'bert_mask' in batch: + batch['bert_mask'] *= mask_position.astype(jnp.float32) + else: + batch['bert_mask'] = mask_position.astype(jnp.float32) + batch['true_msa'] = batch['msa'] + batch['msa'] = bert_msa + + return batch + + +def nearest_neighbor_clusters(batch, gap_agreement_weight=0.): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask. + + weights = jnp.array( + [1.] * 21 + [gap_agreement_weight] + [0.], dtype=jnp.float32) + + msa_mask = batch['msa_mask'] + msa_one_hot = jax.nn.one_hot(batch['msa'], 23) + + extra_mask = batch['extra_msa_mask'] + extra_one_hot = jax.nn.one_hot(batch['extra_msa'], 23) + + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot + + agreement = jnp.einsum('mrc, nrc->nm', extra_one_hot_masked, + weights * msa_one_hot_masked) + + cluster_assignment = jax.nn.softmax(1e3 * agreement, axis=0) + cluster_assignment *= jnp.einsum('mr, nr->mn', msa_mask, extra_mask) + + cluster_count = jnp.sum(cluster_assignment, axis=-1) + cluster_count += 1. # We always include the sequence itself. + + msa_sum = jnp.einsum('nm, mrc->nrc', cluster_assignment, extra_one_hot_masked) + msa_sum += msa_one_hot_masked + + cluster_profile = msa_sum / cluster_count[:, None, None] + + extra_deletion_matrix = batch['extra_deletion_matrix'] + deletion_matrix = batch['deletion_matrix'] + + del_sum = jnp.einsum('nm, mc->nc', cluster_assignment, + extra_mask * extra_deletion_matrix) + del_sum += deletion_matrix # Original sequence. + cluster_deletion_mean = del_sum / cluster_count[:, None] + + return cluster_profile, cluster_deletion_mean + + +def create_msa_feat(batch): + """Create and concatenate MSA features.""" + msa_1hot = jax.nn.one_hot(batch['msa'], 23) + deletion_matrix = batch['deletion_matrix'] + has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] + deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] + + deletion_mean_value = (jnp.arctan(batch['cluster_deletion_mean'] / 3.) * + (2. / jnp.pi))[..., None] + + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + batch['cluster_profile'], + deletion_mean_value + ] + + return jnp.concatenate(msa_feat, axis=-1) + + +def create_extra_msa_feature(batch, num_extra_msa): + """Expand extra_msa into 1hot and concat with other extra msa features. + + We do this as late as possible as the one_hot extra msa can be very large. + + Args: + batch: a dictionary with the following keys: + * 'extra_msa': [num_seq, num_res] MSA that wasn't selected as a cluster + centre. Note - This isn't one-hotted. + * 'extra_deletion_matrix': [num_seq, num_res] Number of deletions at given + position. + num_extra_msa: Number of extra msa to use. + + Returns: + Concatenated tensor of extra MSA features. + """ + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = batch['extra_msa'][:num_extra_msa] + deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] + msa_1hot = jax.nn.one_hot(extra_msa, 23) + has_deletion = jnp.clip(deletion_matrix, 0., 1.)[..., None] + deletion_value = (jnp.arctan(deletion_matrix / 3.) * (2. / jnp.pi))[..., None] + extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] + return jnp.concatenate([msa_1hot, has_deletion, deletion_value], + axis=-1), extra_msa_mask + + +def sample_msa(key, batch, max_seq): + """Sample MSA randomly, remaining sequences are stored as `extra_*`. + + Args: + key: safe key for random number generation. + batch: batch to sample msa from. + max_seq: number of sequences to sample. + Returns: + Protein with sampled msa. + """ + # Sample uniformly among sequences with at least one non-masked position. + logits = (jnp.clip(jnp.sum(batch['msa_mask'], axis=-1), 0., 1.) - 1.) * 1e6 + # The cluster_bias_mask can be used to preserve the first row (target + # sequence) for each chain, for example. + if 'cluster_bias_mask' not in batch: + cluster_bias_mask = jnp.pad( + jnp.zeros(batch['msa'].shape[0] - 1), (1, 0), constant_values=1.) + else: + cluster_bias_mask = batch['cluster_bias_mask'] + + logits += cluster_bias_mask * 1e6 + index_order = gumbel_argsort_sample_idx(key.get(), logits) + sel_idx = index_order[:max_seq] + extra_idx = index_order[max_seq:] + + for k in ['msa', 'deletion_matrix', 'msa_mask', 'bert_mask']: + if k in batch: + batch['extra_' + k] = batch[k][extra_idx] + batch[k] = batch[k][sel_idx] + + return batch + + +def make_msa_profile(batch): + """Compute the MSA profile.""" + + # Compute the profile for every residue (over all MSA sequences). + return utils.mask_mean( + batch['msa_mask'][:, :, None], jax.nn.one_hot(batch['msa'], 22), axis=0) + + +class AlphaFoldIteration(hk.Module): + """A single recycling iteration of AlphaFold architecture. + + Computes ensembled (averaged) representations from the provided features. + These representations are then passed to the various heads + that have been requested by the configuration file. + """ + + def __init__(self, config, global_config, name='alphafold_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, + batch, + is_training, + return_representations=False, + safe_key=None): + + if is_training: + num_ensemble = np.asarray(self.config.num_ensemble_train) + else: + num_ensemble = np.asarray(self.config.num_ensemble_eval) + + # Compute representations for each MSA sample and average. + embedding_module = EmbeddingsAndEvoformer( + self.config.embeddings_and_evoformer, self.global_config) + repr_shape = hk.eval_shape( + lambda: embedding_module(batch, is_training)) + representations = { + k: jnp.zeros(v.shape, v.dtype) for (k, v) in repr_shape.items() + } + + def ensemble_body(x, unused_y): + """Add into representations ensemble.""" + del unused_y + representations, safe_key = x + safe_key, safe_subkey = safe_key.split() + representations_update = embedding_module( + batch, is_training, safe_key=safe_subkey) + + for k in representations: + if k not in {'msa', 'true_msa', 'bert_mask'}: + representations[k] += representations_update[k] * ( + 1. / num_ensemble).astype(representations[k].dtype) + else: + representations[k] = representations_update[k] + + return (representations, safe_key), None + + (representations, _), _ = hk.scan( + ensemble_body, (representations, safe_key), None, length=num_ensemble) + + self.representations = representations + self.batch = batch + self.heads = {} + for head_name, head_config in sorted(self.config.heads.items()): + if not head_config.weight: + continue # Do not instantiate zero-weight heads. + + head_factory = { + 'masked_msa': + modules.MaskedMsaHead, + 'distogram': + modules.DistogramHead, + 'structure_module': + folding_multimer.StructureModule, + 'predicted_aligned_error': + modules.PredictedAlignedErrorHead, + 'predicted_lddt': + modules.PredictedLDDTHead, + 'experimentally_resolved': + modules.ExperimentallyResolvedHead, + }[head_name] + self.heads[head_name] = (head_config, + head_factory(head_config, self.global_config)) + + structure_module_output = None + if 'entity_id' in batch and 'all_atom_positions' in batch: + _, fold_module = self.heads['structure_module'] + structure_module_output = fold_module(representations, batch, is_training) + + ret = {} + ret['representations'] = representations + + for name, (head_config, module) in self.heads.items(): + if name == 'structure_module' and structure_module_output is not None: + ret[name] = structure_module_output + representations['structure_module'] = structure_module_output.pop('act') + # Skip confidence heads until StructureModule is executed. + elif name in {'predicted_lddt', 'predicted_aligned_error', + 'experimentally_resolved'}: + continue + else: + ret[name] = module(representations, batch, is_training) + + # Add confidence heads after StructureModule is executed. + if self.config.heads.get('predicted_lddt.weight', 0.0): + name = 'predicted_lddt' + head_config, module = self.heads[name] + ret[name] = module(representations, batch, is_training) + + if self.config.heads.experimentally_resolved.weight: + name = 'experimentally_resolved' + head_config, module = self.heads[name] + ret[name] = module(representations, batch, is_training) + + if self.config.heads.get('predicted_aligned_error.weight', 0.0): + name = 'predicted_aligned_error' + head_config, module = self.heads[name] + ret[name] = module(representations, batch, is_training) + # Will be used for ipTM computation. + ret[name]['asym_id'] = batch['asym_id'] + + return ret + + +class AlphaFold(hk.Module): + """AlphaFold-Multimer model with recycling. + """ + + def __init__(self, config, name='alphafold'): + super().__init__(name=name) + self.config = config + self.global_config = config.global_config + + def __call__( + self, + batch, + is_training, + return_representations=False, + safe_key=None): + + c = self.config + impl = AlphaFoldIteration(c, self.global_config) + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + elif isinstance(safe_key, jnp.ndarray): + safe_key = prng.SafeKey(safe_key) + + assert isinstance(batch, dict) + num_res = batch['aatype'].shape[0] + + def get_prev(ret): + new_prev = { + 'prev_pos': + ret['structure_module']['final_atom_positions'], + 'prev_msa_first_row': ret['representations']['msa_first_row'], + 'prev_pair': ret['representations']['pair'], + } + return jax.tree_map(jax.lax.stop_gradient, new_prev) + + def apply_network(prev, safe_key): + recycled_batch = {**batch, **prev} + return impl( + batch=recycled_batch, + is_training=is_training, + safe_key=safe_key) + + if self.config.num_recycle: + emb_config = self.config.embeddings_and_evoformer + prev = { + 'prev_pos': + jnp.zeros([num_res, residue_constants.atom_type_num, 3]), + 'prev_msa_first_row': + jnp.zeros([num_res, emb_config.msa_channel]), + 'prev_pair': + jnp.zeros([num_res, num_res, emb_config.pair_channel]), + } + + if 'num_iter_recycling' in batch: + # Training time: num_iter_recycling is in batch. + # Value for each ensemble batch is the same, so arbitrarily taking 0-th. + num_iter = batch['num_iter_recycling'][0] + + # Add insurance that even when ensembling, we will not run more + # recyclings than the model is configured to run. + num_iter = jnp.minimum(num_iter, c.num_recycle) + else: + # Eval mode or tests: use the maximum number of iterations. + num_iter = c.num_recycle + + def recycle_body(i, x): + del i + prev, safe_key = x + safe_key1, safe_key2 = safe_key.split() if c.resample_msa_in_recycling else safe_key.duplicate() # pylint: disable=line-too-long + ret = apply_network(prev=prev, safe_key=safe_key2) + return get_prev(ret), safe_key1 + + prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key)) + else: + prev = {} + + # Run extra iteration. + ret = apply_network(prev=prev, safe_key=safe_key) + + if not return_representations: + del ret['representations'] + return ret + + +class EmbeddingsAndEvoformer(hk.Module): + """Embeds the input data and runs Evoformer. + + Produces the MSA, single and pair representations. + """ + + def __init__(self, config, global_config, name='evoformer'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def _relative_encoding(self, batch): + """Add relative position encodings. + + For position (i, j), the value is (i-j) clipped to [-k, k] and one-hotted. + + When not using 'use_chain_relative' the residue indices are used as is, e.g. + for heteromers relative positions will be computed using the positions in + the corresponding chains. + + When using 'use_chain_relative' we add an extra bin that denotes + 'different chain'. Furthermore we also provide the relative chain index + (i.e. sym_id) clipped and one-hotted to the network. And an extra feature + which denotes whether they belong to the same chain type, i.e. it's 0 if + they are in different heteromer chains and 1 otherwise. + + Args: + batch: batch. + Returns: + Feature embedding using the features as described before. + """ + c = self.config + rel_feats = [] + pos = batch['residue_index'] + asym_id = batch['asym_id'] + asym_id_same = jnp.equal(asym_id[:, None], asym_id[None, :]) + offset = pos[:, None] - pos[None, :] + + clipped_offset = jnp.clip( + offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx) + + if c.use_chain_relative: + + final_offset = jnp.where(asym_id_same, clipped_offset, + (2 * c.max_relative_idx + 1) * + jnp.ones_like(clipped_offset)) + + rel_pos = jax.nn.one_hot(final_offset, 2 * c.max_relative_idx + 2) + + rel_feats.append(rel_pos) + + entity_id = batch['entity_id'] + entity_id_same = jnp.equal(entity_id[:, None], entity_id[None, :]) + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + + sym_id = batch['sym_id'] + rel_sym_id = sym_id[:, None] - sym_id[None, :] + + max_rel_chain = c.max_relative_chain + + clipped_rel_chain = jnp.clip( + rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain) + + final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain, + (2 * max_rel_chain + 1) * + jnp.ones_like(clipped_rel_chain)) + rel_chain = jax.nn.one_hot(final_rel_chain, 2 * c.max_relative_chain + 2) + + rel_feats.append(rel_chain) + + else: + rel_pos = jax.nn.one_hot(clipped_offset, 2 * c.max_relative_idx + 1) + rel_feats.append(rel_pos) + + rel_feat = jnp.concatenate(rel_feats, axis=-1) + + return common_modules.Linear( + c.pair_channel, + name='position_activations')( + rel_feat) + + def __call__(self, batch, is_training, safe_key=None): + + c = self.config + gc = self.global_config + + batch = dict(batch) + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + output = {} + + batch['msa_profile'] = make_msa_profile(batch) + + target_feat = jax.nn.one_hot(batch['aatype'], 21) + + preprocess_1d = common_modules.Linear( + c.msa_channel, name='preprocess_1d')( + target_feat) + + safe_key, sample_key, mask_key = safe_key.split(3) + batch = sample_msa(sample_key, batch, c.num_msa) + batch = make_masked_msa(batch, mask_key, c.masked_msa) + + (batch['cluster_profile'], + batch['cluster_deletion_mean']) = nearest_neighbor_clusters(batch) + + msa_feat = create_msa_feat(batch) + + preprocess_msa = common_modules.Linear( + c.msa_channel, name='preprocess_msa')( + msa_feat) + + msa_activations = jnp.expand_dims(preprocess_1d, axis=0) + preprocess_msa + + left_single = common_modules.Linear( + c.pair_channel, name='left_single')( + target_feat) + right_single = common_modules.Linear( + c.pair_channel, name='right_single')( + target_feat) + pair_activations = left_single[:, None] + right_single[None] + mask_2d = batch['seq_mask'][:, None] * batch['seq_mask'][None, :] + mask_2d = mask_2d.astype(jnp.float32) + + if c.recycle_pos and 'prev_pos' in batch: + prev_pseudo_beta = modules.pseudo_beta_fn( + batch['aatype'], batch['prev_pos'], None) + + dgram = modules.dgram_from_positions( + prev_pseudo_beta, **self.config.prev_pos) + pair_activations += common_modules.Linear( + c.pair_channel, name='prev_pos_linear')( + dgram) + + if c.recycle_features: + if 'prev_msa_first_row' in batch: + prev_msa_first_row = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_msa_first_row_norm')( + batch['prev_msa_first_row']) + msa_activations = msa_activations.at[0].add(prev_msa_first_row) + + if 'prev_pair' in batch: + pair_activations += hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='prev_pair_norm')( + batch['prev_pair']) + + if c.max_relative_idx: + pair_activations += self._relative_encoding(batch) + + if c.template.enabled: + template_module = TemplateEmbedding(c.template, gc) + template_batch = { + 'template_aatype': batch['template_aatype'], + 'template_all_atom_positions': batch['template_all_atom_positions'], + 'template_all_atom_mask': batch['template_all_atom_mask'] + } + # Construct a mask such that only intra-chain template features are + # computed, since all templates are for each chain individually. + multichain_mask = batch['asym_id'][:, None] == batch['asym_id'][None, :] + safe_key, safe_subkey = safe_key.split() + template_act = template_module( + query_embedding=pair_activations, + template_batch=template_batch, + padding_mask_2d=mask_2d, + multichain_mask_2d=multichain_mask, + is_training=is_training, + safe_key=safe_subkey) + pair_activations += template_act + + # Extra MSA stack. + (extra_msa_feat, + extra_msa_mask) = create_extra_msa_feature(batch, c.num_extra_msa) + extra_msa_activations = common_modules.Linear( + c.extra_msa_channel, + name='extra_msa_activations')( + extra_msa_feat) + extra_msa_mask = extra_msa_mask.astype(jnp.float32) + + extra_evoformer_input = { + 'msa': extra_msa_activations, + 'pair': pair_activations, + } + extra_masks = {'msa': extra_msa_mask, 'pair': mask_2d} + + extra_evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=True, name='extra_msa_stack') + + def extra_evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + extra_evoformer_output = extra_evoformer_iteration( + activations=act, + masks=extra_masks, + is_training=is_training, + safe_key=safe_subkey) + return (extra_evoformer_output, safe_key) + + if gc.use_remat: + extra_evoformer_fn = hk.remat(extra_evoformer_fn) + + safe_key, safe_subkey = safe_key.split() + extra_evoformer_stack = layer_stack.layer_stack( + c.extra_msa_stack_num_block)( + extra_evoformer_fn) + extra_evoformer_output, safe_key = extra_evoformer_stack( + (extra_evoformer_input, safe_subkey)) + + pair_activations = extra_evoformer_output['pair'] + + # Get the size of the MSA before potentially adding templates, so we + # can crop out the templates later. + num_msa_sequences = msa_activations.shape[0] + evoformer_input = { + 'msa': msa_activations, + 'pair': pair_activations, + } + evoformer_masks = {'msa': batch['msa_mask'].astype(jnp.float32), + 'pair': mask_2d} + + if c.template.enabled: + template_features, template_masks = ( + template_embedding_1d(batch=batch, num_channel=c.msa_channel)) + + evoformer_input['msa'] = jnp.concatenate( + [evoformer_input['msa'], template_features], axis=0) + evoformer_masks['msa'] = jnp.concatenate( + [evoformer_masks['msa'], template_masks], axis=0) + + evoformer_iteration = modules.EvoformerIteration( + c.evoformer, gc, is_extra_msa=False, name='evoformer_iteration') + + def evoformer_fn(x): + act, safe_key = x + safe_key, safe_subkey = safe_key.split() + evoformer_output = evoformer_iteration( + activations=act, + masks=evoformer_masks, + is_training=is_training, + safe_key=safe_subkey) + return (evoformer_output, safe_key) + + if gc.use_remat: + evoformer_fn = hk.remat(evoformer_fn) + + safe_key, safe_subkey = safe_key.split() + evoformer_stack = layer_stack.layer_stack(c.evoformer_num_block)( + evoformer_fn) + + def run_evoformer(evoformer_input): + evoformer_output, _ = evoformer_stack((evoformer_input, safe_subkey)) + return evoformer_output + + evoformer_output = run_evoformer(evoformer_input) + + msa_activations = evoformer_output['msa'] + pair_activations = evoformer_output['pair'] + + single_activations = common_modules.Linear( + c.seq_channel, name='single_activations')( + msa_activations[0]) + + output.update({ + 'single': + single_activations, + 'pair': + pair_activations, + # Crop away template rows such that they are not used in MaskedMsaHead. + 'msa': + msa_activations[:num_msa_sequences, :, :], + 'msa_first_row': + msa_activations[0], + }) + + return output + + +class TemplateEmbedding(hk.Module): + """Embed a set of templates.""" + + def __init__(self, config, global_config, name='template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_batch, padding_mask_2d, + multichain_mask_2d, is_training, + safe_key=None): + """Generate an embedding for a set of templates. + + Args: + query_embedding: [num_res, num_res, num_channel] a query tensor that will + be used to attend over the templates to remove the num_templates + dimension. + template_batch: A dictionary containing: + `template_aatype`: [num_templates, num_res] aatype for each template. + `template_all_atom_positions`: [num_templates, num_res, 37, 3] atom + positions for all templates. + `template_all_atom_mask`: [num_templates, num_res, 37] mask for each + template. + padding_mask_2d: [num_res, num_res] Pair mask for attention operations. + multichain_mask_2d: [num_res, num_res] Mask indicating which residue pairs + are intra-chain, used to mask out residue distance based features + between chains. + is_training: bool indicating where we are running in training mode. + safe_key: random key generator. + + Returns: + An embedding of size [num_res, num_res, num_channels] + """ + c = self.config + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + num_templates = template_batch['template_aatype'].shape[0] + num_res, _, query_num_channels = query_embedding.shape + + # Embed each template separately. + template_embedder = SingleTemplateEmbedding(self.config, self.global_config) + def partial_template_embedder(template_aatype, + template_all_atom_positions, + template_all_atom_mask, + unsafe_key): + safe_key = prng.SafeKey(unsafe_key) + return template_embedder(query_embedding, + template_aatype, + template_all_atom_positions, + template_all_atom_mask, + padding_mask_2d, + multichain_mask_2d, + is_training, + safe_key) + + safe_key, unsafe_key = safe_key.split() + unsafe_keys = jax.random.split(unsafe_key._key, num_templates) + + def scan_fn(carry, x): + return carry + partial_template_embedder(*x), None + + scan_init = jnp.zeros((num_res, num_res, c.num_channels), + dtype=query_embedding.dtype) + summed_template_embeddings, _ = hk.scan( + scan_fn, scan_init, + (template_batch['template_aatype'], + template_batch['template_all_atom_positions'], + template_batch['template_all_atom_mask'], unsafe_keys)) + + embedding = summed_template_embeddings / num_templates + embedding = jax.nn.relu(embedding) + embedding = common_modules.Linear( + query_num_channels, + initializer='relu', + name='output_linear')(embedding) + + return embedding + + +class SingleTemplateEmbedding(hk.Module): + """Embed a single template.""" + + def __init__(self, config, global_config, name='single_template_embedding'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, query_embedding, template_aatype, + template_all_atom_positions, template_all_atom_mask, + padding_mask_2d, multichain_mask_2d, is_training, + safe_key): + """Build the single template embedding graph. + + Args: + query_embedding: (num_res, num_res, num_channels) - embedding of the + query sequence/msa. + template_aatype: [num_res] aatype for each template. + template_all_atom_positions: [num_res, 37, 3] atom positions for all + templates. + template_all_atom_mask: [num_res, 37] mask for each template. + padding_mask_2d: Padding mask (Note: this doesn't care if a template + exists, unlike the template_pseudo_beta_mask). + multichain_mask_2d: A mask indicating intra-chain residue pairs, used + to mask out between chain distances/features when templates are for + single chains. + is_training: Are we in training mode. + safe_key: Random key generator. + + Returns: + A template embedding (num_res, num_res, num_channels). + """ + gc = self.global_config + c = self.config + assert padding_mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + num_channels = self.config.num_channels + + def construct_input(query_embedding, template_aatype, + template_all_atom_positions, template_all_atom_mask, + multichain_mask_2d): + + # Compute distogram feature for the template. + template_positions, pseudo_beta_mask = modules.pseudo_beta_fn( + template_aatype, template_all_atom_positions, template_all_atom_mask) + pseudo_beta_mask_2d = (pseudo_beta_mask[:, None] * + pseudo_beta_mask[None, :]) + pseudo_beta_mask_2d *= multichain_mask_2d + template_dgram = modules.dgram_from_positions( + template_positions, **self.config.dgram_features) + template_dgram *= pseudo_beta_mask_2d[..., None] + template_dgram = template_dgram.astype(dtype) + pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) + to_concat = [(template_dgram, 1), (pseudo_beta_mask_2d, 0)] + + aatype = jax.nn.one_hot(template_aatype, 22, axis=-1, dtype=dtype) + to_concat.append((aatype[None, :, :], 1)) + to_concat.append((aatype[:, None, :], 1)) + + # Compute a feature representing the normalized vector between each + # backbone affine - i.e. in each residues local frame, what direction are + # each of the other residues. + raw_atom_pos = template_all_atom_positions + + atom_pos = geometry.Vec3Array.from_array(raw_atom_pos) + rigid, backbone_mask = folding_multimer.make_backbone_affine( + atom_pos, + template_all_atom_mask, + template_aatype) + points = rigid.translation + rigid_vec = rigid[:, None].inverse().apply_to_point(points) + unit_vector = rigid_vec.normalized() + unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] + + backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] + backbone_mask_2d *= multichain_mask_2d + unit_vector = [x*backbone_mask_2d for x in unit_vector] + + # Note that the backbone_mask takes into account C, CA and N (unlike + # pseudo beta mask which just needs CB) so we add both masks as features. + to_concat.extend([(x, 0) for x in unit_vector]) + to_concat.append((backbone_mask_2d, 0)) + + query_embedding = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='query_embedding_norm')( + query_embedding) + # Allow the template embedder to see the query embedding. Note this + # contains the position relative feature, so this is how the network knows + # which residues are next to each other. + to_concat.append((query_embedding, 1)) + + act = 0 + + for i, (x, n_input_dims) in enumerate(to_concat): + + act += common_modules.Linear( + num_channels, + num_input_dims=n_input_dims, + initializer='relu', + name=f'template_pair_embedding_{i}')(x) + return act + + act = construct_input(query_embedding, template_aatype, + template_all_atom_positions, template_all_atom_mask, + multichain_mask_2d) + + template_iteration = TemplateEmbeddingIteration( + c.template_pair_stack, gc, name='template_embedding_iteration') + + def template_iteration_fn(x): + act, safe_key = x + + safe_key, safe_subkey = safe_key.split() + act = template_iteration( + act=act, + pair_mask=padding_mask_2d, + is_training=is_training, + safe_key=safe_subkey) + return (act, safe_key) + + if gc.use_remat: + template_iteration_fn = hk.remat(template_iteration_fn) + + safe_key, safe_subkey = safe_key.split() + template_stack = layer_stack.layer_stack( + c.template_pair_stack.num_block)( + template_iteration_fn) + act, safe_key = template_stack((act, safe_subkey)) + + act = hk.LayerNorm( + axis=[-1], + create_scale=True, + create_offset=True, + name='output_layer_norm')( + act) + return act + + +class TemplateEmbeddingIteration(hk.Module): + """Single Iteration of Template Embedding.""" + + def __init__(self, config, global_config, + name='template_embedding_iteration'): + super().__init__(name=name) + self.config = config + self.global_config = global_config + + def __call__(self, act, pair_mask, is_training=True, + safe_key=None): + """Build a single iteration of the template embedder. + + Args: + act: [num_res, num_res, num_channel] Input pairwise activations. + pair_mask: [num_res, num_res] padding mask. + is_training: Whether to run in training mode. + safe_key: Safe pseudo-random generator key. + + Returns: + [num_res, num_res, num_channel] tensor of activations. + """ + c = self.config + gc = self.global_config + + if safe_key is None: + safe_key = prng.SafeKey(hk.next_rng_key()) + + dropout_wrapper_fn = functools.partial( + modules.dropout_wrapper, + is_training=is_training, + global_config=gc) + + safe_key, *sub_keys = safe_key.split(20) + sub_keys = iter(sub_keys) + + act = dropout_wrapper_fn( + modules.TriangleMultiplication(c.triangle_multiplication_outgoing, gc, + name='triangle_multiplication_outgoing'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleMultiplication(c.triangle_multiplication_incoming, gc, + name='triangle_multiplication_incoming'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleAttention(c.triangle_attention_starting_node, gc, + name='triangle_attention_starting_node'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.TriangleAttention(c.triangle_attention_ending_node, gc, + name='triangle_attention_ending_node'), + act, + pair_mask, + safe_key=next(sub_keys)) + + act = dropout_wrapper_fn( + modules.Transition(c.pair_transition, gc, + name='pair_transition'), + act, + pair_mask, + safe_key=next(sub_keys)) + + return act + + +def template_embedding_1d(batch, num_channel): + """Embed templates into an (num_res, num_templates, num_channels) embedding. + + Args: + batch: A batch containing: + template_aatype, (num_templates, num_res) aatype for the templates. + template_all_atom_positions, (num_templates, num_residues, 37, 3) atom + positions for the templates. + template_all_atom_mask, (num_templates, num_residues, 37) atom mask for + each template. + num_channel: The number of channels in the output. + + Returns: + An embedding of shape (num_templates, num_res, num_channels) and a mask of + shape (num_templates, num_res). + """ + + # Embed the templates aatypes. + aatype_one_hot = jax.nn.one_hot(batch['template_aatype'], 22, axis=-1) + + num_templates = batch['template_aatype'].shape[0] + all_chi_angles = [] + all_chi_masks = [] + for i in range(num_templates): + atom_pos = geometry.Vec3Array.from_array( + batch['template_all_atom_positions'][i, :, :, :]) + template_chi_angles, template_chi_mask = all_atom_multimer.compute_chi_angles( + atom_pos, + batch['template_all_atom_mask'][i, :, :], + batch['template_aatype'][i, :]) + all_chi_angles.append(template_chi_angles) + all_chi_masks.append(template_chi_mask) + chi_angles = jnp.stack(all_chi_angles, axis=0) + chi_mask = jnp.stack(all_chi_masks, axis=0) + + template_features = jnp.concatenate([ + aatype_one_hot, + jnp.sin(chi_angles) * chi_mask, + jnp.cos(chi_angles) * chi_mask, + chi_mask], axis=-1) + + template_mask = chi_mask[:, :, 0] + + template_activations = common_modules.Linear( + num_channel, + initializer='relu', + name='template_single_embedding')( + template_features) + template_activations = jax.nn.relu(template_activations) + template_activations = common_modules.Linear( + num_channel, + initializer='relu', + name='template_projection')( + template_activations) + return template_activations, template_mask diff --git a/alphafold/model/utils.py b/alphafold/model/utils.py index 8ed5361e8..40ca1683e 100644 --- a/alphafold/model/utils.py +++ b/alphafold/model/utils.py @@ -15,6 +15,7 @@ """A collection of JAX utility functions for use in protein folding.""" import collections +import functools import numbers from typing import Mapping @@ -79,3 +80,52 @@ def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params: hk_params[scope][name] = jnp.array(array) return hk_params + + +def padding_consistent_rng(f): + """Modify any element-wise random function to be consistent with padding. + + Normally if you take a function like jax.random.normal and generate an array, + say of size (10,10), you will get a different set of random numbers to if you + add padding and take the first (10,10) sub-array. + + This function makes a random function that is consistent regardless of the + amount of padding added. + + Note: The padding-consistent function is likely to be slower to compile and + run than the function it is wrapping, but these slowdowns are likely to be + negligible in a large network. + + Args: + f: Any element-wise function that takes (PRNG key, shape) as the first 2 + arguments. + + Returns: + An equivalent function to f, that is now consistent for different amounts of + padding. + """ + def grid_keys(key, shape): + """Generate a grid of rng keys that is consistent with different padding. + + Generate random keys such that the keys will be identical, regardless of + how much padding is added to any dimension. + + Args: + key: A PRNG key. + shape: The shape of the output array of keys that will be generated. + + Returns: + An array of shape `shape` consisting of random keys. + """ + if not shape: + return key + new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))( + jnp.arange(shape[0])) + return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys) + + def inner(key, shape, **kwargs): + return jnp.vectorize( + lambda key: f(key, shape=(), **kwargs), + signature='(2)->()')( + grid_keys(key, shape)) + return inner diff --git a/alphafold/notebooks/__init__.py b/alphafold/notebooks/__init__.py new file mode 100644 index 000000000..cea1fea36 --- /dev/null +++ b/alphafold/notebooks/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""AlphaFold Colab notebook.""" diff --git a/alphafold/notebooks/notebook_utils.py b/alphafold/notebooks/notebook_utils.py new file mode 100644 index 000000000..3344b71f3 --- /dev/null +++ b/alphafold/notebooks/notebook_utils.py @@ -0,0 +1,182 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper methods for the AlphaFold Colab notebook.""" +import enum +import json +from typing import Any, Mapping, Optional, Sequence, Tuple + +from alphafold.common import residue_constants +from alphafold.data import parsers +from matplotlib import pyplot as plt +import numpy as np + + +@enum.unique +class ModelType(enum.Enum): + MONOMER = 0 + MULTIMER = 1 + + +def clean_and_validate_sequence( + input_sequence: str, min_length: int, max_length: int) -> str: + """Checks that the input sequence is ok and returns a clean version of it.""" + # Remove all whitespaces, tabs and end lines; upper-case. + clean_sequence = input_sequence.translate( + str.maketrans('', '', ' \n\t')).upper() + aatypes = set(residue_constants.restypes) # 20 standard aatypes. + if not set(clean_sequence).issubset(aatypes): + raise ValueError( + f'Input sequence contains non-amino acid letters: ' + f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' + 'amino acids as inputs.') + if len(clean_sequence) < min_length: + raise ValueError( + f'Input sequence is too short: {len(clean_sequence)} amino acids, ' + f'while the minimum is {min_length}') + if len(clean_sequence) > max_length: + raise ValueError( + f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' + f'the maximum is {max_length}. You may be able to run it with the full ' + f'AlphaFold system depending on your resources (system memory, ' + f'GPU memory).') + return clean_sequence + + +def validate_input( + input_sequences: Sequence[str], + min_length: int, + max_length: int, + max_multimer_length: int) -> Tuple[Sequence[str], ModelType]: + """Validates and cleans input sequences and determines which model to use.""" + sequences = [] + + for input_sequence in input_sequences: + if input_sequence.strip(): + input_sequence = clean_and_validate_sequence( + input_sequence=input_sequence, + min_length=min_length, + max_length=max_length) + sequences.append(input_sequence) + + if len(sequences) == 1: + print('Using the single-chain model.') + return sequences, ModelType.MONOMER + + elif len(sequences) > 1: + total_multimer_length = sum([len(seq) for seq in sequences]) + if total_multimer_length > max_multimer_length: + raise ValueError(f'The total length of multimer sequences is too long: ' + f'{total_multimer_length}, while the maximum is ' + f'{max_multimer_length}. Please use the full AlphaFold ' + f'system for long multimers.') + elif total_multimer_length > 1536: + print('WARNING: The accuracy of the system has not been fully validated ' + 'above 1536 residues, and you may experience long running times or ' + f'run out of memory for your complex with {total_multimer_length} ' + 'residues.') + print(f'Using the multimer model with {len(sequences)} sequences.') + return sequences, ModelType.MULTIMER + + else: + raise ValueError('No input amino acid sequence provided, please provide at ' + 'least one sequence.') + + +def merge_chunked_msa( + results: Sequence[Mapping[str, Any]], + max_hits: Optional[int] = None + ) -> parsers.Msa: + """Merges chunked database hits together into hits for the full database.""" + unsorted_results = [] + for chunk_index, chunk in enumerate(results): + msa = parsers.parse_stockholm(chunk['sto']) + e_values_dict = parsers.parse_e_values_from_tblout(chunk['tbl']) + # Jackhmmer lists sequences as /-. + e_values = [e_values_dict[t.partition('/')[0]] for t in msa.descriptions] + chunk_results = zip( + msa.sequences, msa.deletion_matrix, msa.descriptions, e_values) + if chunk_index != 0: + next(chunk_results) # Only take query (first hit) from the first chunk. + unsorted_results.extend(chunk_results) + + sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[-1]) + merged_sequences, merged_deletion_matrix, merged_descriptions, _ = zip( + *sorted_by_evalue) + merged_msa = parsers.Msa(sequences=merged_sequences, + deletion_matrix=merged_deletion_matrix, + descriptions=merged_descriptions) + if max_hits is not None: + merged_msa = merged_msa.truncate(max_seqs=max_hits) + + return merged_msa + + +def show_msa_info( + single_chain_msas: Sequence[parsers.Msa], + sequence_index: int): + """Prints info and shows a plot of the deduplicated single chain MSA.""" + full_single_chain_msa = [] + for single_chain_msa in single_chain_msas: + full_single_chain_msa.extend(single_chain_msa.sequences) + + # Deduplicate but preserve order (hence can't use set). + deduped_full_single_chain_msa = list(dict.fromkeys(full_single_chain_msa)) + total_msa_size = len(deduped_full_single_chain_msa) + print(f'\n{total_msa_size} unique sequences found in total for sequence ' + f'{sequence_index}\n') + + aa_map = {res: i for i, res in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')} + msa_arr = np.array( + [[aa_map[aa] for aa in seq] for seq in deduped_full_single_chain_msa]) + + plt.figure(figsize=(12, 3)) + plt.title(f'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence ' + f'{sequence_index}') + plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black') + plt.ylabel('Non-Gap Count') + plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3)))) + plt.show() + + +def empty_placeholder_template_features( + num_templates: int, num_res: int) -> Mapping[str, np.ndarray]: + return { + 'template_aatype': np.zeros( + (num_templates, num_res, + len(residue_constants.restypes_with_x_and_gap)), dtype=np.float32), + 'template_all_atom_masks': np.zeros( + (num_templates, num_res, residue_constants.atom_type_num), + dtype=np.float32), + 'template_all_atom_positions': np.zeros( + (num_templates, num_res, residue_constants.atom_type_num, 3), + dtype=np.float32), + 'template_domain_names': np.zeros([num_templates], dtype=np.object), + 'template_sequence': np.zeros([num_templates], dtype=np.object), + 'template_sum_probs': np.zeros([num_templates], dtype=np.float32), + } + + +def get_pae_json(pae: np.ndarray, max_pae: float) -> str: + """Returns the PAE in the same format as is used in the AFDB.""" + rounded_errors = np.round(pae.astype(np.float64), decimals=1) + indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1 + indices_1 = indices[0].flatten().tolist() + indices_2 = indices[1].flatten().tolist() + return json.dumps( + [{'residue1': indices_1, + 'residue2': indices_2, + 'distance': rounded_errors.flatten().tolist(), + 'max_predicted_aligned_error': max_pae}], + indent=None, separators=(',', ':')) diff --git a/alphafold/notebooks/notebook_utils_test.py b/alphafold/notebooks/notebook_utils_test.py new file mode 100644 index 000000000..bb76a95b9 --- /dev/null +++ b/alphafold/notebooks/notebook_utils_test.py @@ -0,0 +1,203 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for notebook_utils.""" +import io + +from absl.testing import absltest +from absl.testing import parameterized +from alphafold.data import parsers +from alphafold.data import templates +from alphafold.notebooks import notebook_utils + +import mock +import numpy as np + + +ONLY_QUERY_HIT = { + 'sto': ( + '# STOCKHOLM 1.0\n' + '#=GF ID query-i1\n' + 'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH\n' + '//\n'), + 'tbl': '', + 'stderr': b'', + 'n_iter': 1, + 'e_value': 0.0001} + +# pylint: disable=line-too-long +MULTI_SEQUENCE_HIT_1 = { + 'sto': ( + '# STOCKHOLM 1.0\n' + '#=GF ID query-i1\n' + '#=GS ERR1700680_4602609/41-109 DE [subseq from] ERR1700680_4602609\n' + '#=GS ERR1019366_5760491/40-105 DE [subseq from] ERR1019366_5760491\n' + '#=GS SRR5580704_12853319/61-125 DE [subseq from] SRR5580704_12853319\n' + 'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH\n' + 'ERR1700680_4602609/41-109 --INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHTEK--\n' + 'ERR1019366_5760491/40-105 ---RSGAQHHDAAAQHYEEAARHHRMAAKQYQASHHEKAAHYAQLAYAHHMYAEQHAAEAAK-AHAKNHG----\n' + 'SRR5580704_12853319/61-125 ----PAADHHMKAAEHHEEAAKHHRAAAEHHTAGDHQKAGHHAHVANGHHVNAVHHAEEASK-HHATDHS----\n' + '//\n'), + 'tbl': ( + 'ERR1700680_4602609 - query - 7.7e-09 47.7 33.8 1.1e-08 47.2 33.8 1.2 1 0 0 1 1 1 1 -\n' + 'ERR1019366_5760491 - query - 1.7e-08 46.6 33.1 2.5e-08 46.1 33.1 1.3 1 0 0 1 1 1 1 -\n' + 'SRR5580704_12853319 - query - 1.1e-07 44.0 41.6 2e-07 43.1 41.6 1.4 1 0 0 1 1 1 1 -\n'), + 'stderr': b'', + 'n_iter': 1, + 'e_value': 0.0001} + +MULTI_SEQUENCE_HIT_2 = { + 'sto': ( + '# STOCKHOLM 1.0\n' + '#=GF ID query-i1\n' + '#=GS ERR1700719_3476944/70-137 DE [subseq from] ERR1700719_3476944\n' + '#=GS ERR1700761_4254522/72-138 DE [subseq from] ERR1700761_4254522\n' + '#=GS SRR5438477_9761204/64-132 DE [subseq from] SRR5438477_9761204\n' + 'query MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH\n' + 'ERR1700719_3476944/70-137 ---KQAAEHHHQAAEHHEHAARHHREAAKHHEAGDHESAAHHAHTAQGHLHQATHHASEAAKLHVEHHGQK--\n' + 'ERR1700761_4254522/72-138 ----QASEHHNLAAEHHEHAARHHRDAAKHHKAGDHEKAAHHAHVAHGHHLHATHHATEAAKHHVEAHGEK--\n' + 'SRR5438477_9761204/64-132 MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE----\n' + '//\n'), + 'tbl': ( + 'ERR1700719_3476944 - query - 2e-07 43.2 47.5 3.5e-07 42.4 47.5 1.4 1 0 0 1 1 1 1 -\n' + 'ERR1700761_4254522 - query - 6.1e-07 41.6 48.1 8.1e-07 41.3 48.1 1.2 1 0 0 1 1 1 1 -\n' + 'SRR5438477_9761204 - query - 1.8e-06 40.2 46.9 2.3e-06 39.8 46.9 1.2 1 0 0 1 1 1 1 -\n'), + 'stderr': b'', + 'n_iter': 1, + 'e_value': 0.0001} +# pylint: enable=line-too-long + + +class NotebookUtilsTest(parameterized.TestCase): + + @parameterized.parameters( + ('DeepMind', 'DEEPMIND'), ('A ', 'A'), ('\tA', 'A'), (' A\t\n', 'A'), + ('ACDEFGHIKLMNPQRSTVWY', 'ACDEFGHIKLMNPQRSTVWY')) + def test_clean_and_validate_sequence_ok(self, sequence, exp_clean): + clean = notebook_utils.clean_and_validate_sequence( + sequence, min_length=1, max_length=100) + self.assertEqual(clean, exp_clean) + + @parameterized.named_parameters( + ('too_short', 'AA', 'too short'), + ('too_long', 'AAAAAAAAAA', 'too long'), + ('bad_amino_acids_B', 'BBBB', 'non-amino acid'), + ('bad_amino_acids_J', 'JJJJ', 'non-amino acid'), + ('bad_amino_acids_O', 'OOOO', 'non-amino acid'), + ('bad_amino_acids_U', 'UUUU', 'non-amino acid'), + ('bad_amino_acids_X', 'XXXX', 'non-amino acid'), + ('bad_amino_acids_Z', 'ZZZZ', 'non-amino acid')) + def test_clean_and_validate_sequence_bad(self, sequence, exp_error): + with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'): + notebook_utils.clean_and_validate_sequence( + sequence, min_length=4, max_length=8) + + @parameterized.parameters( + (['A', '', '', ' ', '\t', ' \t\n', '', ''], ['A'], + notebook_utils.ModelType.MONOMER), + (['', 'A'], ['A'], + notebook_utils.ModelType.MONOMER), + (['A', 'C ', ''], ['A', 'C'], + notebook_utils.ModelType.MULTIMER), + (['', 'A', '', 'C '], ['A', 'C'], + notebook_utils.ModelType.MULTIMER)) + def test_validate_input_ok( + self, input_sequences, exp_sequences, exp_model_type): + sequences, model_type = notebook_utils.validate_input( + input_sequences=input_sequences, + min_length=1, max_length=100, max_multimer_length=100) + self.assertSequenceEqual(sequences, exp_sequences) + self.assertEqual(model_type, exp_model_type) + + @parameterized.named_parameters( + ('no_input_sequence', ['', '\t', '\n'], 'No input amino acid sequence'), + ('too_long_single', ['AAAAAAAAA', 'AAAA'], 'Input sequence is too long'), + ('too_long_multimer', ['AAAA', 'AAAAA'], 'The total length of multimer')) + def test_validate_input_bad(self, input_sequences, exp_error): + with self.assertRaisesRegex(ValueError, f'.*{exp_error}.*'): + notebook_utils.validate_input( + input_sequences=input_sequences, + min_length=4, max_length=8, max_multimer_length=6) + + def test_merge_chunked_msa_no_hits(self): + results = [ONLY_QUERY_HIT, ONLY_QUERY_HIT] + merged_msa = notebook_utils.merge_chunked_msa( + results=results) + self.assertSequenceEqual( + merged_msa.sequences, + ('MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEH',)) + self.assertSequenceEqual(merged_msa.deletion_matrix, ([0] * 56,)) + + def test_merge_chunked_msa(self): + results = [MULTI_SEQUENCE_HIT_1, MULTI_SEQUENCE_HIT_2] + merged_msa = notebook_utils.merge_chunked_msa( + results=results) + self.assertLen(merged_msa.sequences, 7) + # The 1st one is the query. + self.assertEqual( + merged_msa.sequences[0], + 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAP' + 'KPH') + # The 2nd one is the one with the lowest e-value: ERR1700680_4602609. + self.assertEqual( + merged_msa.sequences[1], + '--INKGAEYHKKAAEHHELAAKHHREAAKHHEAGSHEKAAHHSEIAAGHGLTAVHHTEEATK-HHPEEHT' + 'EK-') + # The last one is the one with the largest e-value: SRR5438477_9761204. + self.assertEqual( + merged_msa.sequences[-1], + 'MPKHEGAEHHKKAAEHNEHAARHHKEAARHHEEGSHEKVGHHAHIAHGHHLHATHHAEEAAKTHSNQHE-' + '---') + self.assertLen(merged_msa.deletion_matrix, 7) + + @mock.patch('sys.stdout', new_callable=io.StringIO) + def test_show_msa_info(self, mocked_stdout): + single_chain_msas = [ + parsers.Msa(sequences=['A', 'B', 'C', 'C'], + deletion_matrix=[None] * 4, + descriptions=[''] * 4), + parsers.Msa(sequences=['A', 'A', 'A', 'D'], + deletion_matrix=[None] * 4, + descriptions=[''] * 4) + ] + notebook_utils.show_msa_info( + single_chain_msas=single_chain_msas, sequence_index=1) + self.assertEqual(mocked_stdout.getvalue(), + '\n4 unique sequences found in total for sequence 1\n\n') + + @parameterized.named_parameters( + ('some_templates', 4), ('no_templates', 0)) + def test_empty_placeholder_template_features(self, num_templates): + template_features = notebook_utils.empty_placeholder_template_features( + num_templates=num_templates, num_res=16) + self.assertCountEqual(template_features.keys(), + templates.TEMPLATE_FEATURES.keys()) + self.assertSameElements( + [v.shape[0] for v in template_features.values()], [num_templates]) + self.assertSequenceEqual( + [t.dtype for t in template_features.values()], + [np.array([], dtype=templates.TEMPLATE_FEATURES[feat_name]).dtype + for feat_name in template_features]) + + def test_get_pae_json(self): + pae = np.array([[0.01, 13.12345], [20.0987, 0.0]]) + pae_json = notebook_utils.get_pae_json(pae=pae, max_pae=31.75) + self.assertEqual( + pae_json, + '[{"residue1":[1,1,2,2],"residue2":[1,2,1,2],"distance":' + '[0.0,13.1,20.1,0.0],"max_predicted_aligned_error":31.75}]') + + +if __name__ == '__main__': + absltest.main() diff --git a/docker/run_docker.py b/docker/run_docker.py index 8d2d3fc7f..4eec39c9e 100644 --- a/docker/run_docker.py +++ b/docker/run_docker.py @@ -15,6 +15,7 @@ """Docker launch script for Alphafold docker image.""" import os +import pathlib import signal from typing import Tuple @@ -25,87 +26,54 @@ from docker import types -#### USER CONFIGURATION #### - -# Set to target of scripts/download_all_databases.sh -DOWNLOAD_DIR = 'SET ME' - -# Name of the AlphaFold Docker image. -docker_image_name = 'alphafold' - -# Path to a directory that will store the results. -output_dir = '/tmp/alphafold' - -# Names of models to use. -model_names = [ - 'model_1', - 'model_2', - 'model_3', - 'model_4', - 'model_5', -] - -# You can individually override the following paths if you have placed the -# data in locations other than the DOWNLOAD_DIR. - -# Path to directory of supporting data, contains 'params' dir. -data_dir = DOWNLOAD_DIR - -# Path to the Uniref90 database for use by JackHMMER. -uniref90_database_path = os.path.join( - DOWNLOAD_DIR, 'uniref90', 'uniref90.fasta') - -# Path to the MGnify database for use by JackHMMER. -mgnify_database_path = os.path.join( - DOWNLOAD_DIR, 'mgnify', 'mgy_clusters_2018_12.fa') - -# Path to the BFD database for use by HHblits. -bfd_database_path = os.path.join( - DOWNLOAD_DIR, 'bfd', - 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt') - -# Path to the Small BFD database for use by JackHMMER. -small_bfd_database_path = os.path.join( - DOWNLOAD_DIR, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta') - -# Path to the Uniclust30 database for use by HHblits. -uniclust30_database_path = os.path.join( - DOWNLOAD_DIR, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') - -# Path to the PDB70 database for use by HHsearch. -pdb70_database_path = os.path.join(DOWNLOAD_DIR, 'pdb70', 'pdb70') - -# Path to a directory with template mmCIF structures, each named .cif') -template_mmcif_dir = os.path.join(DOWNLOAD_DIR, 'pdb_mmcif', 'mmcif_files') - -# Path to a file mapping obsolete PDB IDs to their replacements. -obsolete_pdbs_path = os.path.join(DOWNLOAD_DIR, 'pdb_mmcif', 'obsolete.dat') - -#### END OF USER CONFIGURATION #### - - -flags.DEFINE_bool('use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') -flags.DEFINE_string('gpu_devices', 'all', 'Comma separated list of devices to ' - 'pass to NVIDIA_VISIBLE_DEVICES.') -flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing ' - 'one sequence. Paths should be separated by commas. ' - 'All FASTA paths must have a unique basename as the ' - 'basename is used to name the output directories for ' - 'each prediction.') -flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' - 'to consider (ISO-8601 format - i.e. YYYY-MM-DD). ' - 'Important if folding historical test sets.') -flags.DEFINE_enum('preset', 'full_dbs', - ['reduced_dbs', 'full_dbs', 'casp14'], - 'Choose preset model configuration - no ensembling and ' - 'smaller genetic database config (reduced_dbs), no ' - 'ensembling and full genetic database config (full_dbs) or ' - 'full genetic database config and 8 model ensemblings ' - '(casp14).') -flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' - 'to obtain a timing that excludes the compilation time, ' - 'which should be more indicative of the time required for ' - 'inferencing many proteins.') +flags.DEFINE_bool( + 'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') +flags.DEFINE_string( + 'gpu_devices', 'all', + 'Comma separated list of devices to pass to NVIDIA_VISIBLE_DEVICES.') +flags.DEFINE_list( + 'fasta_paths', None, + 'Paths to FASTA files, each containing one sequence. Paths should be ' + 'separated by commas. All FASTA paths must have a unique basename as the ' + 'basename is used to name the output directories for each prediction.') +flags.DEFINE_list('is_prokaryote_list', None, 'Optional for multimer system, ' + 'not used by the single chain system. ' + 'This list should contain a boolean for each fasta ' + 'specifying true where the target complex is from a ' + 'prokaryote, and false where it is not, or where the ' + 'origin is unknown. These values determine the pairing ' + 'method for the MSA.') +flags.DEFINE_string( + 'output_dir', '/tmp/alphafold', + 'Path to a directory that will store the results.') +flags.DEFINE_string( + 'data_dir', None, + 'Path to directory with supporting data: AlphaFold parameters and genetic ' + 'and template databases. Set to the target of download_all_databases.sh.') +flags.DEFINE_string( + 'docker_image_name', 'alphafold', 'Name of the AlphaFold Docker image.') +flags.DEFINE_string( + 'max_template_date', None, + 'Maximum template release date to consider (ISO-8601 format: YYYY-MM-DD). ' + 'Important if folding historical test sets.') +flags.DEFINE_enum( + 'db_preset', 'full_dbs', ['full_dbs', 'reduced_dbs'], + 'Choose preset MSA database configuration - smaller genetic database ' + 'config (reduced_dbs) or full genetic database config (full_dbs)') +flags.DEFINE_enum( + 'model_preset', 'monomer', + ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'], + 'Choose preset model configuration - the monomer model, the monomer model ' + 'with extra ensembling, monomer model with pTM head, or multimer model') +flags.DEFINE_boolean( + 'benchmark', False, + 'Run multiple JAX model evaluations to obtain a timing that excludes the ' + 'compilation time, which should be more indicative of the time required ' + 'for inferencing many proteins.') +flags.DEFINE_boolean( + 'use_precomputed_msas', False, + 'Whether to read MSAs that have been written to disk. WARNING: This will ' + 'not check if the sequence, database or configuration have changed.') FLAGS = flags.FLAGS @@ -125,6 +93,55 @@ def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') + # You can individually override the following paths if you have placed the + # data in locations other than the FLAGS.data_dir. + + # Path to the Uniref90 database for use by JackHMMER. + uniref90_database_path = os.path.join( + FLAGS.data_dir, 'uniref90', 'uniref90.fasta') + + # Path to the Uniprot database for use by JackHMMER. + uniprot_database_path = os.path.join( + FLAGS.data_dir, 'uniprot', 'uniprot.fasta') + + # Path to the MGnify database for use by JackHMMER. + mgnify_database_path = os.path.join( + FLAGS.data_dir, 'mgnify', 'mgy_clusters_2018_12.fa') + + # Path to the BFD database for use by HHblits. + bfd_database_path = os.path.join( + FLAGS.data_dir, 'bfd', + 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt') + + # Path to the Small BFD database for use by JackHMMER. + small_bfd_database_path = os.path.join( + FLAGS.data_dir, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta') + + # Path to the Uniclust30 database for use by HHblits. + uniclust30_database_path = os.path.join( + FLAGS.data_dir, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') + + # Path to the PDB70 database for use by HHsearch. + pdb70_database_path = os.path.join(FLAGS.data_dir, 'pdb70', 'pdb70') + + # Path to the PDB seqres database for use by hmmsearch. + pdb_seqres_database_path = os.path.join( + FLAGS.data_dir, 'pdb_seqres', 'pdb_seqres.txt') + + # Path to a directory with template mmCIF structures, each named .cif. + template_mmcif_dir = os.path.join(FLAGS.data_dir, 'pdb_mmcif', 'mmcif_files') + + # Path to a file mapping obsolete PDB IDs to their replacements. + obsolete_pdbs_path = os.path.join(FLAGS.data_dir, 'pdb_mmcif', 'obsolete.dat') + + alphafold_path = pathlib.Path(__file__).parent.parent + data_dir_path = pathlib.Path(FLAGS.data_dir) + if alphafold_path == data_dir_path or alphafold_path in data_dir_path.parents: + raise app.UsageError( + f'The download directory {FLAGS.data_dir} should not be a subdirectory ' + f'in the AlphaFold repository directory. If it is, the Docker build is ' + f'slow since the large databases are copied during the image creation.') + mounts = [] command_args = [] @@ -139,12 +156,19 @@ def main(argv): database_paths = [ ('uniref90_database_path', uniref90_database_path), ('mgnify_database_path', mgnify_database_path), - ('pdb70_database_path', pdb70_database_path), - ('data_dir', data_dir), + ('data_dir', FLAGS.data_dir), ('template_mmcif_dir', template_mmcif_dir), ('obsolete_pdbs_path', obsolete_pdbs_path), ] - if FLAGS.preset == 'reduced_dbs': + + if FLAGS.model_preset == 'multimer': + database_paths.append(('uniprot_database_path', uniprot_database_path)) + database_paths.append(('pdb_seqres_database_path', + pdb_seqres_database_path)) + else: + database_paths.append(('pdb70_database_path', pdb70_database_path)) + + if FLAGS.db_preset == 'reduced_dbs': database_paths.append(('small_bfd_database_path', small_bfd_database_path)) else: database_paths.extend([ @@ -158,20 +182,25 @@ def main(argv): command_args.append(f'--{name}={target_path}') output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output') - mounts.append(types.Mount(output_target_path, output_dir, type='bind')) + mounts.append(types.Mount(output_target_path, FLAGS.output_dir, type='bind')) command_args.extend([ f'--output_dir={output_target_path}', - f'--model_names={",".join(model_names)}', f'--max_template_date={FLAGS.max_template_date}', - f'--preset={FLAGS.preset}', + f'--db_preset={FLAGS.db_preset}', + f'--model_preset={FLAGS.model_preset}', f'--benchmark={FLAGS.benchmark}', + f'--use_precomputed_msas={FLAGS.use_precomputed_msas}', '--logtostderr', ]) + if FLAGS.is_prokaryote_list: + command_args.append( + f'--is_prokaryote_list={",".join(FLAGS.is_prokaryote_list)}') + client = docker.from_env() container = client.containers.run( - image=docker_image_name, + image=FLAGS.docker_image_name, command=command_args, runtime='nvidia' if FLAGS.use_gpu else None, remove=True, @@ -195,6 +224,7 @@ def main(argv): if __name__ == '__main__': flags.mark_flags_as_required([ + 'data_dir', 'fasta_paths', 'max_template_date', ]) diff --git a/notebooks/AlphaFold.ipynb b/notebooks/AlphaFold.ipynb index 6736838c0..9e20dee00 100644 --- a/notebooks/AlphaFold.ipynb +++ b/notebooks/AlphaFold.ipynb @@ -1,694 +1,794 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "pc5-mbsX9PZC" - }, - "source": [ - "# AlphaFold Colab\n", - "\n", - "This Colab notebook allows you to easily predict the structure of a protein using a slightly simplified version of [AlphaFold v2.0](https://doi.org/10.1038/s41586-021-03819-2). \n", - "\n", - "**Differences to AlphaFold v2.0**\n", - "\n", - "In comparison to AlphaFold v2.0, this Colab notebook uses **no templates (homologous structures)** and a selected portion of the [BFD database](https://bfd.mmseqs.com/). We have validated these changes on several thousand recent PDB structures. While accuracy will be near-identical to the full AlphaFold system on many targets, a small fraction have a large drop in accuracy due to the smaller MSA and lack of templates. For best reliability, we recommend instead using the [full open source AlphaFold](https://github.com/deepmind/alphafold/), or the [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).\n", - "\n", - "Please note that this Colab notebook is provided as an early-access prototype and is not a finished product. It is provided for theoretical modelling only and caution should be exercised in its use. \n", - "\n", - "**Citing this work**\n", - "\n", - "Any publication that discloses findings arising from using this notebook should [cite](https://github.com/deepmind/alphafold/#citing-this-work) the [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2).\n", - "\n", - "**Licenses**\n", - "\n", - "This Colab uses the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license) and its outputs are thus for non-commercial use only, under the Creative Commons Attribution-NonCommercial 4.0 International ([CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n", - "\n", - "**More information**\n", - "\n", - "You can find more information about how AlphaFold works in our two Nature papers:\n", - "\n", - "* [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2)\n", - "* [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1)\n", - "\n", - "FAQ on how to interpret AlphaFold predictions are [here](https://alphafold.ebi.ac.uk/faq)." - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "pc5-mbsX9PZC" + }, + "source": [ + "# AlphaFold Colab\n", + "\n", + "This Colab notebook allows you to easily predict the structure of a protein using a slightly simplified version of [AlphaFold v2.1.0](https://doi.org/10.1038/s41586-021-03819-2). \n", + "\n", + "**Differences to AlphaFold v2.1.0**\n", + "\n", + "In comparison to AlphaFold v2.1.0, this Colab notebook uses **no templates (homologous structures)** and a selected portion of the [BFD database](https://bfd.mmseqs.com/). We have validated these changes on several thousand recent PDB structures. While accuracy will be near-identical to the full AlphaFold system on many targets, a small fraction have a large drop in accuracy due to the smaller MSA and lack of templates. For best reliability, we recommend instead using the [full open source AlphaFold](https://github.com/deepmind/alphafold/), or the [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).\n", + "\n", + "**This Colab has an small drop in average accuracy for multimers compared to local AlphaFold installation, for full multimer accuracy it is highly recommended to run [AlphaFold locally](https://github.com/deepmind/alphafold#running-alphafold).** Moreover, the AlphaFold-Multimer requires searching for MSA for every unique sequence in the complex, hence it is substantially slower. If your notebook times-out due to slow multimer MSA search, we recommend either using Colab Pro or running AlphaFold locally.\n", + "\n", + "Please note that this Colab notebook is provided as an early-access prototype and is not a finished product. It is provided for theoretical modelling only and caution should be exercised in its use. \n", + "\n", + "**Citing this work**\n", + "\n", + "Any publication that discloses findings arising from using this notebook should [cite](https://github.com/deepmind/alphafold/#citing-this-work) the [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2).\n", + "\n", + "**Licenses**\n", + "\n", + "This Colab uses the [AlphaFold model parameters](https://github.com/deepmind/alphafold/#model-parameters-license) and its outputs are thus for non-commercial use only, under the Creative Commons Attribution-NonCommercial 4.0 International ([CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/legalcode)) license. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0). See the full license statement below.\n", + "\n", + "**More information**\n", + "\n", + "You can find more information about how AlphaFold works in the following papers:\n", + "\n", + "* [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2)\n", + "* [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1)\n", + "* [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1)\n", + "\n", + "FAQ on how to interpret AlphaFold predictions are [here](https://alphafold.ebi.ac.uk/faq)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "woIxeCPygt7K" + }, + "outputs": [], + "source": [ + "#@title Install third-party software\n", + "\n", + "#@markdown Please execute this cell by pressing the _Play_ button \n", + "#@markdown on the left to download and import third-party software \n", + "#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in our readme.)\n", + "\n", + "#@markdown **Note**: This installs the software on the Colab \n", + "#@markdown notebook in the cloud and not on your computer.\n", + "\n", + "from IPython.utils import io\n", + "import os\n", + "import subprocess\n", + "import tqdm.notebook\n", + "\n", + "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", + "\n", + "try:\n", + " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " with io.capture_output() as captured:\n", + " # Uninstall default Colab version of TF.\n", + " %shell pip uninstall -y tensorflow\n", + "\n", + " %shell sudo apt install --quiet --yes hmmer\n", + " pbar.update(6)\n", + "\n", + " # Install py3dmol.\n", + " %shell pip install py3dmol\n", + " pbar.update(2)\n", + "\n", + " # Install OpenMM and pdbfixer.\n", + " %shell rm -rf /opt/conda\n", + " %shell wget -q -P /tmp \\\n", + " https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n", + " \u0026\u0026 bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n", + " \u0026\u0026 rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n", + " pbar.update(9)\n", + "\n", + " PATH=%env PATH\n", + " %env PATH=/opt/conda/bin:{PATH}\n", + " %shell conda update -qy conda \\\n", + " \u0026\u0026 conda install -qy -c conda-forge \\\n", + " python=3.7 \\\n", + " openmm=7.5.1 \\\n", + " pdbfixer\n", + " pbar.update(80)\n", + "\n", + " # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n", + " %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n", + " %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n", + " pbar.update(2)\n", + "\n", + " %shell wget -q -P /content \\\n", + " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", + " pbar.update(1)\n", + "except subprocess.CalledProcessError:\n", + " print(captured)\n", + " raise" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "VzJ5iMjTtoZw" + }, + "outputs": [], + "source": [ + "#@title Download AlphaFold\n", + "\n", + "#@markdown Please execute this cell by pressing the *Play* button on \n", + "#@markdown the left.\n", + "\n", + "GIT_REPO = 'https://github.com/deepmind/alphafold'\n", + "\n", + "SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_colab_2021-10-27.tar'\n", + "PARAMS_DIR = './alphafold/data/params'\n", + "PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n", + "\n", + "try:\n", + " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " with io.capture_output() as captured:\n", + " %shell rm -rf alphafold\n", + " %shell git clone --branch main {GIT_REPO} alphafold\n", + " pbar.update(8)\n", + " # Install the required versions of all dependencies.\n", + " %shell pip3 install -r ./alphafold/requirements.txt\n", + " # Run setup.py to install only AlphaFold.\n", + " %shell pip3 install --no-dependencies ./alphafold\n", + " pbar.update(10)\n", + "\n", + " # Apply OpenMM patch.\n", + " %shell pushd /opt/conda/lib/python3.7/site-packages/ \u0026\u0026 \\\n", + " patch -p0 \u003c /content/alphafold/docker/openmm.patch \u0026\u0026 \\\n", + " popd\n", + "\n", + " # Make sure stereo_chemical_props.txt is in all locations where it could be searched for.\n", + " %shell mkdir -p /content/alphafold/alphafold/common\n", + " %shell cp -f /content/stereo_chemical_props.txt /content/alphafold/alphafold/common\n", + " %shell mkdir -p /opt/conda/lib/python3.7/site-packages/alphafold/common/\n", + " %shell cp -f /content/stereo_chemical_props.txt /opt/conda/lib/python3.7/site-packages/alphafold/common/\n", + "\n", + " %shell mkdir --parents \"{PARAMS_DIR}\"\n", + " %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n", + " pbar.update(27)\n", + "\n", + " %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n", + " --directory=\"{PARAMS_DIR}\" --preserve-permissions\n", + " %shell rm \"{PARAMS_PATH}\"\n", + " pbar.update(55)\n", + "except subprocess.CalledProcessError:\n", + " print(captured)\n", + " raise\n", + "\n", + "import jax\n", + "if jax.local_devices()[0].platform == 'tpu':\n", + " raise RuntimeError('Colab TPU runtime not supported. Change it to GPU via Runtime -\u003e Change Runtime Type -\u003e Hardware accelerator -\u003e GPU.')\n", + "elif jax.local_devices()[0].platform == 'cpu':\n", + " raise RuntimeError('Colab CPU runtime not supported. Change it to GPU via Runtime -\u003e Change Runtime Type -\u003e Hardware accelerator -\u003e GPU.')\n", + "else:\n", + " print(f'Running with {jax.local_devices()[0].device_kind} GPU')\n", + "\n", + "# Make sure everything we need is on the path.\n", + "import sys\n", + "sys.path.append('/opt/conda/lib/python3.7/site-packages')\n", + "sys.path.append('/content/alphafold')\n", + "\n", + "# Make sure all necessary environment variables are set.\n", + "import os\n", + "os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n", + "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "W4JpOs6oA-QS" + }, + "source": [ + "## Making a prediction\n", + "\n", + "Please paste the sequence of your protein in the text box below, then run the remaining cells via _Runtime_ \u003e _Run after_. You can also run the cells individually by pressing the _Play_ button on the left.\n", + "\n", + "Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "rowN0bVYLe9n" + }, + "outputs": [], + "source": [ + "#@title Enter the amino acid sequence(s) to fold ⬇️\n", + "#@markdown Enter the amino acid sequence(s) to fold:\n", + "#@markdown * If you enter only a single sequence, the monomer model will be used.\n", + "#@markdown * If you enter multiple sequences, the multimer model will be used.\n", + "\n", + "from alphafold.notebooks import notebook_utils\n", + "\n", + "sequence_1 = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", + "sequence_2 = '' #@param {type:\"string\"}\n", + "sequence_3 = '' #@param {type:\"string\"}\n", + "sequence_4 = '' #@param {type:\"string\"}\n", + "sequence_5 = '' #@param {type:\"string\"}\n", + "sequence_6 = '' #@param {type:\"string\"}\n", + "sequence_7 = '' #@param {type:\"string\"}\n", + "sequence_8 = '' #@param {type:\"string\"}\n", + "\n", + "input_sequences = (sequence_1, sequence_2, sequence_3, sequence_4,\n", + " sequence_5, sequence_6, sequence_7, sequence_8)\n", + "\n", + "#@markdown If folding a complex target and all the input sequences are\n", + "#@markdown prokaryotic then set `is_prokaryotic` to `True`. Set to `False`\n", + "#@markdown otherwise or if the origin is unknown.\n", + "\n", + "is_prokaryote = False #@param {type:\"boolean\"}\n", + "\n", + "MIN_SINGLE_SEQUENCE_LENGTH = 16\n", + "MAX_SINGLE_SEQUENCE_LENGTH = 2500\n", + "MAX_MULTIMER_LENGTH = 2500\n", + "\n", + "# Validate the input.\n", + "sequences, model_type_to_use = notebook_utils.validate_input(\n", + " input_sequences=input_sequences,\n", + " min_length=MIN_SINGLE_SEQUENCE_LENGTH,\n", + " max_length=MAX_SINGLE_SEQUENCE_LENGTH,\n", + " max_multimer_length=MAX_MULTIMER_LENGTH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "2tTeTTsLKPjB" + }, + "outputs": [], + "source": [ + "#@title Search against genetic databases\n", + "\n", + "#@markdown Once this cell has been executed, you will see\n", + "#@markdown statistics about the multiple sequence alignment \n", + "#@markdown (MSA) that will be used by AlphaFold. In particular, \n", + "#@markdown you’ll see how well each residue is covered by similar \n", + "#@markdown sequences in the MSA.\n", + "\n", + "# --- Python imports ---\n", + "import collections\n", + "import copy\n", + "from concurrent import futures\n", + "import json\n", + "import random\n", + "\n", + "from urllib import request\n", + "from google.colab import files\n", + "from matplotlib import gridspec\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import py3Dmol\n", + "\n", + "from alphafold.model import model\n", + "from alphafold.model import config\n", + "from alphafold.model import data\n", + "\n", + "from alphafold.data import feature_processing\n", + "from alphafold.data import msa_pairing\n", + "from alphafold.data import parsers\n", + "from alphafold.data import pipeline\n", + "from alphafold.data import pipeline_multimer\n", + "from alphafold.data.tools import jackhmmer\n", + "\n", + "from alphafold.common import protein\n", + "\n", + "from alphafold.relax import relax\n", + "from alphafold.relax import utils\n", + "\n", + "from IPython import display\n", + "from ipywidgets import GridspecLayout\n", + "from ipywidgets import Output\n", + "\n", + "# Color bands for visualizing plddt\n", + "PLDDT_BANDS = [(0, 50, '#FF7D45'),\n", + " (50, 70, '#FFDB13'),\n", + " (70, 90, '#65CBF3'),\n", + " (90, 100, '#0053D6')]\n", + "\n", + "# --- Find the closest source ---\n", + "test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n", + "ex = futures.ThreadPoolExecutor(3)\n", + "def fetch(source):\n", + " request.urlretrieve(test_url_pattern.format(source))\n", + " return source\n", + "fs = [ex.submit(fetch, source) for source in ['', '-europe', '-asia']]\n", + "source = None\n", + "for f in futures.as_completed(fs):\n", + " source = f.result()\n", + " ex.shutdown()\n", + " break\n", + "\n", + "JACKHMMER_BINARY_PATH = '/usr/bin/jackhmmer'\n", + "DB_ROOT_PATH = f'https://storage.googleapis.com/alphafold-colab{source}/latest/'\n", + "# The z_value is the number of sequences in a database.\n", + "MSA_DATABASES = [\n", + " {'db_name': 'uniref90',\n", + " 'db_path': f'{DB_ROOT_PATH}uniref90_2021_03.fasta',\n", + " 'num_streamed_chunks': 59,\n", + " 'z_value': 135_301_051},\n", + " {'db_name': 'smallbfd',\n", + " 'db_path': f'{DB_ROOT_PATH}bfd-first_non_consensus_sequences.fasta',\n", + " 'num_streamed_chunks': 17,\n", + " 'z_value': 65_984_053},\n", + " {'db_name': 'mgnify',\n", + " 'db_path': f'{DB_ROOT_PATH}mgy_clusters_2019_05.fasta',\n", + " 'num_streamed_chunks': 71,\n", + " 'z_value': 304_820_129},\n", + "]\n", + "\n", + "# Search UniProt and construct the all_seq features only for heteromers, not homomers.\n", + "if model_type_to_use == notebook_utils.ModelType.MULTIMER and len(set(sequences)) \u003e 1:\n", + " MSA_DATABASES.extend([\n", + " # Swiss-Prot and TrEMBL are concatenated together as UniProt.\n", + " {'db_name': 'uniprot',\n", + " 'db_path': f'{DB_ROOT_PATH}uniprot_2021_03.fasta',\n", + " 'num_streamed_chunks': 98,\n", + " 'z_value': 219_174_961 + 565_254},\n", + " ])\n", + "\n", + "TOTAL_JACKHMMER_CHUNKS = sum([cfg['num_streamed_chunks'] for cfg in MSA_DATABASES])\n", + "\n", + "MAX_HITS = {\n", + " 'uniref90': 10_000,\n", + " 'smallbfd': 5_000,\n", + " 'mgnify': 501,\n", + " 'uniprot': 50_000,\n", + "}\n", + "\n", + "\n", + "def get_msa(fasta_path):\n", + " \"\"\"Searches for MSA for the given sequence using chunked Jackhmmer search.\"\"\"\n", + "\n", + " # Run the search against chunks of genetic databases (since the genetic\n", + " # databases don't fit in Colab disk).\n", + " raw_msa_results = collections.defaultdict(list)\n", + " with tqdm.notebook.tqdm(total=TOTAL_JACKHMMER_CHUNKS, bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " def jackhmmer_chunk_callback(i):\n", + " pbar.update(n=1)\n", + "\n", + " for db_config in MSA_DATABASES:\n", + " db_name = db_config['db_name']\n", + " pbar.set_description(f'Searching {db_name}')\n", + " jackhmmer_runner = jackhmmer.Jackhmmer(\n", + " binary_path=JACKHMMER_BINARY_PATH,\n", + " database_path=db_config['db_path'],\n", + " get_tblout=True,\n", + " num_streamed_chunks=db_config['num_streamed_chunks'],\n", + " streaming_callback=jackhmmer_chunk_callback,\n", + " z_value=db_config['z_value'])\n", + " # Group the results by database name.\n", + " raw_msa_results[db_name].extend(jackhmmer_runner.query(fasta_path))\n", + "\n", + " return raw_msa_results\n", + "\n", + "\n", + "features_for_chain = {}\n", + "raw_msa_results_for_sequence = {}\n", + "for sequence_index, sequence in enumerate(sequences, start=1):\n", + " print(f'\\nGetting MSA for sequence {sequence_index}')\n", + "\n", + " fasta_path = f'target_{sequence_index}.fasta'\n", + " with open(fasta_path, 'wt') as f:\n", + " f.write(f'\u003equery\\n{sequence}')\n", + "\n", + " # Don't do redundant work for multiple copies of the same chain in the multimer.\n", + " if sequence not in raw_msa_results_for_sequence:\n", + " raw_msa_results = get_msa(fasta_path=fasta_path)\n", + " raw_msa_results_for_sequence[sequence] = raw_msa_results\n", + " else:\n", + " raw_msa_results = copy.deepcopy(raw_msa_results_for_sequence[sequence])\n", + "\n", + " # Extract the MSAs from the Stockholm files.\n", + " # NB: deduplication happens later in pipeline.make_msa_features.\n", + " single_chain_msas = []\n", + " uniprot_msa = None\n", + " for db_name, db_results in raw_msa_results.items():\n", + " merged_msa = notebook_utils.merge_chunked_msa(\n", + " results=db_results, max_hits=MAX_HITS.get(db_name))\n", + " if merged_msa.sequences and db_name != 'uniprot':\n", + " single_chain_msas.append(merged_msa)\n", + " msa_size = len(set(merged_msa.sequences))\n", + " print(f'{msa_size} unique sequences found in {db_name} for sequence {sequence_index}')\n", + " elif merged_msa.sequences and db_name == 'uniprot':\n", + " uniprot_msa = merged_msa\n", + "\n", + " notebook_utils.show_msa_info(single_chain_msas=single_chain_msas, sequence_index=sequence_index)\n", + "\n", + " # Turn the raw data into model features.\n", + " feature_dict = {}\n", + " feature_dict.update(pipeline.make_sequence_features(\n", + " sequence=sequence, description='query', num_res=len(sequence)))\n", + " feature_dict.update(pipeline.make_msa_features(msas=single_chain_msas))\n", + " # We don't use templates in AlphaFold Colab notebook, add only empty placeholder features.\n", + " feature_dict.update(notebook_utils.empty_placeholder_template_features(\n", + " num_templates=0, num_res=len(sequence)))\n", + "\n", + " # Construct the all_seq features only for heteromers, not homomers.\n", + " if model_type_to_use == notebook_utils.ModelType.MULTIMER and len(set(sequences)) \u003e 1:\n", + " valid_feats = msa_pairing.MSA_FEATURES + (\n", + " 'msa_uniprot_accession_identifiers',\n", + " 'msa_species_identifiers',\n", + " )\n", + " all_seq_features = {\n", + " f'{k}_all_seq': v for k, v in pipeline.make_msa_features([uniprot_msa]).items()\n", + " if k in valid_feats}\n", + " feature_dict.update(all_seq_features)\n", + "\n", + " features_for_chain[protein.PDB_CHAIN_IDS[sequence_index - 1]] = feature_dict\n", + "\n", + "\n", + "# Do further feature post-processing depending on the model type.\n", + "if model_type_to_use == notebook_utils.ModelType.MONOMER:\n", + " np_example = features_for_chain[protein.PDB_CHAIN_IDS[0]]\n", + "\n", + "elif model_type_to_use == notebook_utils.ModelType.MULTIMER:\n", + " all_chain_features = {}\n", + " for chain_id, chain_features in features_for_chain.items():\n", + " all_chain_features[chain_id] = pipeline_multimer.convert_monomer_features(\n", + " chain_features, chain_id)\n", + "\n", + " all_chain_features = pipeline_multimer.add_assembly_features(all_chain_features)\n", + "\n", + " np_example = feature_processing.pair_and_merge(\n", + " all_chain_features=all_chain_features, is_prokaryote=is_prokaryote)\n", + "\n", + " # Pad MSA to avoid zero-sized extra_msa.\n", + " np_example = pipeline_multimer.pad_msa(np_example, min_num_seq=512)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "XUo6foMQxwS2" + }, + "outputs": [], + "source": [ + "#@title Run AlphaFold and download prediction\n", + "\n", + "#@markdown Once this cell has been executed, a zip-archive with\n", + "#@markdown the obtained prediction will be automatically downloaded\n", + "#@markdown to your computer.\n", + "\n", + "#@markdown In case you are having issues with the relaxation stage, you can disable it below.\n", + "#@markdown Warning: This means that the prediction might have distracting\n", + "#@markdown small stereochemical violations.\n", + "\n", + "run_relax = True #@param {type:\"boolean\"}\n", + "\n", + "# --- Run the model ---\n", + "if model_type_to_use == notebook_utils.ModelType.MONOMER:\n", + " model_names = config.MODEL_PRESETS['monomer'] + ('model_2_ptm',)\n", + "elif model_type_to_use == notebook_utils.ModelType.MULTIMER:\n", + " model_names = config.MODEL_PRESETS['multimer']\n", + "\n", + "output_dir = 'prediction'\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "plddts = {}\n", + "ranking_confidences = {}\n", + "pae_outputs = {}\n", + "unrelaxed_proteins = {}\n", + "\n", + "with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", + " for model_name in model_names:\n", + " pbar.set_description(f'Running {model_name}')\n", + "\n", + " cfg = config.model_config(model_name)\n", + " if model_type_to_use == notebook_utils.ModelType.MONOMER:\n", + " cfg.data.eval.num_ensemble = 1\n", + " elif model_type_to_use == notebook_utils.ModelType.MULTIMER:\n", + " cfg.model.num_ensemble_eval = 1\n", + " params = data.get_model_haiku_params(model_name, './alphafold/data')\n", + " model_runner = model.RunModel(cfg, params)\n", + " processed_feature_dict = model_runner.process_features(np_example, random_seed=0)\n", + " prediction = model_runner.predict(processed_feature_dict, random_seed=random.randrange(sys.maxsize))\n", + "\n", + " mean_plddt = prediction['plddt'].mean()\n", + "\n", + " if model_type_to_use == notebook_utils.ModelType.MONOMER:\n", + " if 'predicted_aligned_error' in prediction:\n", + " pae_outputs[model_name] = (prediction['predicted_aligned_error'],\n", + " prediction['max_predicted_aligned_error'])\n", + " else:\n", + " # Monomer models are sorted by mean pLDDT. Do not put monomer pTM models here as they\n", + " # should never get selected.\n", + " ranking_confidences[model_name] = prediction['ranking_confidence']\n", + " plddts[model_name] = prediction['plddt']\n", + " elif model_type_to_use == notebook_utils.ModelType.MULTIMER:\n", + " # Multimer models are sorted by pTM+ipTM.\n", + " ranking_confidences[model_name] = prediction['ranking_confidence']\n", + " plddts[model_name] = prediction['plddt']\n", + " pae_outputs[model_name] = (prediction['predicted_aligned_error'],\n", + " prediction['max_predicted_aligned_error'])\n", + "\n", + " # Set the b-factors to the per-residue plddt.\n", + " final_atom_mask = prediction['structure_module']['final_atom_mask']\n", + " b_factors = prediction['plddt'][:, None] * final_atom_mask\n", + " unrelaxed_protein = protein.from_prediction(\n", + " processed_feature_dict,\n", + " prediction,\n", + " b_factors=b_factors,\n", + " remove_leading_feature_dimension=(\n", + " model_type_to_use == notebook_utils.ModelType.MONOMER))\n", + " unrelaxed_proteins[model_name] = unrelaxed_protein\n", + "\n", + " # Delete unused outputs to save memory.\n", + " del model_runner\n", + " del params\n", + " del prediction\n", + " pbar.update(n=1)\n", + "\n", + " # --- AMBER relax the best model ---\n", + "\n", + " # Find the best model according to the mean pLDDT.\n", + " best_model_name = max(ranking_confidences.keys(), key=lambda x: ranking_confidences[x])\n", + "\n", + " if run_relax:\n", + " pbar.set_description(f'AMBER relaxation')\n", + " amber_relaxer = relax.AmberRelaxation(\n", + " max_iterations=0,\n", + " tolerance=2.39,\n", + " stiffness=10.0,\n", + " exclude_residues=[],\n", + " max_outer_iterations=3)\n", + " relaxed_pdb, _, _ = amber_relaxer.process(prot=unrelaxed_proteins[best_model_name])\n", + " else:\n", + " print('Warning: Running without the relaxation stage.')\n", + " relaxed_pdb = protein.to_pdb(unrelaxed_proteins[best_model_name])\n", + " pbar.update(n=1) # Finished AMBER relax.\n", + "\n", + "# Construct multiclass b-factors to indicate confidence bands\n", + "# 0=very low, 1=low, 2=confident, 3=very high\n", + "banded_b_factors = []\n", + "for plddt in plddts[best_model_name]:\n", + " for idx, (min_val, max_val, _) in enumerate(PLDDT_BANDS):\n", + " if plddt \u003e= min_val and plddt \u003c= max_val:\n", + " banded_b_factors.append(idx)\n", + " break\n", + "banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n", + "to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)\n", + "\n", + "\n", + "# Write out the prediction\n", + "pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n", + "with open(pred_output_path, 'w') as f:\n", + " f.write(relaxed_pdb)\n", + "\n", + "\n", + "# --- Visualise the prediction \u0026 confidence ---\n", + "show_sidechains = True\n", + "def plot_plddt_legend():\n", + " \"\"\"Plots the legend for pLDDT.\"\"\"\n", + " thresh = ['Very low (pLDDT \u003c 50)',\n", + " 'Low (70 \u003e pLDDT \u003e 50)',\n", + " 'Confident (90 \u003e pLDDT \u003e 70)',\n", + " 'Very high (pLDDT \u003e 90)']\n", + "\n", + " colors = [x[2] for x in PLDDT_BANDS]\n", + "\n", + " plt.figure(figsize=(2, 2))\n", + " for c in colors:\n", + " plt.bar(0, 0, color=c)\n", + " plt.legend(thresh, frameon=False, loc='center', fontsize=20)\n", + " plt.xticks([])\n", + " plt.yticks([])\n", + " ax = plt.gca()\n", + " ax.spines['right'].set_visible(False)\n", + " ax.spines['top'].set_visible(False)\n", + " ax.spines['left'].set_visible(False)\n", + " ax.spines['bottom'].set_visible(False)\n", + " plt.title('Model Confidence', fontsize=20, pad=20)\n", + " return plt\n", + "\n", + "# Show the structure coloured by chain if the multimer model has been used.\n", + "if model_type_to_use == notebook_utils.ModelType.MULTIMER:\n", + " multichain_view = py3Dmol.view(width=800, height=600)\n", + " multichain_view.addModelsAsFrames(to_visualize_pdb)\n", + " multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n", + " multichain_view.setStyle({'model': -1}, multichain_style)\n", + " multichain_view.zoomTo()\n", + " multichain_view.show()\n", + "\n", + "# Color the structure by per-residue pLDDT\n", + "color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n", + "view = py3Dmol.view(width=800, height=600)\n", + "view.addModelsAsFrames(to_visualize_pdb)\n", + "style = {'cartoon': {'colorscheme': {'prop': 'b', 'map': color_map}}}\n", + "if show_sidechains:\n", + " style['stick'] = {}\n", + "view.setStyle({'model': -1}, style)\n", + "view.zoomTo()\n", + "\n", + "grid = GridspecLayout(1, 2)\n", + "out = Output()\n", + "with out:\n", + " view.show()\n", + "grid[0, 0] = out\n", + "\n", + "out = Output()\n", + "with out:\n", + " plot_plddt_legend().show()\n", + "grid[0, 1] = out\n", + "\n", + "display.display(grid)\n", + "\n", + "# Display pLDDT and predicted aligned error (if output by the model).\n", + "if pae_outputs:\n", + " num_plots = 2\n", + "else:\n", + " num_plots = 1\n", + "\n", + "plt.figure(figsize=[8 * num_plots, 6])\n", + "plt.subplot(1, num_plots, 1)\n", + "plt.plot(plddts[best_model_name])\n", + "plt.title('Predicted LDDT')\n", + "plt.xlabel('Residue')\n", + "plt.ylabel('pLDDT')\n", + "\n", + "if num_plots == 2:\n", + " plt.subplot(1, 2, 2)\n", + " pae, max_pae = list(pae_outputs.values())[0]\n", + " plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n", + " plt.colorbar(fraction=0.046, pad=0.04)\n", + "\n", + " # Display lines at chain boundaries.\n", + " best_unrelaxed_prot = unrelaxed_proteins[best_model_name]\n", + " total_num_res = best_unrelaxed_prot.residue_index.shape[-1]\n", + " chain_ids = best_unrelaxed_prot.chain_index\n", + " for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n", + " plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n", + " plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n", + "\n", + " plt.title('Predicted Aligned Error')\n", + " plt.xlabel('Scored residue')\n", + " plt.ylabel('Aligned residue')\n", + "\n", + "# Save the predicted aligned error (if it exists).\n", + "pae_output_path = os.path.join(output_dir, 'predicted_aligned_error.json')\n", + "if pae_outputs:\n", + " # Save predicted aligned error in the same format as the AF EMBL DB.\n", + " pae_data = notebook_utils.get_pae_json(pae=pae, max_pae=max_pae.item())\n", + " with open(pae_output_path, 'w') as f:\n", + " f.write(pae_data)\n", + "\n", + "# --- Download the predictions ---\n", + "!zip -q -r {output_dir}.zip {output_dir}\n", + "files.download(f'{output_dir}.zip')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lUQAn5LYC5n4" + }, + "source": [ + "### Interpreting the prediction\n", + "\n", + "In general predicted LDDT (pLDDT) is best used for intra-domain confidence, whereas Predicted Aligned Error (PAE) is best used for determining between domain or between chain confidence.\n", + "\n", + "Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2), the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), and the [AlphaFold-Multimer paper](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) as well as [our FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold predictions." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jeb2z8DIA4om" + }, + "source": [ + "## FAQ \u0026 Troubleshooting\n", + "\n", + "\n", + "* How do I get a predicted protein structure for my protein?\n", + " * Click on the _Connect_ button on the top right to get started.\n", + " * Paste the amino acid sequence of your protein (without any headers) into the “Enter the amino acid sequence to fold”.\n", + " * Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ \u003e _Run all._\n", + " * The predicted protein structure will be downloaded once all cells have been executed. Note: This can take minutes to hours - see below.\n", + "* How long will this take?\n", + " * Downloading the AlphaFold source code can take up to a few minutes.\n", + " * Downloading and installing the third-party software can take up to a few minutes.\n", + " * The search against genetic databases can take minutes to hours.\n", + " * Running AlphaFold and generating the prediction can take minutes to hours, depending on the length of your protein and on which GPU-type Colab has assigned you.\n", + "* My Colab no longer seems to be doing anything, what should I do?\n", + " * Some steps may take minutes to hours to complete.\n", + " * If nothing happens or if you receive an error message, try restarting your Colab runtime via _Runtime_ \u003e _Restart runtime_.\n", + " * If this doesn’t help, try resetting your Colab runtime via _Runtime_ \u003e _Factory reset runtime_.\n", + "* How does this compare to the open-source version of AlphaFold?\n", + " * This Colab version of AlphaFold searches a selected portion of the BFD dataset and currently doesn’t use templates, so its accuracy is reduced in comparison to the full version of AlphaFold that is described in the [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and [Github repo](https://github.com/deepmind/alphafold/) (the full version is available via the inference script).\n", + "* What is a Colab?\n", + " * See the [Colab FAQ](https://research.google.com/colaboratory/faq.html).\n", + "* I received a warning “Notebook requires high RAM”, what do I do?\n", + " * The resources allocated to your Colab vary. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n", + " * You can execute the Colab nonetheless.\n", + "* I received an error “Colab CPU runtime not supported” or “No GPU/TPU found”, what do I do?\n", + " * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ \u003e _Change runtime type_ \u003e _Hardware accelerator_ \u003e _GPU_.\n", + " * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n", + " * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n", + " * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs.\n", + "* I received an error “ModuleNotFoundError: No module named ...”, even though I ran the cell that imports it, what do I do?\n", + " * Colab notebooks on the free tier time out after a certain amount of time. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html#idle-timeouts). Try rerunning the whole notebook from the beginning.\n", + "* Does this tool install anything on my computer?\n", + " * No, everything happens in the cloud on Google Colab.\n", + " * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n", + "* How should I share feedback and bug reports?\n", + " * Please share any feedback and bug reports as an [issue](https://github.com/deepmind/alphafold/issues) on Github.\n", + "\n", + "\n", + "## Related work\n", + "\n", + "Take a look at these Colab notebooks provided by the community (please note that these notebooks may vary from our validated AlphaFold system and we cannot guarantee their accuracy):\n", + "\n", + "* The [ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) by Sergey Ovchinnikov, Milot Mirdita and Martin Steinegger, which uses an API hosted at the Södinglab based on the MMseqs2 server ([Mirdita et al. 2019, Bioinformatics](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135)) for the multiple sequence alignment creation.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YfPhvYgKC81B" + }, + "source": [ + "# License and Disclaimer\n", + "\n", + "This is not an officially-supported Google product.\n", + "\n", + "This Colab notebook and other information provided is for theoretical modelling only, caution should be exercised in its use. It is provided ‘as-is’ without any warranty of any kind, whether expressed or implied. Information is not intended to be a substitute for professional medical advice, diagnosis, or treatment, and does not constitute medical or other professional advice.\n", + "\n", + "Copyright 2021 DeepMind Technologies Limited.\n", + "\n", + "\n", + "## AlphaFold Code License\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", + "\n", + "## Model Parameters License\n", + "\n", + "The AlphaFold parameters are made available for non-commercial use only, under the terms of the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. You can find details at: https://creativecommons.org/licenses/by-nc/4.0/legalcode\n", + "\n", + "\n", + "## Third-party software\n", + "\n", + "Use of the third-party software, libraries or code referred to in the [Acknowledgements section](https://github.com/deepmind/alphafold/#acknowledgements) in the AlphaFold README may be governed by separate terms and conditions or license provisions. Your use of the third-party software, libraries or code is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use.\n", + "\n", + "\n", + "## Mirrored Databases\n", + "\n", + "The following databases have been mirrored by DeepMind, and are available with reference to the following:\n", + "* UniProt: v2021\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\n", + "* UniRef90: v2021\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\n", + "* MGnify: v2019\\_05 (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).\n", + "* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "AlphaFold.ipynb", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "form", - "id": "woIxeCPygt7K" - }, - "outputs": [], - "source": [ - "#@title Install third-party software\n", - "\n", - "#@markdown Please execute this cell by pressing the _Play_ button \n", - "#@markdown on the left to download and import third-party software \n", - "#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/deepmind/alphafold/#acknowledgements) in our readme.)\n", - "\n", - "#@markdown **Note**: This installs the software on the Colab \n", - "#@markdown notebook in the cloud and not on your computer.\n", - "\n", - "from IPython.utils import io\n", - "import os\n", - "import subprocess\n", - "import tqdm.notebook\n", - "\n", - "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n", - "\n", - "try:\n", - " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", - " with io.capture_output() as captured:\n", - " # Uninstall default Colab version of TF.\n", - " %shell pip uninstall -y tensorflow\n", - "\n", - " %shell sudo apt install --quiet --yes hmmer\n", - " pbar.update(6)\n", - "\n", - " # Install py3dmol.\n", - " %shell pip install py3dmol\n", - " pbar.update(2)\n", - "\n", - " # Install OpenMM and pdbfixer.\n", - " %shell rm -rf /opt/conda\n", - " %shell wget -q -P /tmp \\\n", - " https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \\\n", - " && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \\\n", - " && rm /tmp/Miniconda3-latest-Linux-x86_64.sh\n", - " pbar.update(9)\n", - "\n", - " PATH=%env PATH\n", - " %env PATH=/opt/conda/bin:{PATH}\n", - " %shell conda update -qy conda \\\n", - " && conda install -qy -c conda-forge \\\n", - " python=3.7 \\\n", - " openmm=7.5.1 \\\n", - " pdbfixer\n", - " pbar.update(80)\n", - "\n", - " # Create a ramdisk to store a database chunk to make Jackhmmer run fast.\n", - " %shell sudo mkdir -m 777 --parents /tmp/ramdisk\n", - " %shell sudo mount -t tmpfs -o size=9G ramdisk /tmp/ramdisk\n", - " pbar.update(2)\n", - "\n", - " %shell wget -q -P /content \\\n", - " https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt\n", - " pbar.update(1)\n", - "except subprocess.CalledProcessError:\n", - " print(captured)\n", - " raise" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "form", - "id": "VzJ5iMjTtoZw" - }, - "outputs": [], - "source": [ - "#@title Download AlphaFold\n", - "\n", - "#@markdown Please execute this cell by pressing the *Play* button on \n", - "#@markdown the left.\n", - "\n", - "GIT_REPO = 'https://github.com/deepmind/alphafold'\n", - "\n", - "SOURCE_URL = 'https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar'\n", - "PARAMS_DIR = './alphafold/data/params'\n", - "PARAMS_PATH = os.path.join(PARAMS_DIR, os.path.basename(SOURCE_URL))\n", - "\n", - "try:\n", - " with tqdm.notebook.tqdm(total=100, bar_format=TQDM_BAR_FORMAT) as pbar:\n", - " with io.capture_output() as captured:\n", - " %shell rm -rf alphafold\n", - " %shell git clone {GIT_REPO} alphafold\n", - " pbar.update(8)\n", - " # Install the required versions of all dependencies.\n", - " %shell pip3 install -r ./alphafold/requirements.txt\n", - " # Run setup.py to install only AlphaFold.\n", - " %shell pip3 install --no-dependencies ./alphafold\n", - " pbar.update(10)\n", - "\n", - " # Apply OpenMM patch.\n", - " %shell pushd /opt/conda/lib/python3.7/site-packages/ && \\\n", - " patch -p0 < /content/alphafold/docker/openmm.patch && \\\n", - " popd\n", - " \n", - " %shell mkdir -p /content/alphafold/common\n", - " %shell cp -f /content/stereo_chemical_props.txt /content/alphafold/common\n", - "\n", - " %shell mkdir --parents \"{PARAMS_DIR}\"\n", - " %shell wget -O \"{PARAMS_PATH}\" \"{SOURCE_URL}\"\n", - " pbar.update(27)\n", - "\n", - " %shell tar --extract --verbose --file=\"{PARAMS_PATH}\" \\\n", - " --directory=\"{PARAMS_DIR}\" --preserve-permissions\n", - " %shell rm \"{PARAMS_PATH}\"\n", - " pbar.update(55)\n", - "except subprocess.CalledProcessError:\n", - " print(captured)\n", - " raise\n", - "\n", - "import jax\n", - "if jax.local_devices()[0].platform == 'tpu':\n", - " raise RuntimeError('Colab TPU runtime not supported. Change it to GPU via Runtime -> Change Runtime Type -> Hardware accelerator -> GPU.')\n", - "elif jax.local_devices()[0].platform == 'cpu':\n", - " raise RuntimeError('Colab CPU runtime not supported. Change it to GPU via Runtime -> Change Runtime Type -> Hardware accelerator -> GPU.')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W4JpOs6oA-QS" - }, - "source": [ - "## Making a prediction\n", - "\n", - "Please paste the sequence of your protein in the text box below, then run the remaining cells via _Runtime_ > _Run after_. You can also run the cells individually by pressing the _Play_ button on the left.\n", - "\n", - "Note that the search against databases and the actual prediction can take some time, from minutes to hours, depending on the length of the protein and what type of GPU you are allocated by Colab (see FAQ below)." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "form", - "id": "rowN0bVYLe9n" - }, - "outputs": [], - "source": [ - "#@title Enter the amino acid sequence to fold ⬇️\n", - "sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH' #@param {type:\"string\"}\n", - "\n", - "MIN_SEQUENCE_LENGTH = 16\n", - "MAX_SEQUENCE_LENGTH = 2500\n", - "\n", - "# Remove all whitespaces, tabs and end lines; upper-case\n", - "sequence = sequence.translate(str.maketrans('', '', ' \\n\\t')).upper()\n", - "aatypes = set('ACDEFGHIKLMNPQRSTVWY') # 20 standard aatypes\n", - "if not set(sequence).issubset(aatypes):\n", - " raise Exception(f'Input sequence contains non-amino acid letters: {set(sequence) - aatypes}. AlphaFold only supports 20 standard amino acids as inputs.')\n", - "if len(sequence) < MIN_SEQUENCE_LENGTH:\n", - " raise Exception(f'Input sequence is too short: {len(sequence)} amino acids, while the minimum is {MIN_SEQUENCE_LENGTH}')\n", - "if len(sequence) > MAX_SEQUENCE_LENGTH:\n", - " raise Exception(f'Input sequence is too long: {len(sequence)} amino acids, while the maximum is {MAX_SEQUENCE_LENGTH}. Please use the full AlphaFold system for long sequences.')" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "form", - "id": "2tTeTTsLKPjB" - }, - "outputs": [], - "source": [ - "#@title Search against genetic databases\n", - "\n", - "#@markdown Once this cell has been executed, you will see\n", - "#@markdown statistics about the multiple sequence alignment \n", - "#@markdown (MSA) that will be used by AlphaFold. In particular, \n", - "#@markdown you’ll see how well each residue is covered by similar \n", - "#@markdown sequences in the MSA.\n", - "\n", - "# --- Python imports ---\n", - "import sys\n", - "sys.path.append('/opt/conda/lib/python3.7/site-packages')\n", - "\n", - "import os\n", - "os.environ['TF_FORCE_UNIFIED_MEMORY'] = '1'\n", - "os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '2.0'\n", - "\n", - "from urllib import request\n", - "from concurrent import futures\n", - "from google.colab import files\n", - "import json\n", - "from matplotlib import gridspec\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import py3Dmol\n", - "\n", - "from alphafold.model import model\n", - "from alphafold.model import config\n", - "from alphafold.model import data\n", - "\n", - "from alphafold.data import parsers\n", - "from alphafold.data import pipeline\n", - "from alphafold.data.tools import jackhmmer\n", - "\n", - "from alphafold.common import protein\n", - "\n", - "from alphafold.relax import relax\n", - "from alphafold.relax import utils\n", - "\n", - "from IPython import display\n", - "from ipywidgets import GridspecLayout\n", - "from ipywidgets import Output\n", - "\n", - "# Color bands for visualizing plddt\n", - "PLDDT_BANDS = [(0, 50, '#FF7D45'),\n", - " (50, 70, '#FFDB13'),\n", - " (70, 90, '#65CBF3'),\n", - " (90, 100, '#0053D6')]\n", - "\n", - "# --- Find the closest source ---\n", - "test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1'\n", - "ex = futures.ThreadPoolExecutor(3)\n", - "def fetch(source):\n", - " request.urlretrieve(test_url_pattern.format(source))\n", - " return source\n", - "fs = [ex.submit(fetch, source) for source in ['', '-europe', '-asia']]\n", - "source = None\n", - "for f in futures.as_completed(fs):\n", - " source = f.result()\n", - " ex.shutdown()\n", - " break\n", - "\n", - "# --- Search against genetic databases ---\n", - "with open('target.fasta', 'wt') as f:\n", - " f.write(f'>query\\n{sequence}')\n", - "\n", - "# Run the search against chunks of genetic databases (since the genetic\n", - "# databases don't fit in Colab ramdisk).\n", - "\n", - "jackhmmer_binary_path = '/usr/bin/jackhmmer'\n", - "dbs = []\n", - "\n", - "num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71}\n", - "total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values())\n", - "with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT) as pbar:\n", - " def jackhmmer_chunk_callback(i):\n", - " pbar.update(n=1)\n", - "\n", - " pbar.set_description('Searching uniref90')\n", - " jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(\n", - " binary_path=jackhmmer_binary_path,\n", - " database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta',\n", - " get_tblout=True,\n", - " num_streamed_chunks=num_jackhmmer_chunks['uniref90'],\n", - " streaming_callback=jackhmmer_chunk_callback,\n", - " z_value=135301051)\n", - " dbs.append(('uniref90', jackhmmer_uniref90_runner.query('target.fasta')))\n", - "\n", - " pbar.set_description('Searching smallbfd')\n", - " jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer(\n", - " binary_path=jackhmmer_binary_path,\n", - " database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta',\n", - " get_tblout=True,\n", - " num_streamed_chunks=num_jackhmmer_chunks['smallbfd'],\n", - " streaming_callback=jackhmmer_chunk_callback,\n", - " z_value=65984053)\n", - " dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query('target.fasta')))\n", - "\n", - " pbar.set_description('Searching mgnify')\n", - " jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(\n", - " binary_path=jackhmmer_binary_path,\n", - " database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta',\n", - " get_tblout=True,\n", - " num_streamed_chunks=num_jackhmmer_chunks['mgnify'],\n", - " streaming_callback=jackhmmer_chunk_callback,\n", - " z_value=304820129)\n", - " dbs.append(('mgnify', jackhmmer_mgnify_runner.query('target.fasta')))\n", - "\n", - "\n", - "# --- Extract the MSAs and visualize ---\n", - "# Extract the MSAs from the Stockholm files.\n", - "# NB: deduplication happens later in pipeline.make_msa_features.\n", - "\n", - "mgnify_max_hits = 501\n", - "\n", - "msas = []\n", - "deletion_matrices = []\n", - "full_msa = []\n", - "for db_name, db_results in dbs:\n", - " unsorted_results = []\n", - " for i, result in enumerate(db_results):\n", - " msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto'])\n", - " e_values_dict = parsers.parse_e_values_from_tblout(result['tbl'])\n", - " e_values = [e_values_dict[t.split('/')[0]] for t in target_names]\n", - " zipped_results = zip(msa, deletion_matrix, target_names, e_values)\n", - " if i != 0:\n", - " # Only take query from the first chunk\n", - " zipped_results = [x for x in zipped_results if x[2] != 'query']\n", - " unsorted_results.extend(zipped_results)\n", - " sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3])\n", - " db_msas, db_deletion_matrices, _, _ = zip(*sorted_by_evalue)\n", - " if db_msas:\n", - " if db_name == 'mgnify':\n", - " db_msas = db_msas[:mgnify_max_hits]\n", - " db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits]\n", - " full_msa.extend(db_msas)\n", - " msas.append(db_msas)\n", - " deletion_matrices.append(db_deletion_matrices)\n", - " msa_size = len(set(db_msas))\n", - " print(f'{msa_size} Sequences Found in {db_name}')\n", - "\n", - "deduped_full_msa = list(dict.fromkeys(full_msa))\n", - "total_msa_size = len(deduped_full_msa)\n", - "print(f'\\n{total_msa_size} Sequences Found in Total\\n')\n", - "\n", - "aa_map = {restype: i for i, restype in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}\n", - "msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in deduped_full_msa])\n", - "num_alignments, num_res = msa_arr.shape\n", - "\n", - "fig = plt.figure(figsize=(12, 3))\n", - "plt.title('Per-Residue Count of Non-Gap Amino Acids in the MSA')\n", - "plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')\n", - "plt.ylabel('Non-Gap Count')\n", - "plt.yticks(range(0, num_alignments + 1, max(1, int(num_alignments / 3))))\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "form", - "id": "XUo6foMQxwS2" - }, - "outputs": [], - "source": [ - "#@title Run AlphaFold and download prediction\n", - "\n", - "#@markdown Once this cell has been executed, a zip-archive with \n", - "#@markdown the obtained prediction will be automatically downloaded \n", - "#@markdown to your computer.\n", - "\n", - "# --- Run the model ---\n", - "model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5', 'model_2_ptm']\n", - "\n", - "def _placeholder_template_feats(num_templates_, num_res_):\n", - " return {\n", - " 'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32),\n", - " 'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37, 3], np.float32),\n", - " 'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37], np.float32),\n", - " 'template_domain_names': np.zeros([num_templates_], np.float32),\n", - " 'template_sum_probs': np.zeros([num_templates_], np.float32),\n", - " }\n", - "\n", - "output_dir = 'prediction'\n", - "os.makedirs(output_dir, exist_ok=True)\n", - "\n", - "plddts = {}\n", - "pae_outputs = {}\n", - "unrelaxed_proteins = {}\n", - "\n", - "with tqdm.notebook.tqdm(total=len(model_names) + 1, bar_format=TQDM_BAR_FORMAT) as pbar:\n", - " for model_name in model_names:\n", - " pbar.set_description(f'Running {model_name}')\n", - " num_templates = 0\n", - " num_res = len(sequence)\n", - "\n", - " feature_dict = {}\n", - " feature_dict.update(pipeline.make_sequence_features(sequence, 'test', num_res))\n", - " feature_dict.update(pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))\n", - " feature_dict.update(_placeholder_template_feats(num_templates, num_res))\n", - "\n", - " cfg = config.model_config(model_name)\n", - " params = data.get_model_haiku_params(model_name, './alphafold/data')\n", - " model_runner = model.RunModel(cfg, params)\n", - " processed_feature_dict = model_runner.process_features(feature_dict,\n", - " random_seed=0)\n", - " prediction_result = model_runner.predict(processed_feature_dict)\n", - "\n", - " mean_plddt = prediction_result['plddt'].mean()\n", - "\n", - " if 'predicted_aligned_error' in prediction_result:\n", - " pae_outputs[model_name] = (\n", - " prediction_result['predicted_aligned_error'],\n", - " prediction_result['max_predicted_aligned_error']\n", - " )\n", - " else:\n", - " # Get the pLDDT confidence metrics. Do not put pTM models here as they\n", - " # should never get selected.\n", - " plddts[model_name] = prediction_result['plddt']\n", - "\n", - " # Set the b-factors to the per-residue plddt.\n", - " final_atom_mask = prediction_result['structure_module']['final_atom_mask']\n", - " b_factors = prediction_result['plddt'][:, None] * final_atom_mask\n", - " unrelaxed_protein = protein.from_prediction(processed_feature_dict,\n", - " prediction_result,\n", - " b_factors=b_factors)\n", - " unrelaxed_proteins[model_name] = unrelaxed_protein\n", - "\n", - " # Delete unused outputs to save memory.\n", - " del model_runner\n", - " del params\n", - " del prediction_result\n", - " pbar.update(n=1)\n", - "\n", - " # --- AMBER relax the best model ---\n", - " pbar.set_description(f'AMBER relaxation')\n", - " amber_relaxer = relax.AmberRelaxation(\n", - " max_iterations=0,\n", - " tolerance=2.39,\n", - " stiffness=10.0,\n", - " exclude_residues=[],\n", - " max_outer_iterations=20)\n", - " # Find the best model according to the mean pLDDT.\n", - " best_model_name = max(plddts.keys(), key=lambda x: plddts[x].mean())\n", - " relaxed_pdb, _, _ = amber_relaxer.process(\n", - " prot=unrelaxed_proteins[best_model_name])\n", - " pbar.update(n=1) # Finished AMBER relax.\n", - "\n", - "# Construct multiclass b-factors to indicate confidence bands\n", - "# 0=very low, 1=low, 2=confident, 3=very high\n", - "banded_b_factors = []\n", - "for plddt in plddts[best_model_name]:\n", - " for idx, (min_val, max_val, _) in enumerate(PLDDT_BANDS):\n", - " if plddt >= min_val and plddt <= max_val:\n", - " banded_b_factors.append(idx)\n", - " break\n", - "banded_b_factors = np.array(banded_b_factors)[:, None] * final_atom_mask\n", - "to_visualize_pdb = utils.overwrite_b_factors(relaxed_pdb, banded_b_factors)\n", - "\n", - "\n", - "# Write out the prediction\n", - "pred_output_path = os.path.join(output_dir, 'selected_prediction.pdb')\n", - "with open(pred_output_path, 'w') as f:\n", - " f.write(relaxed_pdb)\n", - "\n", - "\n", - "# --- Visualise the prediction & confidence ---\n", - "show_sidechains = True\n", - "def plot_plddt_legend():\n", - " \"\"\"Plots the legend for pLDDT.\"\"\"\n", - " thresh = [\n", - " 'Very low (pLDDT < 50)',\n", - " 'Low (70 > pLDDT > 50)',\n", - " 'Confident (90 > pLDDT > 70)',\n", - " 'Very high (pLDDT > 90)']\n", - "\n", - " colors = [x[2] for x in PLDDT_BANDS]\n", - "\n", - " plt.figure(figsize=(2, 2))\n", - " for c in colors:\n", - " plt.bar(0, 0, color=c)\n", - " plt.legend(thresh, frameon=False, loc='center', fontsize=20)\n", - " plt.xticks([])\n", - " plt.yticks([])\n", - " ax = plt.gca()\n", - " ax.spines['right'].set_visible(False)\n", - " ax.spines['top'].set_visible(False)\n", - " ax.spines['left'].set_visible(False)\n", - " ax.spines['bottom'].set_visible(False)\n", - " plt.title('Model Confidence', fontsize=20, pad=20)\n", - " return plt\n", - "\n", - "# Color the structure by per-residue pLDDT\n", - "color_map = {i: bands[2] for i, bands in enumerate(PLDDT_BANDS)}\n", - "view = py3Dmol.view(width=800, height=600)\n", - "view.addModelsAsFrames(to_visualize_pdb)\n", - "style = {'cartoon': {\n", - " 'colorscheme': {\n", - " 'prop': 'b',\n", - " 'map': color_map}\n", - " }}\n", - "if show_sidechains:\n", - " style['stick'] = {}\n", - "view.setStyle({'model': -1}, style)\n", - "view.zoomTo()\n", - "\n", - "grid = GridspecLayout(1, 2)\n", - "out = Output()\n", - "with out:\n", - " view.show()\n", - "grid[0, 0] = out\n", - "\n", - "out = Output()\n", - "with out:\n", - " plot_plddt_legend().show()\n", - "grid[0, 1] = out\n", - "\n", - "display.display(grid)\n", - "\n", - "# Display pLDDT and predicted aligned error (if output by the model).\n", - "if pae_outputs:\n", - " num_plots = 2\n", - "else:\n", - " num_plots = 1\n", - "\n", - "plt.figure(figsize=[8 * num_plots, 6])\n", - "plt.subplot(1, num_plots, 1)\n", - "plt.plot(plddts[best_model_name])\n", - "plt.title('Predicted LDDT')\n", - "plt.xlabel('Residue')\n", - "plt.ylabel('pLDDT')\n", - "\n", - "if num_plots == 2:\n", - " plt.subplot(1, 2, 2)\n", - " pae, max_pae = list(pae_outputs.values())[0]\n", - " plt.imshow(pae, vmin=0., vmax=max_pae, cmap='Greens_r')\n", - " plt.colorbar(fraction=0.046, pad=0.04)\n", - " plt.title('Predicted Aligned Error')\n", - " plt.xlabel('Scored residue')\n", - " plt.ylabel('Aligned residue')\n", - "\n", - "# Save pLDDT and predicted aligned error (if it exists)\n", - "pae_output_path = os.path.join(output_dir, 'predicted_aligned_error.json')\n", - "if pae_outputs:\n", - " # Save predicted aligned error in the same format as the AF EMBL DB\n", - " rounded_errors = np.round(pae.astype(np.float64), decimals=1)\n", - " indices = np.indices((len(rounded_errors), len(rounded_errors))) + 1\n", - " indices_1 = indices[0].flatten().tolist()\n", - " indices_2 = indices[1].flatten().tolist()\n", - " pae_data = json.dumps([{\n", - " 'residue1': indices_1,\n", - " 'residue2': indices_2,\n", - " 'distance': rounded_errors.flatten().tolist(),\n", - " 'max_predicted_aligned_error': max_pae.item()\n", - " }],\n", - " indent=None,\n", - " separators=(',', ':'))\n", - " with open(pae_output_path, 'w') as f:\n", - " f.write(pae_data)\n", - "\n", - "\n", - "# --- Download the predictions ---\n", - "!zip -q -r {output_dir}.zip {output_dir}\n", - "files.download(f'{output_dir}.zip')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lUQAn5LYC5n4" - }, - "source": [ - "### Interpreting the prediction\n", - "\n", - "Please see the [AlphaFold methods paper](https://www.nature.com/articles/s41586-021-03819-2) and the [AlphaFold predictions of the human proteome paper](https://www.nature.com/articles/s41586-021-03828-1), as well as [our FAQ](https://alphafold.ebi.ac.uk/faq) on how to interpret AlphaFold predictions." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jeb2z8DIA4om" - }, - "source": [ - "## FAQ & Troubleshooting\n", - "\n", - "\n", - "* How do I get a predicted protein structure for my protein?\n", - " * Click on the _Connect_ button on the top right to get started.\n", - " * Paste the amino acid sequence of your protein (without any headers) into the “Enter the amino acid sequence to fold”.\n", - " * Run all cells in the Colab, either by running them individually (with the play button on the left side) or via _Runtime_ > _Run all._\n", - " * The predicted protein structure will be downloaded once all cells have been executed. Note: This can take minutes to hours - see below.\n", - "* How long will this take?\n", - " * Downloading the AlphaFold source code can take up to a few minutes.\n", - " * Downloading and installing the third-party software can take up to a few minutes.\n", - " * The search against genetic databases can take minutes to hours.\n", - " * Running AlphaFold and generating the prediction can take minutes to hours, depending on the length of your protein and on which GPU-type Colab has assigned you.\n", - "* My Colab no longer seems to be doing anything, what should I do?\n", - " * Some steps may take minutes to hours to complete.\n", - " * If nothing happens or if you receive an error message, try restarting your Colab runtime via _Runtime_ > _Restart runtime_.\n", - " * If this doesn’t help, try resetting your Colab runtime via _Runtime_ > _Factory reset runtime_.\n", - "* How does this compare to the open-source version of AlphaFold?\n", - " * This Colab version of AlphaFold searches a selected portion of the BFD dataset and currently doesn’t use templates, so its accuracy is reduced in comparison to the full version of AlphaFold that is described in the [AlphaFold paper](https://doi.org/10.1038/s41586-021-03819-2) and [Github repo](https://github.com/deepmind/alphafold/) (the full version is available via the inference script).\n", - "* What is a Colab?\n", - " * See the [Colab FAQ](https://research.google.com/colaboratory/faq.html).\n", - "* I received a warning “Notebook requires high RAM”, what do I do?\n", - " * The resources allocated to your Colab vary. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n", - " * You can execute the Colab nonetheless.\n", - "* I received an error “Colab CPU runtime not supported” or “No GPU/TPU found”, what do I do?\n", - " * Colab CPU runtime is not supported. Try changing your runtime via _Runtime_ > _Change runtime type_ > _Hardware accelerator_ > _GPU_.\n", - " * The type of GPU allocated to your Colab varies. See the [Colab FAQ](https://research.google.com/colaboratory/faq.html) for more details.\n", - " * If you receive “Cannot connect to GPU backend”, you can try again later to see if Colab allocates you a GPU.\n", - " * [Colab Pro](https://colab.research.google.com/signup) offers priority access to GPUs. \n", - "* Does this tool install anything on my computer?\n", - " * No, everything happens in the cloud on Google Colab.\n", - " * At the end of the Colab execution a zip-archive with the obtained prediction will be automatically downloaded to your computer.\n", - "* How should I share feedback and bug reports?\n", - " * Please share any feedback and bug reports as an [issue](https://github.com/deepmind/alphafold/issues) on Github.\n", - "\n", - "\n", - "## Related work\n", - "\n", - "Take a look at these Colab notebooks provided by the community (please note that these notebooks may vary from our validated AlphaFold system and we cannot guarantee their accuracy):\n", - "\n", - "* The [ColabFold AlphaFold2 notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb) by Sergey Ovchinnikov, Milot Mirdita and Martin Steinegger, which uses an API hosted at the Södinglab based on the MMseqs2 server ([Mirdita et al. 2019, Bioinformatics](https://academic.oup.com/bioinformatics/article/35/16/2856/5280135)) for the multiple sequence alignment creation.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YfPhvYgKC81B" - }, - "source": [ - "# License and Disclaimer\n", - "\n", - "This is not an officially-supported Google product.\n", - "\n", - "This Colab notebook and other information provided is for theoretical modelling only, caution should be exercised in its use. It is provided ‘as-is’ without any warranty of any kind, whether expressed or implied. Information is not intended to be a substitute for professional medical advice, diagnosis, or treatment, and does not constitute medical or other professional advice.\n", - "\n", - "Copyright 2021 DeepMind Technologies Limited.\n", - "\n", - "\n", - "## AlphaFold Code License\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0.\n", - "\n", - "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", - "\n", - "## Model Parameters License\n", - "\n", - "The AlphaFold parameters are made available for non-commercial use only, under the terms of the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. You can find details at: https://creativecommons.org/licenses/by-nc/4.0/legalcode\n", - "\n", - "\n", - "## Third-party software\n", - "\n", - "Use of the third-party software, libraries or code referred to in the [Acknowledgements section](https://github.com/deepmind/alphafold/#acknowledgements) in the AlphaFold README may be governed by separate terms and conditions or license provisions. Your use of the third-party software, libraries or code is subject to any such terms and you should check that you can comply with any applicable restrictions or terms and conditions before use.\n", - "\n", - "\n", - "## Mirrored Databases\n", - "\n", - "The following databases have been mirrored by DeepMind, and are available with reference to the following:\n", - "* UniRef90: v2021\\_03 (unmodified), by The UniProt Consortium, available under a [Creative Commons Attribution-NoDerivatives 4.0 International License](http://creativecommons.org/licenses/by-nd/4.0/).\n", - "* MGnify: v2019\\_05 (unmodified), by Mitchell AL et al., available free of all copyright restrictions and made fully and freely available for both non-commercial and commercial use under [CC0 1.0 Universal (CC0 1.0) Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/).\n", - "* BFD: (modified), by Steinegger M. and Söding J., modified by DeepMind, available under a [Creative Commons Attribution-ShareAlike 4.0 International License](https://creativecommons.org/licenses/by/4.0/). See the Methods section of the [AlphaFold proteome paper](https://www.nature.com/articles/s41586-021-03828-1) for details." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "AlphaFold.ipynb" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/requirements.txt b/requirements.txt index 40c6539e7..552e6f918 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,6 @@ immutabledict==2.0.0 jax==0.2.14 ml-collections==0.1.0 numpy==1.19.5 +pandas==1.3.4 scipy==1.7.0 tensorflow-cpu==2.5.0 diff --git a/run_alphafold.py b/run_alphafold.py index 6f1a690d8..1d5403c1c 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -18,9 +18,10 @@ import pathlib import pickle import random +import shutil import sys import time -from typing import Dict +from typing import Dict, Union, Optional from absl import app from absl import flags @@ -28,30 +29,47 @@ from alphafold.common import protein from alphafold.common import residue_constants from alphafold.data import pipeline +from alphafold.data import pipeline_multimer from alphafold.data import templates -from alphafold.model import data +from alphafold.data.tools import hhsearch +from alphafold.data.tools import hmmsearch from alphafold.model import config from alphafold.model import model from alphafold.relax import relax import numpy as np + +from alphafold.model import data # Internal import (7716). +logging.set_verbosity(logging.INFO) + flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing ' - 'one sequence. Paths should be separated by commas. ' + 'a prediction target. Paths should be separated by commas. ' 'All FASTA paths must have a unique basename as the ' 'basename is used to name the output directories for ' 'each prediction.') +flags.DEFINE_list('is_prokaryote_list', None, 'Optional for multimer system, ' + 'not used by the single chain system. ' + 'This list should contain a boolean for each fasta ' + 'specifying true where the target complex is from a ' + 'prokaryote, and false where it is not, or where the ' + 'origin is unknown. These values determine the pairing ' + 'method for the MSA.') + +flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') flags.DEFINE_string('output_dir', None, 'Path to a directory that will ' 'store the results.') -flags.DEFINE_list('model_names', None, 'Names of models to use.') -flags.DEFINE_string('data_dir', None, 'Path to directory of supporting data.') -flags.DEFINE_string('jackhmmer_binary_path', '/usr/bin/jackhmmer', +flags.DEFINE_string('jackhmmer_binary_path', shutil.which('jackhmmer'), 'Path to the JackHMMER executable.') -flags.DEFINE_string('hhblits_binary_path', '/usr/bin/hhblits', +flags.DEFINE_string('hhblits_binary_path', shutil.which('hhblits'), 'Path to the HHblits executable.') -flags.DEFINE_string('hhsearch_binary_path', '/usr/bin/hhsearch', +flags.DEFINE_string('hhsearch_binary_path', shutil.which('hhsearch'), 'Path to the HHsearch executable.') -flags.DEFINE_string('kalign_binary_path', '/usr/bin/kalign', +flags.DEFINE_string('hmmsearch_binary_path', shutil.which('hmmsearch'), + 'Path to the hmmsearch executable.') +flags.DEFINE_string('hmmbuild_binary_path', shutil.which('hmmbuild'), + 'Path to the hmmbuild executable.') +flags.DEFINE_string('kalign_binary_path', shutil.which('kalign'), 'Path to the Kalign executable.') flags.DEFINE_string('uniref90_database_path', None, 'Path to the Uniref90 ' 'database for use by JackHMMER.') @@ -63,8 +81,12 @@ 'version of BFD used with the "reduced_dbs" preset.') flags.DEFINE_string('uniclust30_database_path', None, 'Path to the Uniclust30 ' 'database for use by HHblits.') +flags.DEFINE_string('uniprot_database_path', None, 'Path to the Uniprot ' + 'database for use by JackHMMer.') flags.DEFINE_string('pdb70_database_path', None, 'Path to the PDB70 ' 'database for use by HHsearch.') +flags.DEFINE_string('pdb_seqres_database_path', None, 'Path to the PDB ' + 'seqres database for use by hmmsearch.') flags.DEFINE_string('template_mmcif_dir', None, 'Path to a directory with ' 'template mmCIF structures, each named .cif') flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' @@ -72,13 +94,16 @@ flags.DEFINE_string('obsolete_pdbs_path', None, 'Path to file containing a ' 'mapping from obsolete PDB IDs to the PDB IDs of their ' 'replacements.') -flags.DEFINE_enum('preset', 'full_dbs', - ['reduced_dbs', 'full_dbs', 'casp14'], - 'Choose preset model configuration - no ensembling and ' - 'smaller genetic database config (reduced_dbs), no ' - 'ensembling and full genetic database config (full_dbs) or ' - 'full genetic database config and 8 model ensemblings ' - '(casp14).') +flags.DEFINE_enum('db_preset', 'full_dbs', + ['full_dbs', 'reduced_dbs'], + 'Choose preset MSA database configuration - ' + 'smaller genetic database config (reduced_dbs) or ' + 'full genetic database config (full_dbs)') +flags.DEFINE_enum('model_preset', 'monomer', + ['monomer', 'monomer_casp14', 'monomer_ptm', 'multimer'], + 'Choose preset model configuration - the monomer model, ' + 'the monomer model with extra ensembling, monomer model with ' + 'pTM head, or multimer model') flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' 'to obtain a timing that excludes the compilation time, ' 'which should be more indicative of the time required for ' @@ -88,6 +113,10 @@ 'that even if this is set, Alphafold may still not be ' 'deterministic, because processes like GPU inference are ' 'nondeterministic.') +flags.DEFINE_boolean('use_precomputed_msas', False, 'Whether to read MSAs that ' + 'have been written to disk. WARNING: This will not check ' + 'if the sequence, database or configuration have changed.') + FLAGS = flags.FLAGS MAX_TEMPLATE_HITS = 20 @@ -95,25 +124,30 @@ RELAX_ENERGY_TOLERANCE = 2.39 RELAX_STIFFNESS = 10.0 RELAX_EXCLUDE_RESIDUES = [] -RELAX_MAX_OUTER_ITERATIONS = 20 +RELAX_MAX_OUTER_ITERATIONS = 3 -def _check_flag(flag_name: str, preset: str, should_be_set: bool): +def _check_flag(flag_name: str, + other_flag_name: str, + should_be_set: bool): if should_be_set != bool(FLAGS[flag_name].value): verb = 'be' if should_be_set else 'not be' - raise ValueError(f'{flag_name} must {verb} set for preset "{preset}"') + raise ValueError(f'{flag_name} must {verb} set when running with ' + f'"--{other_flag_name}={FLAGS[other_flag_name].value}".') def predict_structure( fasta_path: str, fasta_name: str, output_dir_base: str, - data_pipeline: pipeline.DataPipeline, + data_pipeline: Union[pipeline.DataPipeline, pipeline_multimer.DataPipeline], model_runners: Dict[str, model.RunModel], amber_relaxer: relax.AmberRelaxation, benchmark: bool, - random_seed: int): + random_seed: int, + is_prokaryote: Optional[bool] = None): """Predicts structure using AlphaFold for the given sequence.""" + logging.info('Predicting %s', fasta_name) timings = {} output_dir = os.path.join(output_dir_base, fasta_name) if not os.path.exists(output_dir): @@ -124,9 +158,15 @@ def predict_structure( # Get features. t_0 = time.time() - feature_dict = data_pipeline.process( - input_fasta_path=fasta_path, - msa_output_dir=msa_output_dir) + if is_prokaryote is None: + feature_dict = data_pipeline.process( + input_fasta_path=fasta_path, + msa_output_dir=msa_output_dir) + else: + feature_dict = data_pipeline.process( + input_fasta_path=fasta_path, + msa_output_dir=msa_output_dir, + is_prokaryote=is_prokaryote) timings['features'] = time.time() - t_0 # Write out features as a pickled dictionary. @@ -134,33 +174,42 @@ def predict_structure( with open(features_output_path, 'wb') as f: pickle.dump(feature_dict, f, protocol=4) + unrelaxed_pdbs = {} relaxed_pdbs = {} - plddts = {} + ranking_confidences = {} # Run the models. - for model_name, model_runner in model_runners.items(): - logging.info('Running model %s', model_name) + num_models = len(model_runners) + for model_index, (model_name, model_runner) in enumerate( + model_runners.items()): + logging.info('Running model %s on %s', model_name, fasta_name) t_0 = time.time() + model_random_seed = model_index + random_seed * num_models processed_feature_dict = model_runner.process_features( - feature_dict, random_seed=random_seed) + feature_dict, random_seed=model_random_seed) timings[f'process_features_{model_name}'] = time.time() - t_0 t_0 = time.time() - prediction_result = model_runner.predict(processed_feature_dict) + prediction_result = model_runner.predict(processed_feature_dict, + random_seed=model_random_seed) t_diff = time.time() - t_0 timings[f'predict_and_compile_{model_name}'] = t_diff logging.info( - 'Total JAX model %s predict time (includes compilation time, see --benchmark): %.0f?', - model_name, t_diff) + 'Total JAX model %s on %s predict time (includes compilation time, see --benchmark): %.1fs', + model_name, fasta_name, t_diff) if benchmark: t_0 = time.time() - model_runner.predict(processed_feature_dict) - timings[f'predict_benchmark_{model_name}'] = time.time() - t_0 + model_runner.predict(processed_feature_dict, + random_seed=model_random_seed) + t_diff = time.time() - t_0 + timings[f'predict_benchmark_{model_name}'] = t_diff + logging.info( + 'Total JAX model %s on %s predict time (excludes compilation time): %.1fs', + model_name, fasta_name, t_diff) - # Get mean pLDDT confidence metric. plddt = prediction_result['plddt'] - plddts[model_name] = np.mean(plddt) + ranking_confidences[model_name] = prediction_result['ranking_confidence'] # Save the model outputs. result_output_path = os.path.join(output_dir, f'result_{model_name}.pkl') @@ -174,36 +223,45 @@ def predict_structure( unrelaxed_protein = protein.from_prediction( features=processed_feature_dict, result=prediction_result, - b_factors=plddt_b_factors) + b_factors=plddt_b_factors, + remove_leading_feature_dimension=not model_runner.multimer_mode) + unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein) unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb') with open(unrelaxed_pdb_path, 'w') as f: - f.write(protein.to_pdb(unrelaxed_protein)) + f.write(unrelaxed_pdbs[model_name]) - # Relax the prediction. - t_0 = time.time() - relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) - timings[f'relax_{model_name}'] = time.time() - t_0 + if amber_relaxer: + # Relax the prediction. + t_0 = time.time() + relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) + timings[f'relax_{model_name}'] = time.time() - t_0 - relaxed_pdbs[model_name] = relaxed_pdb_str + relaxed_pdbs[model_name] = relaxed_pdb_str - # Save the relaxed PDB. - relaxed_output_path = os.path.join(output_dir, f'relaxed_{model_name}.pdb') - with open(relaxed_output_path, 'w') as f: - f.write(relaxed_pdb_str) + # Save the relaxed PDB. + relaxed_output_path = os.path.join( + output_dir, f'relaxed_{model_name}.pdb') + with open(relaxed_output_path, 'w') as f: + f.write(relaxed_pdb_str) - # Rank by pLDDT and write out relaxed PDBs in rank order. + # Rank by model confidence and write out relaxed PDBs in rank order. ranked_order = [] for idx, (model_name, _) in enumerate( - sorted(plddts.items(), key=lambda x: x[1], reverse=True)): + sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)): ranked_order.append(model_name) ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb') with open(ranked_output_path, 'w') as f: - f.write(relaxed_pdbs[model_name]) + if amber_relaxer: + f.write(relaxed_pdbs[model_name]) + else: + f.write(unrelaxed_pdbs[model_name]) ranking_output_path = os.path.join(output_dir, 'ranking_debug.json') with open(ranking_output_path, 'w') as f: - f.write(json.dumps({'plddts': plddts, 'order': ranked_order}, indent=4)) + label = 'iptm+ptm' if 'iptm' in prediction_result else 'plddts' + f.write(json.dumps( + {label: ranking_confidences, 'order': ranked_order}, indent=4)) logging.info('Final timings for %s: %s', fasta_name, timings) @@ -216,49 +274,108 @@ def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') - use_small_bfd = FLAGS.preset == 'reduced_dbs' - _check_flag('small_bfd_database_path', FLAGS.preset, + for tool_name in ( + 'jackhmmer', 'hhblits', 'hhsearch', 'hmmsearch', 'hmmbuild', 'kalign'): + if not FLAGS[f'{tool_name}_binary_path'].value: + raise ValueError(f'Could not find path to the "{tool_name}" binary. Make ' + 'sure it is installed on your system.') + + use_small_bfd = FLAGS.db_preset == 'reduced_dbs' + _check_flag('small_bfd_database_path', 'db_preset', should_be_set=use_small_bfd) - _check_flag('bfd_database_path', FLAGS.preset, + _check_flag('bfd_database_path', 'db_preset', should_be_set=not use_small_bfd) - _check_flag('uniclust30_database_path', FLAGS.preset, + _check_flag('uniclust30_database_path', 'db_preset', should_be_set=not use_small_bfd) - if FLAGS.preset in ('reduced_dbs', 'full_dbs'): - num_ensemble = 1 - elif FLAGS.preset == 'casp14': + run_multimer_system = 'multimer' in FLAGS.model_preset + _check_flag('pdb70_database_path', 'model_preset', + should_be_set=not run_multimer_system) + _check_flag('pdb_seqres_database_path', 'model_preset', + should_be_set=run_multimer_system) + _check_flag('uniprot_database_path', 'model_preset', + should_be_set=run_multimer_system) + + if FLAGS.model_preset == 'monomer_casp14': num_ensemble = 8 + else: + num_ensemble = 1 # Check for duplicate FASTA file names. fasta_names = [pathlib.Path(p).stem for p in FLAGS.fasta_paths] if len(fasta_names) != len(set(fasta_names)): raise ValueError('All FASTA paths must have a unique basename.') - template_featurizer = templates.TemplateHitFeaturizer( - mmcif_dir=FLAGS.template_mmcif_dir, - max_template_date=FLAGS.max_template_date, - max_hits=MAX_TEMPLATE_HITS, - kalign_binary_path=FLAGS.kalign_binary_path, - release_dates_path=None, - obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) - - data_pipeline = pipeline.DataPipeline( + # Check that is_prokaryote_list has same number of elements as fasta_paths, + # and convert to bool. + if FLAGS.is_prokaryote_list: + if len(FLAGS.is_prokaryote_list) != len(FLAGS.fasta_paths): + raise ValueError('--is_prokaryote_list must either be omitted or match ' + 'length of --fasta_paths.') + is_prokaryote_list = [] + for s in FLAGS.is_prokaryote_list: + if s in ('true', 'false'): + is_prokaryote_list.append(s == 'true') + else: + raise ValueError('--is_prokaryote_list must contain comma separated ' + 'true or false values.') + else: # Default is_prokaryote to False. + is_prokaryote_list = [False] * len(fasta_names) + + if run_multimer_system: + template_searcher = hmmsearch.Hmmsearch( + binary_path=FLAGS.hmmsearch_binary_path, + hmmbuild_binary_path=FLAGS.hmmbuild_binary_path, + database_path=FLAGS.pdb_seqres_database_path) + template_featurizer = templates.HmmsearchHitFeaturizer( + mmcif_dir=FLAGS.template_mmcif_dir, + max_template_date=FLAGS.max_template_date, + max_hits=MAX_TEMPLATE_HITS, + kalign_binary_path=FLAGS.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) + else: + template_searcher = hhsearch.HHSearch( + binary_path=FLAGS.hhsearch_binary_path, + databases=[FLAGS.pdb70_database_path]) + template_featurizer = templates.HhsearchHitFeaturizer( + mmcif_dir=FLAGS.template_mmcif_dir, + max_template_date=FLAGS.max_template_date, + max_hits=MAX_TEMPLATE_HITS, + kalign_binary_path=FLAGS.kalign_binary_path, + release_dates_path=None, + obsolete_pdbs_path=FLAGS.obsolete_pdbs_path) + + monomer_data_pipeline = pipeline.DataPipeline( jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, hhblits_binary_path=FLAGS.hhblits_binary_path, - hhsearch_binary_path=FLAGS.hhsearch_binary_path, uniref90_database_path=FLAGS.uniref90_database_path, mgnify_database_path=FLAGS.mgnify_database_path, bfd_database_path=FLAGS.bfd_database_path, uniclust30_database_path=FLAGS.uniclust30_database_path, small_bfd_database_path=FLAGS.small_bfd_database_path, - pdb70_database_path=FLAGS.pdb70_database_path, + template_searcher=template_searcher, template_featurizer=template_featurizer, - use_small_bfd=use_small_bfd) + use_small_bfd=use_small_bfd, + use_precomputed_msas=FLAGS.use_precomputed_msas) + + if run_multimer_system: + data_pipeline = pipeline_multimer.DataPipeline( + monomer_data_pipeline=monomer_data_pipeline, + jackhmmer_binary_path=FLAGS.jackhmmer_binary_path, + uniprot_database_path=FLAGS.uniprot_database_path, + use_precomputed_msas=FLAGS.use_precomputed_msas) + else: + data_pipeline = monomer_data_pipeline model_runners = {} - for model_name in FLAGS.model_names: + model_names = config.MODEL_PRESETS[FLAGS.model_preset] + for model_name in model_names: model_config = config.model_config(model_name) - model_config.data.eval.num_ensemble = num_ensemble + if run_multimer_system: + model_config.model.num_ensemble_eval = num_ensemble + else: + model_config.data.eval.num_ensemble = num_ensemble model_params = data.get_model_haiku_params( model_name=model_name, data_dir=FLAGS.data_dir) model_runner = model.RunModel(model_config, model_params) @@ -276,11 +393,13 @@ def main(argv): random_seed = FLAGS.random_seed if random_seed is None: - random_seed = random.randrange(sys.maxsize) + random_seed = random.randrange(sys.maxsize // len(model_names)) logging.info('Using random seed %d for the data pipeline', random_seed) # Predict structure for each of the sequences. - for fasta_path, fasta_name in zip(FLAGS.fasta_paths, fasta_names): + for i, fasta_path in enumerate(FLAGS.fasta_paths): + is_prokaryote = is_prokaryote_list[i] if run_multimer_system else None + fasta_name = fasta_names[i] predict_structure( fasta_path=fasta_path, fasta_name=fasta_name, @@ -289,19 +408,17 @@ def main(argv): model_runners=model_runners, amber_relaxer=amber_relaxer, benchmark=FLAGS.benchmark, - random_seed=random_seed) + random_seed=random_seed, + is_prokaryote=is_prokaryote) if __name__ == '__main__': flags.mark_flags_as_required([ 'fasta_paths', 'output_dir', - 'model_names', 'data_dir', - 'preset', 'uniref90_database_path', 'mgnify_database_path', - 'pdb70_database_path', 'template_mmcif_dir', 'max_template_date', 'obsolete_pdbs_path', diff --git a/run_alphafold_test.py b/run_alphafold_test.py index 9093eb234..4797f1153 100644 --- a/run_alphafold_test.py +++ b/run_alphafold_test.py @@ -26,7 +26,11 @@ class RunAlphafoldTest(parameterized.TestCase): - def test_end_to_end(self): + @parameterized.named_parameters( + ('relax', True), + ('no_relax', False), + ) + def test_end_to_end(self, do_relax): data_pipeline_mock = mock.Mock() model_runner_mock = mock.Mock() @@ -46,11 +50,13 @@ def test_end_to_end(self): 'logits': np.ones((10, 50)), }, 'plddt': np.ones(10) * 42, + 'ranking_confidence': 90, 'ptm': np.array(0.), 'aligned_confidence_probs': np.zeros((10, 10, 50)), 'predicted_aligned_error': np.zeros((10, 10)), 'max_predicted_aligned_error': np.array(0.), } + model_runner_mock.multimer_mode = False amber_relaxer_mock.process.return_value = ('RELAXED', None, None) fasta_path = os.path.join(absltest.get_default_test_tmpdir(), @@ -67,7 +73,7 @@ def test_end_to_end(self): output_dir_base=out_dir, data_pipeline=data_pipeline_mock, model_runners={'model1': model_runner_mock}, - amber_relaxer=amber_relaxer_mock, + amber_relaxer=amber_relaxer_mock if do_relax else None, benchmark=False, random_seed=0) @@ -76,10 +82,13 @@ def test_end_to_end(self): self.assertIn('test', base_output_files) target_output_files = os.listdir(os.path.join(out_dir, 'test')) - self.assertCountEqual( - ['features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', - 'relaxed_model1.pdb', 'result_model1.pkl', 'timings.json', - 'unrelaxed_model1.pdb'], target_output_files) + expected_files = [ + 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', + 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb', + ] + if do_relax: + expected_files.append('relaxed_model1.pdb') + self.assertCountEqual(expected_files, target_output_files) # Check that pLDDT is set in the B-factor column. with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f: diff --git a/scripts/download_all_data.sh b/scripts/download_all_data.sh index c88581067..a30013ca0 100755 --- a/scripts/download_all_data.sh +++ b/scripts/download_all_data.sh @@ -20,17 +20,17 @@ set -e if [[ $# -eq 0 ]]; then - echo "Error: download directory must be provided as an input argument." - exit 1 + echo "Error: download directory must be provided as an input argument." + exit 1 fi if ! command -v aria2c &> /dev/null ; then - echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." - exit 1 + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 fi DOWNLOAD_DIR="$1" -DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. +DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] then echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." @@ -42,12 +42,12 @@ SCRIPT_DIR="$(dirname "$(realpath "$0")")" echo "Downloading AlphaFold parameters..." bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}" -if [[ "${DOWNLOAD_MODE}" = full_dbs ]] ; then - echo "Downloading BFD..." - bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}" -else +if [[ "${DOWNLOAD_MODE}" = reduced_dbs ]] ; then echo "Downloading Small BFD..." bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}" +else + echo "Downloading BFD..." + bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}" fi echo "Downloading MGnify..." @@ -65,4 +65,10 @@ bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" echo "Downloading Uniref90..." bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" +echo "Downloading UniProt..." +bash "${SCRIPT_DIR}/download_uniprot.sh" "${DOWNLOAD_DIR}" + +echo "Downloading PDB SeqRes..." +bash "${SCRIPT_DIR}/download_pdb_seqres.sh" "${DOWNLOAD_DIR}" + echo "All data downloaded." diff --git a/scripts/download_alphafold_params.sh b/scripts/download_alphafold_params.sh index b77e47c89..1b6a9178a 100755 --- a/scripts/download_alphafold_params.sh +++ b/scripts/download_alphafold_params.sh @@ -31,7 +31,7 @@ fi DOWNLOAD_DIR="$1" ROOT_DIR="${DOWNLOAD_DIR}/params" -SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar" +SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar" BASENAME=$(basename "${SOURCE_URL}") mkdir --parents "${ROOT_DIR}" diff --git a/scripts/download_pdb_seqres.sh b/scripts/download_pdb_seqres.sh new file mode 100755 index 000000000..77a908a7b --- /dev/null +++ b/scripts/download_pdb_seqres.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads and unzips the PDB SeqRes database for AlphaFold. +# +# Usage: bash download_pdb_seqres.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/pdb_seqres" +SOURCE_URL="ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt" +BASENAME=$(basename "${SOURCE_URL}") + +mkdir --parents "${ROOT_DIR}" +aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" diff --git a/scripts/download_uniprot.sh b/scripts/download_uniprot.sh new file mode 100755 index 000000000..e815e3c3f --- /dev/null +++ b/scripts/download_uniprot.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Downloads, unzips and merges the SwissProt and TrEMBL databases for +# AlphaFold-Multimer. +# +# Usage: bash download_uniprot.sh /path/to/download/directory +set -e + +if [[ $# -eq 0 ]]; then + echo "Error: download directory must be provided as an input argument." + exit 1 +fi + +if ! command -v aria2c &> /dev/null ; then + echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." + exit 1 +fi + +DOWNLOAD_DIR="$1" +ROOT_DIR="${DOWNLOAD_DIR}/uniprot" + +TREMBL_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz" +TREMBL_BASENAME=$(basename "${TREMBL_SOURCE_URL}") +TREMBL_UNZIPPED_BASENAME="${TREMBL_BASENAME%.gz}" + +SPROT_SOURCE_URL="ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz" +SPROT_BASENAME=$(basename "${SPROT_SOURCE_URL}") +SPROT_UNZIPPED_BASENAME="${SPROT_BASENAME%.gz}" + +mkdir --parents "${ROOT_DIR}" +aria2c "${TREMBL_SOURCE_URL}" --dir="${ROOT_DIR}" +aria2c "${SPROT_SOURCE_URL}" --dir="${ROOT_DIR}" +pushd "${ROOT_DIR}" +gunzip "${ROOT_DIR}/${TREMBL_BASENAME}" +gunzip "${ROOT_DIR}/${SPROT_BASENAME}" + +# Concatenate TrEMBL and SwissProt, rename to uniprot and clean up. +cat "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" >> "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" +mv "${ROOT_DIR}/${TREMBL_UNZIPPED_BASENAME}" "${ROOT_DIR}/uniprot.fasta" +rm "${ROOT_DIR}/${SPROT_UNZIPPED_BASENAME}" +popd diff --git a/setup.py b/setup.py index 0006a2aa7..ff57c00ed 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name='alphafold', - version='2.0.0', + version='2.1.0', description='An implementation of the inference pipeline of AlphaFold v2.0.' 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' 'and published in Nature.', @@ -38,6 +38,7 @@ 'jax', 'ml-collections', 'numpy', + 'pandas', 'scipy', 'tensorflow-cpu', ], @@ -49,6 +50,9 @@ 'Operating System :: POSIX :: Linux', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], )