Skip to content

Commit

Permalink
Further Updating Tests (#64)
Browse files Browse the repository at this point in the history
* Begin Adding Support for Fixed Residues

* implementing chain logic

* add chains

* begin implementation of residue bias
  • Loading branch information
zachcp authored Dec 4, 2024
1 parent cb2a955 commit baf17e6
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 108 deletions.
95 changes: 58 additions & 37 deletions ferritin-featurizers/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ use crate::models::ligandmpnn::configs::{
AABiasConfig, LigandMPNNConfig, MPNNExecConfig, MembraneMPNNConfig, ModelTypes, MultiPDBConfig,
ResidueControl, RunConfig,
};
use crate::models::ligandmpnn::model::ScoreOutput;
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, Result};
use candle_core::{Device, Result, Tensor};
use rand::Rng;

pub fn device(cpu: bool) -> Result<Device> {
Expand Down Expand Up @@ -43,7 +42,7 @@ pub fn execute(
let device = device(false)?;

let exec = MPNNExecConfig::new(
device,
device.clone(),
pdb_path, // will need to omdify this for multiple
run_config,
Some(residue_control_config),
Expand All @@ -53,7 +52,7 @@ pub fn execute(
Some(multi_pdb_config),
)?;

// Crete Default Values ------------------------------------------------------------
// Create Default Values ------------------------------------------------------------
//
let model_type = exec
.run_config
Expand All @@ -74,40 +73,70 @@ pub fn execute(
// Load The model ------------------------------------------------------------

let model = exec.load_model(model_type)?;
println!("Model Loaded!");
let mut prot_features = exec.generate_protein_features()?;

println!("Generating Protein Features");
let prot_features = exec.generate_protein_features()?;
println!("Protein Features Loaded!");
// Calculate Masks ------------------------------------------------------------

// Create the output folders
println!("Creating the Outputs");
std::fs::create_dir_all(format!("{}/seqs", out_folder))?;
std::fs::create_dir_all(format!("{}/backbones", out_folder))?;
std::fs::create_dir_all(format!("{}/packed", out_folder))?;
std::fs::create_dir_all(format!("{}/stats", out_folder))?;
println!("Generating Chains to Design. Tensor of [B,L]");
let chains_to_design: Vec<String> = match &exec.residue_control_config {
None => prot_features.chain_letters.clone(),
Some(config) => match &config.chains_to_design {
None => prot_features.chain_letters.clone(),
Some(chains) => chains.split(' ').map(String::from).collect(),
},
};

// Chain tensor is the base. Additional Tensors can be added on top.
let mut chain_mask_tensor = prot_features.get_chain_mask_tensor(chains_to_design, &device)?;

// Residue-Related -------------------------------------------
if let Some(res) = exec.residue_control_config {
// Residues
let fixed_residues = res.fixed_residues.unwrap();
let fixed_positions_tensor = prot_features.get_encoded_tensor(fixed_residues, &device)?;
// multiply the fixed positions to the chain tensor
chain_mask_tensor = chain_mask_tensor.mul(&fixed_positions_tensor)?;
}

// bias-Related -------------------------------------------
// Todo
// if let Some(aabias) = exec.aabias_config {
// let bias_tensor = &prot_features.create_bias_tensor(exec.aabias_config?.bias_aa).unwrap_or('');
// let (batch_size, seq_length) = &prot_features.s.dims2()?;
// let mut base_bias = Tensor::zeros_like(&prot_features.s)?;
// println!("BIAS!! Dims for S {:?}", base_bias.dims());
// // let bias_aa: Tensor = match aabias.bias_aa {
// None =>
// }

// Update the Mask Here
prot_features.update_mask(chain_mask_tensor)?;

// Sample from the Model -------------------------------------------
println!("Sampling from the Model...");
println!("Temp and Seed are: temp: {:}, seed: {:}", temperature, seed);
let model_sample = model.sample(&prot_features, temperature as f64, seed as u64)?;
println!("{:?}", model_sample);

std::fs::create_dir_all(format!("{}/seqs", out_folder))?;
let sequences = model_sample.get_sequences()?;
// println!("DECODING ORDER: {:?}", model_sample.get_decoding_order()?);

let fasta_path = format!("{}/seqs/output.fasta", out_folder);
let mut fasta_content = String::new();
for (i, seq) in sequences.iter().enumerate() {
fasta_content.push_str(&format!(">sequence_{}\n{}\n", i + 1, seq));
}
std::fs::write(fasta_path, fasta_content)?;

// Score a Protein!
println!("Scoring the Protein...");
let model_score = model.score(&prot_features, false)?;
println!("Protein Score: {:?}", model_score);
let _ = {
// Create the output folders
println!("Creating the Outputs");
std::fs::create_dir_all(format!("{}/seqs", out_folder))?;
std::fs::create_dir_all(format!("{}/backbones", out_folder))?;
std::fs::create_dir_all(format!("{}/packed", out_folder))?;
std::fs::create_dir_all(format!("{}/stats", out_folder))?;
std::fs::create_dir_all(format!("{}/seqs", out_folder))?;
let sequences = model_sample.get_sequences()?;
let fasta_path = format!("{}/seqs/output.fasta", out_folder);
let mut fasta_content = String::new();
for (i, seq) in sequences.iter().enumerate() {
fasta_content.push_str(&format!(">sequence_{}\n{}\n", i + 1, seq));
}
std::fs::write(fasta_path, fasta_content)?;
};

// note this is only the Score outputs.
// It doesn't have the other fields in the pytorch implmentation
if save_stats {
// out_dict = {}
// out_dict["generated_sequences"] = S_stack.cpu()
Expand All @@ -119,16 +148,8 @@ pub fn execute(
// out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu()
// out_dict["seed"] = seed
// out_dict["temperature"] = args.temperature
// if args.save_stats:
// torch.save(out_dict, output_stats_path)
//
// model_score.get_decoding_order()
// model_score.get_sequences()
let outfile = format!("{}/stats/stats.safetensors", out_folder);
// note this is only the Score outputs.
// It doesn't have the other fields in teh pytoch implmentaiton
model_sample.save_as_safetensors(outfile);
}

Ok(())
}
46 changes: 27 additions & 19 deletions ferritin-featurizers/src/models/ligandmpnn/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,32 @@ impl MPNNExecConfig {
let (pdb, _) = pdbtbx::open(self.protein_inputs.clone()).expect("A PDB or CIF file");
let ac = AtomCollection::from(&pdb);

// let s = ac.encode_amino_acids(&device)?;
let s = ac
.encode_amino_acids(&device)
.expect("A complete convertion to locations");

let x_37 = ac.to_numeric_atom37(&device)?;

// Note: default to 1!
let x_37_mask = Tensor::ones((x_37.dim(0)?, x_37.dim(1)?), base_dtype, &device)?;
// println!("This is the atom map: {:?}", x_37_mask.dims());

let (y, y_t, y_m) = ac.to_numeric_ligand_atoms(&device)?;

// R_idx = np.array(CA_resnums, dtype=np.int32)
let res_idx = ac.get_res_index();
let res_idx_len = res_idx.len();
let res_idx_tensor = Tensor::from_vec(res_idx, (1, res_idx_len), &device)?;

// chain residues
let chain_letters: Vec<String> = ac
.iter_residues_aminoacid()
.map(|res| res.chain_id)
.collect();

// unique Chains
let chain_list: Vec<String> = chain_letters
.clone()
.into_iter()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();

// assert_eq!(true, false);

// update residue info
// residue_config: Option<ResidueControl>,
// handle these:
Expand Down Expand Up @@ -149,17 +157,17 @@ impl MPNNExecConfig {
// println!("Returning Protein Features....");
// return ligand MPNN.
Ok(ProteinFeatures {
s, // protein amino acids sequences as 1D Tensor of u32
x: x_37, // protein co-oords by residue [1, 37, 4]
x_mask: Some(x_37_mask), // protein mask by residue
y, // ligand coords
y_t, // encoded ligand atom names
y_m: Some(y_m), // ligand mask
r_idx: Some(res_idx_tensor), // protein residue indices shape=[length]
chain_labels: None, // # protein chain letters shape=[length]
chain_letters: None, // chain_letters: shape=[length]
mask_c: None, // mask_c: shape=[length]
chain_list: None,
s, // protein amino acids sequences as 1D Tensor of u32
x: x_37, // protein co-oords by residue [1, 37, 4]
x_mask: Some(x_37_mask), // protein mask by residue
y, // ligand coords
y_t, // encoded ligand atom names
y_m: Some(y_m), // ligand mask
r_idx: res_idx_tensor, // protein residue indices shape=[length]
chain_labels: None, // # protein chain letters shape=[length]
chain_letters, // chain_letters: shape=[length]
mask_c: None, // mask_c: shape=[length]
chain_list,
})
}
}
Expand Down
Loading

0 comments on commit baf17e6

Please sign in to comment.