From baf17e65172009ff2340946f3c70800cf87c1d94 Mon Sep 17 00:00:00 2001 From: zachcp Date: Wed, 4 Dec 2024 10:35:03 -0500 Subject: [PATCH] Further Updating Tests (#64) * Begin Adding Support for Fixed Residues * implementing chain logic * add chains * begin implementation of residue bias --- ferritin-featurizers/src/commands/run.rs | 95 ++++++---- .../src/models/ligandmpnn/configs.rs | 46 +++-- .../src/models/ligandmpnn/featurizer.rs | 168 ++++++++++++++---- .../src/models/ligandmpnn/proteinfeatures.rs | 5 +- .../tests/test_cli_ligandmpnn.rs | 16 +- justfile | 2 +- 6 files changed, 224 insertions(+), 108 deletions(-) diff --git a/ferritin-featurizers/src/commands/run.rs b/ferritin-featurizers/src/commands/run.rs index 5ce19ef9..0e7a843c 100644 --- a/ferritin-featurizers/src/commands/run.rs +++ b/ferritin-featurizers/src/commands/run.rs @@ -2,9 +2,8 @@ use crate::models::ligandmpnn::configs::{ AABiasConfig, LigandMPNNConfig, MPNNExecConfig, MembraneMPNNConfig, ModelTypes, MultiPDBConfig, ResidueControl, RunConfig, }; -use crate::models::ligandmpnn::model::ScoreOutput; use candle_core::utils::{cuda_is_available, metal_is_available}; -use candle_core::{Device, Result}; +use candle_core::{Device, Result, Tensor}; use rand::Rng; pub fn device(cpu: bool) -> Result { @@ -43,7 +42,7 @@ pub fn execute( let device = device(false)?; let exec = MPNNExecConfig::new( - device, + device.clone(), pdb_path, // will need to omdify this for multiple run_config, Some(residue_control_config), @@ -53,7 +52,7 @@ pub fn execute( Some(multi_pdb_config), )?; - // Crete Default Values ------------------------------------------------------------ + // Create Default Values ------------------------------------------------------------ // let model_type = exec .run_config @@ -74,40 +73,70 @@ pub fn execute( // Load The model ------------------------------------------------------------ let model = exec.load_model(model_type)?; - println!("Model Loaded!"); + let mut prot_features = exec.generate_protein_features()?; - println!("Generating Protein Features"); - let prot_features = exec.generate_protein_features()?; - println!("Protein Features Loaded!"); + // Calculate Masks ------------------------------------------------------------ - // Create the output folders - println!("Creating the Outputs"); - std::fs::create_dir_all(format!("{}/seqs", out_folder))?; - std::fs::create_dir_all(format!("{}/backbones", out_folder))?; - std::fs::create_dir_all(format!("{}/packed", out_folder))?; - std::fs::create_dir_all(format!("{}/stats", out_folder))?; + println!("Generating Chains to Design. Tensor of [B,L]"); + let chains_to_design: Vec = match &exec.residue_control_config { + None => prot_features.chain_letters.clone(), + Some(config) => match &config.chains_to_design { + None => prot_features.chain_letters.clone(), + Some(chains) => chains.split(' ').map(String::from).collect(), + }, + }; + + // Chain tensor is the base. Additional Tensors can be added on top. + let mut chain_mask_tensor = prot_features.get_chain_mask_tensor(chains_to_design, &device)?; + // Residue-Related ------------------------------------------- + if let Some(res) = exec.residue_control_config { + // Residues + let fixed_residues = res.fixed_residues.unwrap(); + let fixed_positions_tensor = prot_features.get_encoded_tensor(fixed_residues, &device)?; + // multiply the fixed positions to the chain tensor + chain_mask_tensor = chain_mask_tensor.mul(&fixed_positions_tensor)?; + } + + // bias-Related ------------------------------------------- + // Todo + // if let Some(aabias) = exec.aabias_config { + // let bias_tensor = &prot_features.create_bias_tensor(exec.aabias_config?.bias_aa).unwrap_or(''); + // let (batch_size, seq_length) = &prot_features.s.dims2()?; + // let mut base_bias = Tensor::zeros_like(&prot_features.s)?; + // println!("BIAS!! Dims for S {:?}", base_bias.dims()); + // // let bias_aa: Tensor = match aabias.bias_aa { + // None => + // } + + // Update the Mask Here + prot_features.update_mask(chain_mask_tensor)?; + + // Sample from the Model ------------------------------------------- println!("Sampling from the Model..."); println!("Temp and Seed are: temp: {:}, seed: {:}", temperature, seed); let model_sample = model.sample(&prot_features, temperature as f64, seed as u64)?; println!("{:?}", model_sample); - std::fs::create_dir_all(format!("{}/seqs", out_folder))?; - let sequences = model_sample.get_sequences()?; - // println!("DECODING ORDER: {:?}", model_sample.get_decoding_order()?); - - let fasta_path = format!("{}/seqs/output.fasta", out_folder); - let mut fasta_content = String::new(); - for (i, seq) in sequences.iter().enumerate() { - fasta_content.push_str(&format!(">sequence_{}\n{}\n", i + 1, seq)); - } - std::fs::write(fasta_path, fasta_content)?; - - // Score a Protein! - println!("Scoring the Protein..."); - let model_score = model.score(&prot_features, false)?; - println!("Protein Score: {:?}", model_score); + let _ = { + // Create the output folders + println!("Creating the Outputs"); + std::fs::create_dir_all(format!("{}/seqs", out_folder))?; + std::fs::create_dir_all(format!("{}/backbones", out_folder))?; + std::fs::create_dir_all(format!("{}/packed", out_folder))?; + std::fs::create_dir_all(format!("{}/stats", out_folder))?; + std::fs::create_dir_all(format!("{}/seqs", out_folder))?; + let sequences = model_sample.get_sequences()?; + let fasta_path = format!("{}/seqs/output.fasta", out_folder); + let mut fasta_content = String::new(); + for (i, seq) in sequences.iter().enumerate() { + fasta_content.push_str(&format!(">sequence_{}\n{}\n", i + 1, seq)); + } + std::fs::write(fasta_path, fasta_content)?; + }; + // note this is only the Score outputs. + // It doesn't have the other fields in the pytorch implmentation if save_stats { // out_dict = {} // out_dict["generated_sequences"] = S_stack.cpu() @@ -119,16 +148,8 @@ pub fn execute( // out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu() // out_dict["seed"] = seed // out_dict["temperature"] = args.temperature - // if args.save_stats: - // torch.save(out_dict, output_stats_path) - // - // model_score.get_decoding_order() - // model_score.get_sequences() let outfile = format!("{}/stats/stats.safetensors", out_folder); - // note this is only the Score outputs. - // It doesn't have the other fields in teh pytoch implmentaiton model_sample.save_as_safetensors(outfile); } - Ok(()) } diff --git a/ferritin-featurizers/src/models/ligandmpnn/configs.rs b/ferritin-featurizers/src/models/ligandmpnn/configs.rs index 4cc11470..bc1537d4 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/configs.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/configs.rs @@ -87,24 +87,32 @@ impl MPNNExecConfig { let (pdb, _) = pdbtbx::open(self.protein_inputs.clone()).expect("A PDB or CIF file"); let ac = AtomCollection::from(&pdb); - // let s = ac.encode_amino_acids(&device)?; let s = ac .encode_amino_acids(&device) .expect("A complete convertion to locations"); - let x_37 = ac.to_numeric_atom37(&device)?; - - // Note: default to 1! let x_37_mask = Tensor::ones((x_37.dim(0)?, x_37.dim(1)?), base_dtype, &device)?; - // println!("This is the atom map: {:?}", x_37_mask.dims()); - let (y, y_t, y_m) = ac.to_numeric_ligand_atoms(&device)?; - - // R_idx = np.array(CA_resnums, dtype=np.int32) let res_idx = ac.get_res_index(); let res_idx_len = res_idx.len(); let res_idx_tensor = Tensor::from_vec(res_idx, (1, res_idx_len), &device)?; + // chain residues + let chain_letters: Vec = ac + .iter_residues_aminoacid() + .map(|res| res.chain_id) + .collect(); + + // unique Chains + let chain_list: Vec = chain_letters + .clone() + .into_iter() + .collect::>() + .into_iter() + .collect(); + + // assert_eq!(true, false); + // update residue info // residue_config: Option, // handle these: @@ -149,17 +157,17 @@ impl MPNNExecConfig { // println!("Returning Protein Features...."); // return ligand MPNN. Ok(ProteinFeatures { - s, // protein amino acids sequences as 1D Tensor of u32 - x: x_37, // protein co-oords by residue [1, 37, 4] - x_mask: Some(x_37_mask), // protein mask by residue - y, // ligand coords - y_t, // encoded ligand atom names - y_m: Some(y_m), // ligand mask - r_idx: Some(res_idx_tensor), // protein residue indices shape=[length] - chain_labels: None, // # protein chain letters shape=[length] - chain_letters: None, // chain_letters: shape=[length] - mask_c: None, // mask_c: shape=[length] - chain_list: None, + s, // protein amino acids sequences as 1D Tensor of u32 + x: x_37, // protein co-oords by residue [1, 37, 4] + x_mask: Some(x_37_mask), // protein mask by residue + y, // ligand coords + y_t, // encoded ligand atom names + y_m: Some(y_m), // ligand mask + r_idx: res_idx_tensor, // protein residue indices shape=[length] + chain_labels: None, // # protein chain letters shape=[length] + chain_letters, // chain_letters: shape=[length] + mask_c: None, // mask_c: shape=[length] + chain_list, }) } } diff --git a/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs b/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs index 4be18369..b9840f23 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs @@ -14,7 +14,7 @@ use candle_core::{DType, Device, Result, Tensor}; use ferritin_core::AtomCollection; use itertools::MultiUnzip; use pdbtbx::Element; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use strum::IntoEnumIterator; // Helper Fns -------------------------------------- @@ -49,45 +49,46 @@ impl LMPNNFeatures for AtomCollection { } // equivalent to protien MPNN's parse_PDB fn featurize(&self, device: &Device) -> Result { - let x_37 = self.to_numeric_atom37(device)?; - let x_37_m = Tensor::zeros((x_37.dim(0)?, x_37.dim(1)?), DType::F64, device)?; - let (y, y_t, y_m) = self.to_numeric_ligand_atoms(device)?; + todo!(); + // let x_37 = self.to_numeric_atom37(device)?; + // let x_37_m = Tensor::zeros((x_37.dim(0)?, x_37.dim(1)?), DType::F64, device)?; + // let (y, y_t, y_m) = self.to_numeric_ligand_atoms(device)?; - // get CB locations... - // although we have these already for our full set... - let cb = calculate_cb(&x_37); + // // get CB locations... + // // although we have these already for our full set... + // let cb = calculate_cb(&x_37); - // chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32) - let chain_labels = self.get_resids(); // <-- need to double-check shape. I think this is all-atom + // // chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32) + // let chain_labels = self.get_resids(); // <-- need to double-check shape. I think this is all-atom - // R_idx = np.array(CA_resnums, dtype=np.int32) - // let _r_idx = self.get_resids(); // todo()! + // // R_idx = np.array(CA_resnums, dtype=np.int32) + // // let _r_idx = self.get_resids(); // todo()! - // amino acid names as int.... - let s = self.encode_amino_acids(device)?; + // // amino acid names as int.... + // let s = self.encode_amino_acids(device)?; - // coordinates of the backbone atoms - let indices = Tensor::from_slice( - &[0i64, 1i64, 2i64, 4i64], // index of N/CA/C/O as integers - (4,), - &device, - )?; + // // coordinates of the backbone atoms + // let indices = Tensor::from_slice( + // &[0i64, 1i64, 2i64, 4i64], // index of N/CA/C/O as integers + // (4,), + // &device, + // )?; + + // let x = x_37.index_select(&indices, 1)?; - let x = x_37.index_select(&indices, 1)?; - - Ok(ProteinFeatures { - s, - x, - x_mask: Some(x_37_m), - y, - y_t, - y_m: Some(y_m), - r_idx: None, - chain_labels: None, - chain_letters: None, - mask_c: None, - chain_list: None, - }) + // Ok(ProteinFeatures { + // s, + // x, + // x_mask: Some(x_37_m), + // y, + // y_t, + // y_m: Some(y_m), + // r_idx: None, + // chain_labels: None, + // chain_letters: None, + // mask_c: None, + // chain_list: None, + // }) } /// create numeric Tensor of shape [1, , 4, 3] where the 4 is N/CA/C/O fn to_numeric_backbone_atoms(&self, device: &Device) -> Result { @@ -287,14 +288,14 @@ pub struct ProteinFeatures { /// ligand mask pub(crate) y_m: Option, /// R_idx: Tensor dimensions: torch.Size([93]) # protein residue indices shape=[length] - pub(crate) r_idx: Option, + pub(crate) r_idx: Tensor, /// chain_labels: Tensor dimensions: torch.Size([93]) # protein chain letters shape=[length] pub(crate) chain_labels: Option>, /// chain_letters: NumPy array dimensions: (93,) - pub(crate) chain_letters: Option>, + pub(crate) chain_letters: Vec, /// mask_c: Tensor dimensions: torch.Size([93]) pub(crate) mask_c: Option, - pub(crate) chain_list: Option>, + pub(crate) chain_list: Vec, // CA_icodes: NumPy array dimensions: (93) // put these here temporarily // bias_AA: Option, @@ -321,8 +322,8 @@ impl ProteinFeatures { pub fn get_sequence_mask(&self) -> Option<&Tensor> { self.x_mask.as_ref() } - pub fn get_residue_index(&self) -> Option<&Tensor> { - self.r_idx.as_ref() + pub fn get_residue_index(&self) -> &Tensor { + &self.r_idx } pub fn save_to_safetensor(&self, path: &str) -> Result<()> { let mut tensors: HashMap = HashMap::new(); @@ -335,4 +336,93 @@ impl ProteinFeatures { candle_core::safetensors::save(&tensors, path)?; Ok(()) } + pub fn get_encoded( + &self, + ) -> Result<(Vec, HashMap, HashMap)> { + // Creates a set of mappings from + + let r_idx_list = &self.r_idx.flatten_all()?.to_vec1::()?; + let chain_letters_list = &self.chain_letters; + + let encoded_residues: Vec = r_idx_list + .iter() + .enumerate() + .map(|(i, r_idx)| format!("{}{}", chain_letters_list[i], r_idx)) + .collect(); + + let encoded_residue_dict: HashMap = encoded_residues + .iter() + .enumerate() + .map(|(i, s)| (s.clone(), i)) + .collect(); + + let encoded_residue_dict_rev: HashMap = encoded_residues + .iter() + .enumerate() + .map(|(i, s)| (i, s.clone())) + .collect(); + + Ok(( + encoded_residues, + encoded_residue_dict, + encoded_residue_dict_rev, + )) + } + // Fixed Residue List --> Tensor of 1/0 + // Inputs: `"C1 C2 C3 C4 C5 C6 C7 C8 C9 C10` + pub fn get_encoded_tensor(&self, fixed_residues: String, device: &Device) -> Result { + let res_set: HashSet = fixed_residues.split(' ').map(String::from).collect(); + let (encoded_res, _, _) = &self.get_encoded()?; + candle_core::Tensor::from_iter( + encoded_res + .iter() + .map(|item| u32::from(!res_set.contains(item))), + device, + ) + } + pub fn get_chain_mask_tensor( + &self, + chains_to_design: Vec, + device: &Device, + ) -> Result { + let mask_values: Vec = self + .chain_letters + .iter() + .map(|chain| u32::from(chains_to_design.contains(chain))) + .collect(); + + Tensor::from_iter(mask_values, device) + } + pub fn update_mask(&mut self, tensor: Tensor) -> Result<()> { + if let Some(ref mask) = self.x_mask { + self.x_mask = Some(mask.mul(&tensor)?); + } else { + self.x_mask = Some(tensor); + } + Ok(()) + } + // Fixed Residue List --> Tensor of length 21 + // Inputs: `A:10.0"` + pub fn create_bias_tensor(&self, bias_aa: Option) -> Result { + let device = self.s.device(); + let dtype = self.s.dtype(); + match bias_aa { + None => Tensor::zeros((21), dtype, device), + Some(bias_aa) => { + let mut bias_values = vec![0.0f32; 21]; + for pair in bias_aa.split(',') { + if let Some((aa, value_str)) = pair.split_once(':') { + if let Ok(value) = value_str.parse::() { + // Get first char from aa str and convert u32 to usize for indexing + if let Some(aa_char) = aa.chars().next() { + let idx = aa1to_int(aa_char) as usize; + bias_values[idx] = value; + } + } + } + } + Tensor::from_slice(&bias_values, (21), device) + } + } + } } diff --git a/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs b/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs index 06b141d4..93dca281 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs @@ -144,12 +144,10 @@ impl ProteinFeaturesModel { input_features: &ProteinFeatures, device: &Device, ) -> Result<(Tensor, Tensor)> { - println!("In the Features Forward!"); let x = input_features.get_coords(); let mask = input_features.x_mask.as_ref().unwrap(); - let r_idx = input_features.get_residue_index().unwrap(); + let r_idx = input_features.get_residue_index(); // let chain_labels = input_features.chain_labels.as_ref(); - // todo: fix // let chain_labels = input_features.get_chain_labels(); let chain_labels = Tensor::zeros_like(r_idx)?; @@ -222,7 +220,6 @@ impl ProteinFeaturesModel { .broadcast_as(target_shape)? .to_dtype(DType::F32)?; // [1, 93, 93] - println!("Prepraring the Offset Tensor..."); let offset = (r_idx_expanded1 - r_idx_expanded2)?; let offset = gather_edges(&offset.unsqueeze(D::Minus1)?, &e_idx)?; let offset = offset.squeeze(D::Minus1)?; diff --git a/ferritin-featurizers/tests/test_cli_ligandmpnn.rs b/ferritin-featurizers/tests/test_cli_ligandmpnn.rs index cae1c1c1..e8e9ee00 100644 --- a/ferritin-featurizers/tests/test_cli_ligandmpnn.rs +++ b/ferritin-featurizers/tests/test_cli_ligandmpnn.rs @@ -123,13 +123,12 @@ mod tests { } #[test] - #[ignore] fn test_cli_command_run_example_06() { - let (pdbfile, _tmp) = TestFile::protein_03().create_temp().unwrap(); - let out_folder = tempfile::tempdir().unwrap().into_path(); - let mut cmd = Command::cargo_bin("ferritin-featurizers").unwrap(); + let (pdbfile, _tmp, out_folder) = setup("./outputs/fix_residues".to_string()); - cmd.arg("run") + let assert = Command::cargo_bin("ferritin-featurizers") + .unwrap() + .arg("run") .arg("--seed") .arg("111") .arg("--pdb-path") @@ -138,10 +137,11 @@ mod tests { .arg(&out_folder) .arg("--fixed-residues") .arg("C1 C2 C3 C4 C5 C6 C7 C8 C9 C10") - .arg("--bias-AA") - .arg("A:10.0"); + .arg("--bias-aa") + .arg("A:10.0") + .assert() + .success(); - let assert = cmd.assert().success(); println!("Successful command...."); assert!(out_folder.exists()); println!("Output: {:?}", assert.get_output()); diff --git a/justfile b/justfile index abeb3ad8..188eae28 100644 --- a/justfile +++ b/justfile @@ -33,7 +33,7 @@ upgrade: test2: - cargo test --features metal -p ferritin-featurizers test_cli_command_run_example_05 -- --nocapture + cargo test --features metal -p ferritin-featurizers test_cli_command_run_example_06 -- --nocapture test: cargo test