diff --git a/Cargo.lock b/Cargo.lock index e1bcbd4..1d75e10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1552,13 +1552,14 @@ dependencies = [ [[package]] name = "candle-core" version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e61e6a12c4b0660f105c11cbce42a5b33a392e73caaf465261b24210878fbe0e" +source = "git+https://github.com/zachcp/candle.git?branch=20241201-SA#5e648c32dee876e2fcf1afd06255392a6ba741e3" dependencies = [ "byteorder", + "candle-metal-kernels", "gemm", "half", "memmap2", + "metal 0.27.0", "num-traits", "num_cpus", "rand", @@ -1567,6 +1568,7 @@ dependencies = [ "safetensors", "thiserror", "ug", + "ug-metal", "yoke", "zip", ] @@ -1592,14 +1594,26 @@ dependencies = [ "ureq", ] +[[package]] +name = "candle-metal-kernels" +version = "0.8.0" +source = "git+https://github.com/zachcp/candle.git?branch=20241201-SA#5e648c32dee876e2fcf1afd06255392a6ba741e3" +dependencies = [ + "metal 0.27.0", + "once_cell", + "thiserror", + "tracing", +] + [[package]] name = "candle-nn" version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2032a3a41999801c5997ee7896ce816c361c49988435dedb7830634293c0798e" +source = "git+https://github.com/zachcp/candle.git?branch=20241201-SA#5e648c32dee876e2fcf1afd06255392a6ba741e3" dependencies = [ "candle-core", + "candle-metal-kernels", "half", + "metal 0.27.0", "num-traits", "rayon", "safetensors", @@ -1610,8 +1624,7 @@ dependencies = [ [[package]] name = "candle-transformers" version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b88bf4dbf13c3fa0a51624e4a2bf4db1d97596bbefd471424892a0b852013b7c" +source = "git+https://github.com/zachcp/candle.git?branch=20241201-SA#5e648c32dee876e2fcf1afd06255392a6ba741e3" dependencies = [ "byteorder", "candle-core", @@ -2454,6 +2467,7 @@ dependencies = [ "assert_cmd", "candle-core", "candle-hf-hub", + "candle-metal-kernels", "candle-nn", "candle-transformers", "clap", @@ -3883,6 +3897,21 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "metal" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" +dependencies = [ + "bitflags 2.6.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "metal" version = "0.29.0" @@ -4241,6 +4270,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" dependencies = [ "malloc_buf", + "objc_exception", ] [[package]] @@ -4446,6 +4476,15 @@ dependencies = [ "objc2-foundation", ] +[[package]] +name = "objc_exception" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" +dependencies = [ + "cc", +] + [[package]] name = "object" version = "0.36.5" @@ -6102,6 +6141,21 @@ dependencies = [ "thiserror", ] +[[package]] +name = "ug-metal" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e4ed1df2c20a1a138f993041f650cc84ff27aaefb4342b7f986e77d00e80799" +dependencies = [ + "half", + "metal 0.29.0", + "objc", + "serde", + "serde_json", + "thiserror", + "ug", +] + [[package]] name = "unicode-bidi" version = "0.3.17" @@ -6519,7 +6573,7 @@ dependencies = [ "libc", "libloading", "log", - "metal", + "metal 0.29.0", "naga", "ndk-sys 0.5.0+25.2.9519653", "objc", diff --git a/ferritin-featurizers/Cargo.toml b/ferritin-featurizers/Cargo.toml index ed8265c..776e0f1 100644 --- a/ferritin-featurizers/Cargo.toml +++ b/ferritin-featurizers/Cargo.toml @@ -8,10 +8,23 @@ description.workspace = true [dependencies] anyhow.workspace = true -candle-core = "0.8" + +candle-metal-kernels = { git = "https://github.com/zachcp/candle.git", package = "candle-metal-kernels", branch = "20241201-SA" } +candle-core = { git = "https://github.com/zachcp/candle.git", package = "candle-core", features = [ + "metal", +], branch = "20241201-SA" } + +# candle-core = { version = "0.8", features = ["metal"] } candle-hf-hub = "0.3.3" -candle-nn = "0.8" -candle-transformers = "0.8.0" + +# candle-nn = { version = "0.8", features = ["metal"] } +candle-nn = { git = "https://github.com/zachcp/candle.git", package = "candle-nn", features = [ + "metal", +], branch = "20241201-SA" } +candle-transformers = { git = "https://github.com/zachcp/candle.git", package = "candle-transformers", features = [ + "metal", +], branch = "20241201-SA" } +# candle-transformers = "0.8.0" clap = "4.5.21" ferritin-core = { path = "../ferritin-core" } itertools.workspace = true diff --git a/ferritin-featurizers/src/commands/run.rs b/ferritin-featurizers/src/commands/run.rs index 324dca5..88b4941 100644 --- a/ferritin-featurizers/src/commands/run.rs +++ b/ferritin-featurizers/src/commands/run.rs @@ -2,7 +2,30 @@ use crate::models::ligandmpnn::configs::{ AABiasConfig, LigandMPNNConfig, MPNNExecConfig, MembraneMPNNConfig, ModelTypes, MultiPDBConfig, ResidueControl, RunConfig, }; -use candle_core::Device; +use candle_core::utils::{cuda_is_available, metal_is_available}; +use candle_core::{Device, Result}; + +pub fn device(cpu: bool) -> Result { + if cpu { + Ok(Device::Cpu) + } else if cuda_is_available() { + Ok(Device::new_cuda(0)?) + } else if metal_is_available() { + Ok(Device::new_metal(0)?) + } else { + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + println!( + "Running on CPU, to run on GPU(metal), build this example with `--features metal`" + ); + } + #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] + { + println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); + } + Ok(Device::Cpu) + } +} pub fn execute( seed: i32, @@ -17,7 +40,7 @@ pub fn execute( multi_pdb_config: MultiPDBConfig, ) -> anyhow::Result<()> { // todo - whats the best way to handle device? - let device = &Device::Cpu; + let device = device(false)?; let model_type = model_type.unwrap_or(ModelTypes::ProteinMPNN); @@ -37,6 +60,7 @@ pub fn execute( println!("About to Load the model!"); let model = exec.load_model()?; println!("Model Loaded!"); + println!("Model Loaded on the {:?}", model.device); println!("Generating Protein Features"); let prot_features = exec.generate_protein_features()?; diff --git a/ferritin-featurizers/src/models/ligandmpnn/configs.rs b/ferritin-featurizers/src/models/ligandmpnn/configs.rs index ea609d8..43a9ede 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/configs.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/configs.rs @@ -35,14 +35,14 @@ pub struct MPNNExecConfig { pub(crate) membrane_mpnn_config: Option, pub(crate) multi_pdb_config: Option, pub(crate) residue_control_config: Option, - // device: &candle_core::Device, + pub(crate) device: Device, pub(crate) seed: i32, } impl MPNNExecConfig { pub fn new( seed: i32, - device: &Device, + device: Device, pdb_path: String, model_type: ModelTypes, run_config: RunConfig, @@ -62,15 +62,17 @@ impl MPNNExecConfig { residue_control_config: residue_config, multi_pdb_config: multi_pdb_specific, seed, - // device: device, + device: device, }) } // Todo: refactor this to use loader. pub fn load_model(&self) -> Result { + let default_dtype = DType::F32; + // this is a hidden dep.... let (mpnn_file, _handle) = TestFile::ligmpnn_pmpnn_01().create_temp()?; let pth = PthTensors::new(mpnn_file, Some("model_state_dict"))?; - let vb = VarBuilder::from_backend(Box::new(pth), DType::F32, Device::Cpu); + let vb = VarBuilder::from_backend(Box::new(pth), default_dtype, self.device.clone()); let pconf = ProteinMPNNConfig::proteinmpnn(); Ok(ProteinMPNN::load(vb, &pconf).expect("Unable to load the PMPNN Model")) } @@ -78,7 +80,7 @@ impl MPNNExecConfig { todo!() } pub fn generate_protein_features(&self) -> Result { - let device = Device::Cpu; + let device = self.device.clone(); let base_dtype = DType::F32; // init the Protein Features diff --git a/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs b/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs index c367cb3..4be1836 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/featurizer.rs @@ -24,7 +24,7 @@ fn is_heavy_atom(element: &Element) -> bool { /// Convert the AtomCollection into a struct that can be passed to a model. pub trait LMPNNFeatures { - fn encode_amino_acids(&self, device: &Device) -> Result<(Tensor)>; // ( residue types ) + fn encode_amino_acids(&self, device: &Device) -> Result; // ( residue types ) fn featurize(&self, device: &Device) -> Result; // need more control over this featurization process fn get_res_index(&self) -> Vec; fn to_numeric_backbone_atoms(&self, device: &Device) -> Result; // [residues, N/CA/C/O, xyz] diff --git a/ferritin-featurizers/src/models/ligandmpnn/model.rs b/ferritin-featurizers/src/models/ligandmpnn/model.rs index 0c7de30..4b5f0d6 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/model.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/model.rs @@ -27,7 +27,7 @@ pub fn multinomial_sample(probs: &Tensor, temperature: f64, seed: u64) -> Result ); // Sample from the probabilities - let idx = logits_processor.sample(&probs)?; + let idx = logits_processor.sample(probs)?; // Convert to tensor Tensor::new(&[idx], probs.device()) @@ -139,10 +139,9 @@ impl EncLayer { training: Option, ) -> Result<(Tensor, Tensor)> { println!("EncoderLayer: Starting forward pass"); - - let h_ev = cat_neighbors_nodes(h_v, h_e, e_idx)?; + let h_v = h_v.to_dtype(DType::F32)?; + let h_ev = cat_neighbors_nodes(&h_v, h_e, e_idx)?; let h_v_expand = h_v.unsqueeze(D::Minus2)?; - // Explicitly specify the expansion dimensions let expand_shape = [ h_ev.dims()[0], // batch size @@ -150,8 +149,9 @@ impl EncLayer { h_ev.dims()[2], // number of neighbors h_v_expand.dims()[3], // hidden dimension ]; + let h_v_expand = h_v_expand.expand(&expand_shape)?.to_dtype(h_ev.dtype())?; - let h_ev = Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?; + let h_ev = Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?.contiguous()?; let h_message = self.w1.forward(&h_ev)?; let h_message = h_message.clamp(-20.0, 20.0)?; // Clip after w1 let h_message = h_message.gelu()?; @@ -183,7 +183,6 @@ impl EncLayer { let h_v = h_v.to_dtype(DType::F32)?; self.norm1.forward(&(h_v + dh_dropout)?)? }; - let dh = self.dense.forward(&h_v)?; let h_v = { let dh_dropout = self @@ -208,7 +207,7 @@ impl EncLayer { ]; let h_v_expand = h_v_expand.expand(&expand_shape)?; let h_v_expand = h_v_expand.to_dtype(h_ev.dtype())?; - let h_ev = Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?; + let h_ev = Tensor::cat(&[&h_v_expand, &h_ev], D::Minus1)?.contiguous()?; let h_message = self .w11 .forward(&h_ev)? @@ -295,7 +294,8 @@ impl DecLayer { h_v.dims()[2], // keep original hidden dim (128) ]; let h_v_expand = h_v.unsqueeze(D::Minus2)?.expand(&expand_shape)?; - let h_ev = Tensor::cat(&[&h_v_expand, h_e], D::Minus1)?; + let h_ev = Tensor::cat(&[&h_v_expand, h_e], D::Minus1)?.contiguous()?; + let h_message = self .w1 .forward(&h_ev)? @@ -303,6 +303,7 @@ impl DecLayer { .apply(&self.w2)? .gelu()? .apply(&self.w3)?; + let h_message = if let Some(mask) = mask_attend { mask.unsqueeze(D::Minus1)?.broadcast_mul(&h_message)? } else { @@ -329,14 +330,14 @@ impl DecLayer { // https://github.com/dauparas/LigandMPNN/blob/main/model_utils.py#L10C7-L10C18 pub struct ProteinMPNN { - config: ProteinMPNNConfig, - decoder_layers: Vec, - device: Device, - encoder_layers: Vec, - features: ProteinFeaturesModel, - w_e: Linear, - w_out: Linear, - w_s: Embedding, + pub(crate) config: ProteinMPNNConfig, + pub(crate) decoder_layers: Vec, + pub(crate) device: Device, + pub(crate) encoder_layers: Vec, + pub(crate) features: ProteinFeaturesModel, + pub(crate) w_e: Linear, + pub(crate) w_out: Linear, + pub(crate) w_s: Embedding, } impl ProteinMPNN { @@ -377,7 +378,7 @@ impl ProteinMPNN { Ok(Self { config: config.clone(), // todo: check the\is clone later... decoder_layers, - device: Device::Cpu, + device: vb.device().clone(), encoder_layers, features, w_e, @@ -395,8 +396,9 @@ impl ProteinMPNN { // todo!() // } fn encode(&self, features: &ProteinFeatures) -> Result<(Tensor, Tensor, Tensor)> { - let device = &Device::Cpu; // todo: get device more elegantly + println!("encoded device! {:?}", self.device); let s_true = &features.get_sequence(); + let base_dtype = DType::F32; // needed for the MaskAttend let mask = match features.get_sequence_mask() { @@ -404,56 +406,44 @@ impl ProteinMPNN { None => &Tensor::ones_like(&s_true)?, }; - match self.config.model_type { ModelTypes::ProteinMPNN => { - let (e, e_idx) = self.features.forward(features, device)?; + let (e, e_idx) = self.features.forward(features, &self.device)?; let mut h_v = Tensor::zeros( (e.dim(0)?, e.dim(1)?, e.dim(D::Minus1)?), - DType::F64, - device, + base_dtype, + &self.device, )?; let mut h_e = self.w_e.forward(&e)?; let mask_attend = if let Some(mask) = features.get_sequence_mask() { - // First unsqueeze mask let mask_expanded = mask.unsqueeze(D::Minus1)?; // [B, L, 1] - - // Gather using E_idx + // Gather using E_idx let mask_gathered = gather_nodes(&mask_expanded, &e_idx)?; - let mask_gathered = mask_gathered.squeeze(D::Minus1)?; - // Multiply original mask with gathered mask let mask_attend = { let mask_unsqueezed = mask.unsqueeze(D::Minus1)?; // [B, L, 1] // Explicitly expand mask_unsqueezed to match mask_gathered dimensions - let mask_expanded = mask_unsqueezed - .expand(( - mask_gathered.dim(0)?, // batch - mask_gathered.dim(1)?, // sequence length - mask_gathered.dim(2)?, // number of neighbors - ))? - .contiguous()?; - + let mask_expanded = mask_unsqueezed.expand(( + mask_gathered.dim(0)?, // batch + mask_gathered.dim(1)?, // sequence length + mask_gathered.dim(2)?, // number of neighbors + ))?; // Now do the multiplication with explicit shapes mask_expanded.mul(&mask_gathered)? }; mask_attend } else { let (b, l) = mask.dims2()?; - let ones = Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, device)?; + let ones = Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, &self.device)?; println!("Created default ones mask dims: {:?}", ones.dims()); - ones }; - + println!("Beginning the Encoding..."); for (i, layer) in self.encoder_layers.iter().enumerate() { - - let h_v_f32 = h_v.to_dtype(DType::F32)?; - let h_e_f32 = h_e.to_dtype(DType::F32)?; let (new_h_v, new_h_e) = layer.forward( &h_v, &h_e, @@ -462,9 +452,6 @@ impl ProteinMPNN { Some(&mask_attend), Some(false), )?; - - let new_h_v_f32 = new_h_v.to_dtype(DType::F32)?; - let new_h_e_f32 = new_h_e.to_dtype(DType::F32)?; h_v = new_h_v; h_e = new_h_e; } @@ -508,6 +495,9 @@ impl ProteinMPNN { } } pub fn sample(&self, features: &ProteinFeatures) -> Result { + // "global" dtype + let sample_dtype = DType::F32; + let ProteinFeatures { x, s, @@ -522,19 +512,19 @@ impl ProteinMPNN { let (b, l) = s.dims2()?; // Todo: This is a hack. we should be passing in encoded chains. - let chain_mask = Tensor::ones_like(&x_mask.as_ref().unwrap())?.to_dtype(DType::F32)?; - let chain_mask = x_mask.as_ref().unwrap().mul(&chain_mask)?.contiguous()?; // update chain_M to include missing regions; + let chain_mask = Tensor::ones_like(&x_mask.as_ref().unwrap())?.to_dtype(sample_dtype)?; + let chain_mask = x_mask.as_ref().unwrap().mul(&chain_mask)?; let (h_v, h_e, e_idx) = self.encode(features)?; // this might be a bad rand implementation - let rand_tensor = Tensor::randn(0., 0.25, (b, l), device)?.to_dtype(DType::F32)?; + let rand_tensor = Tensor::randn(0f32, 0.25f32, (b, l), device)?.to_dtype(sample_dtype)?; let decoding_order = (&chain_mask + 0.0001)? .mul(&rand_tensor.abs()?)? .arg_sort_last_dim(false)?; // Todo add bias // # [B,L,21] - amino acid bias per position // bias = feature_dict["bias"] - let bias = Tensor::ones((b, l, 21), DType::F32, device)?; + let bias = Tensor::ones((b, l, 21), sample_dtype, device)?; println!("todo: We need to add the bias!"); // Todo! Fix this hack. @@ -544,13 +534,18 @@ impl ProteinMPNN { let symmetry_residues: Option> = None; match symmetry_residues { None => { - let e_idx = e_idx.repeat(&[b, 1, 1])?.contiguous()?; - let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1., 0.)?; - let tril = Tensor::tril2(l, DType::F64, device)?; + let e_idx = e_idx.repeat(&[b, 1, 1])?; + let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1f32, 0f32)? + .to_dtype(sample_dtype)? + .contiguous()?; + let tril = Tensor::tril2(l, sample_dtype, device)?; let tril = tril.unsqueeze(0)?; - let temp = tril.matmul(&permutation_matrix_reverse.transpose(1, 2)?)?; //tensor of shape (b, i, q) - let order_mask_backward = - temp.matmul(&permutation_matrix_reverse.transpose(1, 2)?)?; // This will give us a tensor of shape (b, q, p) + let temp = tril + .matmul(&permutation_matrix_reverse.transpose(1, 2)?)? + .contiguous()?; //tensor of shape (b, i, q) + let order_mask_backward = temp + .matmul(&permutation_matrix_reverse.transpose(1, 2)?)? + .contiguous()?; // This will give us a tensor of shape (b, q, p) let mask_attend = order_mask_backward .gather(&e_idx, 2)? .unsqueeze(D::Minus1)?; @@ -558,8 +553,7 @@ impl ProteinMPNN { // Broadcast mask_1d to match mask_attend's shape let mask_1d = mask_1d .broadcast_as(mask_attend.shape())? - .to_dtype(DType::F64)?; - + .to_dtype(sample_dtype)?; let mask_bw = mask_1d.mul(&mask_attend)?; let mask_fw = mask_1d.mul(&(Tensor::ones_like(&mask_attend)? - mask_attend)?)?; @@ -568,13 +562,13 @@ impl ProteinMPNN { let s_true = s_true.repeat((b, 1))?; let h_v = h_v.repeat((b, 1, 1))?; let h_e = h_e.repeat((b, 1, 1, 1))?; - let chain_mask = &chain_mask.repeat((b, 1))?.contiguous()?; let mask = x_mask.as_ref().unwrap().repeat((b, 1))?.contiguous()?; - let bias = bias.repeat((b, 1, 1))?.contiguous()?; - let mut all_probs = Tensor::zeros((b, l, 20), DType::F32, device)?; - let mut all_log_probs = Tensor::zeros((b, l, 21), DType::F32, device)?; // why is this one 21 and the others are 20? - let mut h_s = Tensor::zeros_like(&h_v)?.contiguous()?; - let s = (Tensor::ones((b, l), DType::I64, device)? * 20.)?; + let chain_mask = &chain_mask.repeat((b, 1))?; + let bias = bias.repeat((b, 1, 1))?; + let mut all_probs = Tensor::zeros((b, l, 20), sample_dtype, device)?; + let mut all_log_probs = Tensor::zeros((b, l, 21), sample_dtype, device)?; // why is this one 21 and the others are 20? + let mut h_s = Tensor::zeros_like(&h_v)?; + let s = (Tensor::ones((b, l), DType::U32, device)? * 20.)?; let mut h_v_stack = vec![h_v.clone()]; for i in 0..self.decoder_layers.len() { @@ -594,19 +588,21 @@ impl ProteinMPNN { // Gather masks and bias let chain_mask_t = chain_mask.gather(&t_gather, 1)?.squeeze(1)?; - let mask_t = mask.gather(&t_gather, 1)?.squeeze(1)?; + let mask_t = mask.gather(&t_gather, 1)?.squeeze(1)?.contiguous()?; let bias_t = bias .gather(&t_gather.unsqueeze(2)?.expand((b, 1, 21))?.contiguous()?, 1)? .squeeze(1)?; // Gather edge and node indices/features - let e_idx_t = e_idx.gather( - &t_gather - .unsqueeze(2)? - .expand((b, 1, e_idx.dim(2)?))? - .contiguous()?, - 1, - )?; + let e_idx_t = e_idx + .gather( + &t_gather + .unsqueeze(2)? + .expand((b, 1, e_idx.dim(2)?))? + .contiguous()?, + 1, + )? + .contiguous()?; let h_e_t = h_e.gather( &t_gather .unsqueeze(2)? @@ -615,22 +611,18 @@ impl ProteinMPNN { .contiguous()?, 1, )?; - let b = h_s.dim(0)?; // batch size let l = h_s.dim(1)?; // sequence length let n = e_idx_t.dim(2)?; // number of neighbors let c = h_s.dim(2)?; // channels/features - let h_e_t = h_e_t .squeeze(1)? // [B, N, C] .unsqueeze(1)? // [B, 1, N, C] .expand((b, l, n, c))? // [B, L, N, C] .contiguous()?; - let e_idx_t = e_idx_t .expand((b, l, n))? // [B, L, N] .contiguous()?; - let h_es_t = cat_neighbors_nodes(&h_s, &h_e_t, &e_idx_t)?; let h_exv_encoder_t = h_exv_encoder_fw.gather( &t_gather @@ -640,7 +632,6 @@ impl ProteinMPNN { .contiguous()?, 1, )?; - let mask_bw_t = mask_bw.gather( &t_gather .unsqueeze(2)? @@ -649,6 +640,7 @@ impl ProteinMPNN { .contiguous()?, 1, )?; + // Decoder layers loop for l in 0..self.decoder_layers.len() { let h_v_stack_l = &h_v_stack[l]; @@ -664,11 +656,12 @@ impl ProteinMPNN { let h_exv_encoder_t = h_exv_encoder_t .expand(h_esv_decoder_t.dims())? .contiguous()? - .to_dtype(DType::F64)?; + .to_dtype(sample_dtype)?; let h_esv_t = mask_bw_t - .mul(&h_esv_decoder_t.to_dtype(DType::F64)?)? + .mul(&h_esv_decoder_t.to_dtype(sample_dtype)?)? .add(&h_exv_encoder_t)? - .to_dtype(DType::F32)?; + .to_dtype(sample_dtype)? + .contiguous()?; let h_v_t = h_v_t .expand(( h_esv_t.dim(0)?, // batch size @@ -676,7 +669,6 @@ impl ProteinMPNN { h_v_t.dim(2)?, // features (128) ))? .contiguous()?; - let decoder_output = self.decoder_layers[l].forward( &h_v_t, &h_esv_t, @@ -684,7 +676,6 @@ impl ProteinMPNN { None, None, )?; - let t_expanded = t_gather.reshape(&[b])?; // This will give us a 1D tensor of shape [b] let decoder_output = decoder_output .narrow(1, 0, 1)? @@ -726,8 +717,8 @@ impl ProteinMPNN { }; let s_t = multinomial_sample(&probs_sample_1d, temperature, seed)?; // todo: move this upstream - let s_t = s_t.to_dtype(DType::F32)?; - let s_true = s_true.to_dtype(DType::F32)?; + let s_t = s_t.to_dtype(sample_dtype)?; + let s_true = s_true.to_dtype(sample_dtype)?; let s_true_t = s_true.gather(&t_gather, 1)?.squeeze(1)?; let s_t = s_t .mul(&chain_mask_t)? @@ -747,9 +738,9 @@ impl ProteinMPNN { 1, )?; h_s = h_s.index_add(&t_gather_expanded, &h_s_update, 1)?; - let zero_mask = t_gather.zeros_like()?.to_dtype(DType::I64)?; + let zero_mask = t_gather.zeros_like()?.to_dtype(DType::U32)?; let s = s.scatter_add(&t_gather, &zero_mask, 1)?; // Zero out - let s_t = s_t.to_dtype(DType::I64)?; + let s_t = s_t.to_dtype(DType::U32)?; let s = s.scatter_add(&t_gather, &s_t.unsqueeze(1)?, 1)?; let probs_update = chain_mask_t .unsqueeze(1)? @@ -763,7 +754,6 @@ impl ProteinMPNN { all_probs = all_probs.index_add(&t_expanded, &Tensor::zeros_like(&probs_update)?, 1)?; all_probs = all_probs.index_add(&t_expanded, &probs_update, 1)?; - let log_probs_update = chain_mask_t .unsqueeze(1)? .unsqueeze(2)? @@ -1084,24 +1074,25 @@ impl ProteinMPNN { // } pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) -> Result { + // "global" dtype + let sample_dtype = DType::F32; let ProteinFeatures { s, x, x_mask, .. } = &features; let s_true = &s.clone(); let device = s_true.device(); let (b, l) = s_true.dims2()?; - let mask = &x_mask.as_ref().clone(); let b_decoder: usize = b; - // Todo: This is a hack. we shouldbe passing in encoded chains. + // Todo: This is a hack. we should be passing in encoded chains. // Update chain_mask to include missing regions - let chain_mask = Tensor::zeros_like(mask.unwrap())?.to_dtype(DType::F32)?; + let chain_mask = Tensor::zeros_like(mask.unwrap())?.to_dtype(sample_dtype)?; let chain_mask = mask.unwrap().mul(&chain_mask)?; // encode ... let (h_v, h_e, e_idx) = self.encode(features)?; - let rand_tensor = Tensor::randn(0., 1., (b, l), device)?.to_dtype(DType::F32)?; + let rand_tensor = Tensor::randn(0., 1., (b, l), device)?.to_dtype(sample_dtype)?; // Compute decoding order let decoding_order = (chain_mask + 0.001)? @@ -1172,7 +1163,7 @@ impl ProteinMPNN { None => { let e_idx = e_idx.repeat(&[b_decoder, 1, 1])?; let permutation_matrix_reverse = one_hot(decoding_order.clone(), l, 1., 0.)?; - let tril = Tensor::tril2(l, DType::F64, device)?; + let tril = Tensor::tril2(l, sample_dtype, device)?; let tril = tril.unsqueeze(0)?; let temp = tril.matmul(&permutation_matrix_reverse.transpose(1, 2)?)?; // shape (b, i, q) let order_mask_backward = @@ -1184,7 +1175,7 @@ impl ProteinMPNN { // Broadcast mask_1d to match mask_attend's shape let mask_1d = mask_1d .broadcast_as(mask_attend.shape())? - .to_dtype(DType::F64)?; + .to_dtype(sample_dtype)?; let mask_bw = mask_1d.mul(&mask_attend)?; let mask_fw = mask_1d.mul(&(mask_attend - 1.0)?.neg()?)?; diff --git a/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs b/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs index bcac59c..d194dca 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs @@ -35,7 +35,6 @@ impl ProteinFeaturesModel { vb.device(), // device this should be passed in as param, vb.pp("embeddings"), // VarBuilder, )?; - let edge_embedding = linear::linear_no_bias(edge_in, edge_features, vb.pp("edge_embedding"))?; let norm_edges = layer_norm( @@ -75,11 +74,11 @@ impl ProteinFeaturesModel { // 5. Applies the RBF formula: exp(-(x-μ)²/σ²) const D_MIN: f64 = 2.0; const D_MAX: f64 = 22.0; - // Create centers (μ) - let d_mu = linspace(D_MIN, D_MAX, self.num_rbf, device)? + let d_mu = linspace(D_MIN, D_MAX, self.num_rbf, &Device::Cpu)? // Use CPU device + .to_dtype(DType::F32)? // Convert to F32 on CPU .reshape((1, 1, 1, self.num_rbf))? - .to_dtype(DType::F32)?; + .to_device(device)?; // Move to Metal device after conversion // Calculate width (σ) let d_sigma = (D_MAX - D_MIN) / self.num_rbf as f64; @@ -88,10 +87,11 @@ impl ProteinFeaturesModel { let d_mu_broadcast = d_mu.broadcast_as((dims[0], dims[1], dims[2], self.num_rbf))?; let d_expanded_broadcast = d_expanded.broadcast_as((dims[0], dims[1], dims[2], self.num_rbf))?; - - let diff = ((d_expanded_broadcast - d_mu_broadcast)? / d_sigma)?; + let d_sigma_tensor = + Tensor::new(&[d_sigma as f32], &device)?.broadcast_as(d_expanded_broadcast.shape())?; + let d_expanded = d_expanded.to_dtype(DType::F32)?.contiguous()?; + let diff = ((d_expanded_broadcast - d_mu_broadcast)? / d_sigma_tensor)?; let rbf = diff.powf(2.0)?.neg()?.exp()?; - Ok(rbf) } @@ -228,7 +228,6 @@ impl ProteinFeaturesModel { rbf_all.push(self._get_rbf(&c, &o, &e_idx, device)?); let rbf_all = Tensor::cat(&rbf_all, D::Minus1)?; - let dims = r_idx.dims(); let target_shape = (dims[0], dims[1], dims[1]); let r_idx_expanded1 = r_idx @@ -240,27 +239,22 @@ impl ProteinFeaturesModel { .broadcast_as(target_shape)? .to_dtype(DType::F32)?; // [1, 93, 93] + println!("Prepraring the Offset Tensor..."); let offset = (r_idx_expanded1 - r_idx_expanded2)?; let offset = gather_edges(&offset.unsqueeze(D::Minus1)?, &e_idx)?; let offset = offset.squeeze(D::Minus1)?; - let dims = chain_labels.dims(); let target_shape = (dims[0], dims[1], dims[1]); let d_chains = (&chain_labels.unsqueeze(2)?.broadcast_as(target_shape)? - &chain_labels.unsqueeze(1)?.broadcast_as(target_shape)?)? .eq(0.0)? - .to_dtype(DType::I64)?; - + .to_dtype(DType::U32)?; // E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0] let e_chains = gather_edges(&d_chains.unsqueeze(D::Minus1)?, &e_idx)?.squeeze(D::Minus1)?; - - println!("About to start the embeddings calculation..."); let e_positional = self .embeddings - .forward(&offset.to_dtype(DType::I64)?, &e_chains)?; - + .forward(&offset.to_dtype(DType::U32)?, &e_chains)?; println!("About to cat the pos embeddings..."); - let e = Tensor::cat(&[e_positional, rbf_all], D::Minus1)?; let e = self.edge_embedding.forward(&e)?; println!("About to start the normalization..."); @@ -303,9 +297,12 @@ impl PositionalEncodings { // return E fn forward(&self, offset: &Tensor, mask: &Tensor) -> Result { println!("In positional Embedding: forward"); + println!("Offset: {:?} ", offset); + println!("Offset DTYPE: {:?} ", offset.dtype()); - let max_rel = self.max_relative_feature as f64; + // Offset: Tensor[dims 1, 93, 24; u32, metal:4294969325] + let max_rel = self.max_relative_feature as f64; // First part: clip(offset + max_rel, 0, 2*max_rel) let d = (offset + max_rel)?; let d = d.clamp(0f64, 2.0 * max_rel)?; @@ -317,11 +314,32 @@ impl PositionalEncodings { let d = (masked_d + extra_term?)?; // Convert to integers for one_hot - let d = d.to_dtype(DType::I64)?; + // let d = d.to_dtype(DType::U32)?; + + // Todo: confirms this is correct. + // Better to move this upsteam + // Normalize the values by subtracting 97 (ASCII 'a') to make them 0-based + // let d_normalized = (d - 97u32)?; // This will make 'a'=0, 'b'=1, etc. + let offset_val = Tensor::full(97u32, d.dims(), d.device()); + let d_normalized = (d - offset_val)?; + + println!("After normalization:"); + let d_cpu = d_normalized.to_device(&Device::Cpu)?; + let d_vec = d_cpu.to_vec3::()?; + println!( + "Max value after norm: {}", + d_vec.iter().flatten().flatten().max().unwrap() + ); + println!( + "Min value after norm: {}", + d_vec.iter().flatten().flatten().min().unwrap() + ); // one_hot with correct depth using candle_nn::encoding::one_hot let depth = (2 * self.max_relative_feature + 2) as i64; - let d_onehot = one_hot(d, depth as usize, 1f32, 0f32)?; + // let d_onehot = one_hot(d, depth as usize, 1f32, 0f32)?; + let d_onehot = one_hot(d_normalized, depth as usize, 1f32, 0f32)?; + let d_onehot_float = d_onehot.to_dtype(DType::F32)?; self.linear.forward(&d_onehot_float) diff --git a/ferritin-featurizers/src/models/ligandmpnn/utilities.rs b/ferritin-featurizers/src/models/ligandmpnn/utilities.rs index 6efdf75..a478d22 100644 --- a/ferritin-featurizers/src/models/ligandmpnn/utilities.rs +++ b/ferritin-featurizers/src/models/ligandmpnn/utilities.rs @@ -36,29 +36,16 @@ pub fn cat_neighbors_nodes( e_idx: &Tensor, ) -> Result { let h_nodes_gathered = gather_nodes(h_nodes, e_idx)?; - // println!("h_nodes_gathered dims: {:?}", h_nodes_gathered.dims()); - // println!("h_neighbors dims: {:?}", h_neighbors.dims()); - // todo: fix this hacky Dtype - // let h_neighbors = h_neighbors.expand(( h_neighbors.dim(0)?, // 1 h_nodes.dim(1)?, // 93 h_neighbors.dim(2)?, // 24 h_neighbors.dim(3)?, // 128 ))?; - - // println!("h_neighbors dims 02: {:?}", h_neighbors.dims()); - let ten = Tensor::cat( &[h_neighbors, h_nodes_gathered.to_dtype(DType::F32)?], D::Minus1, ); - - // let dims = &ten.as_ref().unwrap().dims(); - - // println!("tensorcat: {:?}", dims); - - // assert_eq!(true, false); ten } @@ -70,6 +57,7 @@ pub fn compute_nearest_neighbors( k: usize, eps: f32, ) -> Result<(Tensor, Tensor)> { + // Todo: fix the F32/F64 issue let (_batch_size, seq_len, _) = coords.dims3()?; // broadcast_matmul handles broadcasting automatically @@ -77,7 +65,7 @@ pub fn compute_nearest_neighbors( let mask_2d = mask .unsqueeze(2)? .broadcast_matmul(&mask.unsqueeze(1)?)? - .to_dtype(DType::F64)?; // Convert to f64 once, at the start + .to_dtype(DType::F32)?; // Convert to f64 once, at the start // Compute pairwise distances with broadcasting let distances = (coords @@ -86,7 +74,8 @@ pub fn compute_nearest_neighbors( .powf(2.)? .sum(D::Minus1)? + eps as f64)? - .sqrt()?; + .sqrt()? + .to_dtype(DType::F32)?; // Apply mask // Get max values for adjustment @@ -102,9 +91,10 @@ pub fn compute_nearest_neighbors( // https://github.com/huggingface/candle/pull/2375/files#diff-e4d52a71060a80ac8c549f2daffcee77f9bf4de8252ad067c47b1c383c3ac828R957 pub fn topk_last_dim(xs: &Tensor, topk: usize) -> Result<(Tensor, Tensor)> { - let sorted_indices = xs.arg_sort_last_dim(false)?.to_dtype(DType::I64)?; + let sorted_indices = xs.arg_sort_last_dim(false)?.to_dtype(DType::U32)?; let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?; - Ok((xs.gather(&topk_indices, D::Minus1)?, topk_indices)) + let gathered = xs.gather(&topk_indices, D::Minus1)?; + Ok((gathered, topk_indices)) } /// Input coords. Output 1 Tensor @@ -297,10 +287,7 @@ pub fn gather_edges(edges: &Tensor, neighbor_idx: &Tensor) -> Result { neighbor_idx .unsqueeze(D::Minus1)? .expand((d1, d2, d3, edges.dim(D::Minus1)?))?; - - // println!("Neighbors idx: {:?}", neighbors.dims()); let edge_gather = edges.gather(&neighbors, 2)?; - // println!("edge_gather idx: {:?}", edge_gather.dims()); Ok(edge_gather) } @@ -309,17 +296,10 @@ pub fn gather_edges(edges: &Tensor, neighbor_idx: &Tensor) -> Result { /// Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C] /// Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C] pub fn gather_nodes(nodes: &Tensor, neighbor_idx: &Tensor) -> Result { - // print!( - // "IN GATHER NODES. Nodes, neighbor_idx: {:?}, {:?}", - // nodes.dims(), - // neighbor_idx.dims() - // ); let (batch_size, n_nodes, n_features) = nodes.dims3()?; let (_, _, k_neighbors) = neighbor_idx.dims3()?; - // Reshape neighbor_idx to [B, N*K] let neighbors_flat = neighbor_idx.reshape((batch_size, n_nodes * k_neighbors))?; - // Add feature dimension and expand let neighbors_flat = neighbors_flat .unsqueeze(2)? // Add feature dimension [B, N*K, 1] @@ -329,12 +309,6 @@ pub fn gather_nodes(nodes: &Tensor, neighbor_idx: &Tensor) -> Result { let neighbors_flat = neighbors_flat.contiguous()?; // Gather features let neighbor_features = nodes.gather(&neighbors_flat, 1)?; - - // println!( - // "neighbor_features dims before final reshape: {:?}", - // neighbor_features.dims() - // ); - // Reshape back to [B, N, K, C] neighbor_features.reshape((batch_size, n_nodes, k_neighbors, n_features)) } @@ -607,25 +581,29 @@ mod tests { // ATOM 4 O O . MET A 1 1 ? 26.748 9.469 -10.197 1.00 37.13 ? 0 MET A O 1 let backbone_coords = [ // Methionine - AA00 - ("N", (0,0, 0, ..), vec![24.277, 8.374, -9.854]), + ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]), ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]), ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]), ("O", (0, 0, 3, ..), vec![26.748, 9.469, -10.197]), // Valine - AA01 - ("N", (0,1, 0, ..), vec![25.964, 11.453, -10.903]), - ("CA", (0,1, 1, ..), vec![27.263, 11.924, -11.359]), - ("C", (0,1, 2, ..), vec![27.392, 13.428, -11.115]), - ("O", (0,1, 3, ..), vec![26.443, 14.184, -11.327]), + ("N", (0, 1, 0, ..), vec![25.964, 11.453, -10.903]), + ("CA", (0, 1, 1, ..), vec![27.263, 11.924, -11.359]), + ("C", (0, 1, 2, ..), vec![27.392, 13.428, -11.115]), + ("O", (0, 1, 3, ..), vec![26.443, 14.184, -11.327]), // Glycing - AAlast - ("N", (0,153, 0, ..), vec![23.474, -3.227, 5.994]), - ("CA", (0,153, 1, ..), vec![22.818, -2.798, 7.211]), - ("C", (0,153, 2, ..), vec![22.695, -1.282, 7.219]), - ("O", (0,153, 3, ..), vec![21.870, -0.745, 7.992]), + ("N", (0, 153, 0, ..), vec![23.474, -3.227, 5.994]), + ("CA", (0, 153, 1, ..), vec![22.818, -2.798, 7.211]), + ("C", (0, 153, 2, ..), vec![22.695, -1.282, 7.219]), + ("O", (0, 153, 3, ..), vec![21.870, -0.745, 7.992]), ]; for (atom_name, (b, i, j, k), expected) in backbone_coords { // assert_eq!(ac_backbone_tensor.dims(), &[1, 154, 4, 3]) - let actual: Vec = ac_backbone_tensor.i((b, i, j, k)).unwrap().to_vec1().unwrap(); + let actual: Vec = ac_backbone_tensor + .i((b, i, j, k)) + .unwrap() + .to_vec1() + .unwrap(); println!("ACTUAL: {:?}", actual); assert_eq!(actual, expected, "Mismatch for atom {}", atom_name); } @@ -666,45 +644,49 @@ mod tests { // Methionine - AA00 // We iterate through these positions. Not all AA's have each ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]), - ("CA", (0,0, 1, ..), vec![24.404, 9.859, -9.939]), - ("C", (0,0, 2, ..), vec![25.814, 10.249, -10.359]), - ("CB", (0,0, 3, ..), vec![24.070, 10.495, -8.596]), - ("O", (0,0, 4, ..), vec![26.748, 9.469, -10.197]), - ("CG", (0,0, 5, ..), vec![24.880, 9.939, -7.442]), - ("CG1", (0,0, 6, ..), vec![0.0, 0.0, 0.0]), - ("CG2", (0,0, 7, ..), vec![0.0, 0.0, 0.0]), - ("OG", (0,0, 8, ..), vec![0.0, 0.0, 0.0]), - ("OG1", (0,0, 9, ..), vec![0.0, 0.0, 0.0]), - ("SG", (0,0, 10, ..), vec![0.0, 0.0, 0.0]), - ("CD", (0,0, 11, ..), vec![0.0, 0.0, 0.0]), - ("CD1", (0,0, 12, ..), vec![0.0, 0.0, 0.0]), - ("CD2", (0,0, 13, ..), vec![0.0, 0.0, 0.0]), - ("ND1", (0,0, 14, ..), vec![0.0, 0.0, 0.0]), - ("ND2", (0,0, 15, ..), vec![0.0, 0.0, 0.0]), - ("OD1", (0,0, 16, ..), vec![0.0, 0.0, 0.0]), - ("OD2", (0,0, 17, ..), vec![0.0, 0.0, 0.0]), - ("SD", (0,0, 18, ..), vec![24.262, 10.555, -5.873]), - ("CE", (0,0, 19, ..), vec![24.822, 12.266, -5.967]), - ("CE1", (0,0, 20, ..), vec![0.0, 0.0, 0.0]), - ("CE2", (0,0, 21, ..), vec![0.0, 0.0, 0.0]), - ("CE3", (0,0, 22, ..), vec![0.0, 0.0, 0.0]), - ("NE", (0,0, 23, ..), vec![0.0, 0.0, 0.0]), - ("NE1", (0,0, 24, ..), vec![0.0, 0.0, 0.0]), - ("NE2", (0,0, 25, ..), vec![0.0, 0.0, 0.0]), - ("OE1", (0,0, 26, ..), vec![0.0, 0.0, 0.0]), - ("OE2", (0,0, 27, ..), vec![0.0, 0.0, 0.0]), - ("CH2", (0,0, 28, ..), vec![0.0, 0.0, 0.0]), - ("NH1", (0,0, 29, ..), vec![0.0, 0.0, 0.0]), - ("NH2", (0,0, 30, ..), vec![0.0, 0.0, 0.0]), - ("OH", (0,0, 31, ..), vec![0.0, 0.0, 0.0]), - ("CZ", (0,0, 32, ..), vec![0.0, 0.0, 0.0]), - ("CZ2", (0,0, 33, ..), vec![0.0, 0.0, 0.0]), - ("CZ3", (0, 0,34, ..), vec![0.0, 0.0, 0.0]), - ("NZ", (0, 0,35, ..), vec![0.0, 0.0, 0.0]), - ("OXT", (0, 0,36, ..), vec![0.0, 0.0, 0.0]), + ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]), + ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]), + ("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]), + ("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]), + ("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]), + ("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]), + ("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]), + ("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]), + ("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]), + ("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]), + ("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]), + ("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]), + ("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]), + ("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]), + ("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]), + ("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]), + ("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]), + ("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]), + ("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]), + ("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]), + ("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]), + ("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]), + ("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]), + ("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]), + ("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]), + ("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]), + ("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]), + ("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]), + ("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]), + ("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]), + ("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]), + ("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]), + ("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]), + ("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]), + ("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]), + ("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]), ]; - for (atom_name, (b,i, j, k), expected) in allatom_coords { - let actual: Vec = ac_backbone_tensor.i((b, i, j, k)).unwrap().to_vec1().unwrap(); + for (atom_name, (b, i, j, k), expected) in allatom_coords { + let actual: Vec = ac_backbone_tensor + .i((b, i, j, k)) + .unwrap() + .to_vec1() + .unwrap(); assert_eq!(actual, expected, "Mismatch for atom {}", atom_name); } } @@ -783,7 +765,9 @@ mod tests { ], &device, ) - .unwrap().to_dtype(test_dtype).unwrap(); + .unwrap() + .to_dtype(test_dtype) + .unwrap(); // Create mask indicating all points are valid let mask = Tensor::ones((2, 3), test_dtype, &device).unwrap(); @@ -796,7 +780,7 @@ mod tests { assert_eq!(indices.dims(), &[2, 3, 2]); // [batch, seq_len, k] // For first sequence, point [1,0,0] should have [0,0,0] and [2,0,0] as nearest neighbors - let point_neighbors: Vec = indices.i((0, 1, ..)).unwrap().to_vec1().unwrap(); + let point_neighbors: Vec = indices.i((0, 1, ..)).unwrap().to_vec1().unwrap(); assert_eq!(point_neighbors, vec![0, 2]); // Check distances are correct diff --git a/ferritin-featurizers/tests/test_cli.rs b/ferritin-featurizers/tests/test_cli.rs index 6cc1086..d682720 100644 --- a/ferritin-featurizers/tests/test_cli.rs +++ b/ferritin-featurizers/tests/test_cli.rs @@ -1,3 +1,6 @@ +// +// cargo flamegraph --bin ferritin-featurizers -- run --seed 111 --pdb-path ferritin-test-data/data/structures/1bc8.cif --model-type protein_mpnn --out-folder testout +// cargo instruments -t time --bin ferritin-featurizers -- run --seed 111 --pdb-path ferritin-test-data/data/structures/1bc8.cif --model-type protein_mpnn --out-folder testout use assert_cmd::Command; use ferritin_test_data::TestFile; use std::path::Path;