Skip to content

Commit

Permalink
Tests and Style (#55)
Browse files Browse the repository at this point in the history
* fix the tests

* bump deps

* remove print statements

* spelling pass

* clippy
  • Loading branch information
zachcp authored Nov 30, 2024
1 parent d128849 commit 6e4dc95
Show file tree
Hide file tree
Showing 19 changed files with 412 additions and 518 deletions.
646 changes: 295 additions & 351 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ resolver = "2"

[workspace.dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.132"
serde_json = "1.0.133"
serde_bytes = "0.11.15"
serde_repr = "0.1.19"
urlencoding = "2.1.3"
Expand Down
6 changes: 3 additions & 3 deletions ferritin-bevy/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[package]
[package]
name = "ferritin-bevy"
version.workspace = true
edition.workspace = true
Expand All @@ -7,8 +7,8 @@ license.workspace = true
description.workspace = true

[dependencies]
bevy = "=0.15.0-rc.3"
bon = "3.0.0"
bevy = "0.15.0"
bon = "3.1.1"
ferritin-core = { path = "../ferritin-core" }
pdbtbx.workspace = true

Expand Down
2 changes: 1 addition & 1 deletion ferritin-cellscape/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ description.workspace = true
[dependencies]
ferritin-core = { path = "../ferritin-core" }
pdbtbx = { workspace = true }
geo = "0.29.1"
geo = "0.29.2"
svg = "0.18.0"
2 changes: 1 addition & 1 deletion ferritin-cellscape/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

- cellscape takes a structure and projects the results down to a pretty 2D image
- this crate aims to port it to Rust.
- initials work: load a stuct, calc the svg and write it to disk
- initials work: load a struct, calc the svg and write it to disk


```sh
Expand Down
6 changes: 3 additions & 3 deletions ferritin-core/src/atomcollection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use pdbtbx::Element;
///
/// The core data structure of ferritin-core.
///
/// it strives to be simple, high perfomance, and extensible using
/// it strives to be simple, high performance, and extensible using
/// traits.
///
pub struct AtomCollection {
Expand Down Expand Up @@ -147,7 +147,7 @@ impl AtomCollection {
}

pub fn connect_via_distance(&self) -> Vec<Bond> {
// note: was intendin to follow Biotite's algo
// note: was intending to follow Biotite's algo
unimplemented!()
}
pub fn get_size(&self) -> usize {
Expand Down Expand Up @@ -215,7 +215,7 @@ impl AtomCollection {
}
/// Iter_Residues Will Iterate Through the AtomCollection one Residue at a time.
///
/// This is the base for any onther residue filtration code.
/// This is the base for any other residue filtration code.
pub fn iter_residues_all(&self) -> ResidueIter {
ResidueIter::new(self, self.get_residue_starts())
}
Expand Down
2 changes: 1 addition & 1 deletion ferritin-featurizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pdbtbx.workspace = true
rand = "0.8.5"
safetensors = "0.4.5"
strum = { version = "0.26", features = ["derive"] }
tokenizers = "0.20.3"
tokenizers = "0.21.0"
ferritin-test-data = { path = "../ferritin-test-data" }


Expand Down
5 changes: 3 additions & 2 deletions ferritin-featurizers/src/models/amplify/amplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use candle_nn::{
/// Configuration Struct for AMPLIFY
///
/// Currently only holds the weight params for
/// those modeld found on GH: the 120M and 350M models.
/// those models found on GH: the 120M and 350M models.
///
pub struct AMPLIFYConfig {
pub hidden_size: usize,
Expand Down Expand Up @@ -94,14 +94,15 @@ impl AMPLIFYConfig {
}
}

//noinspection SpellCheckingInspection
/// Amplify EncoderBlock implementation
///
/// References for coding the block from similar models.
///
/// - [T5](https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/t5.rs#L331)
/// - [distilbert](https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/distilbert.rs#L198)
/// - [glm4](https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/glm4.rs#L340)
/// - [SwiGLu Imple](https://github.com/facebookresearch/xformers/blob/main/xformers/ops/swiglu_op.py#L462)
/// - [SwiGLu Implementation](https://github.com/facebookresearch/xformers/blob/main/xformers/ops/swiglu_op.py#L462)
#[derive(Debug)]
pub struct EncoderBlock {
q: Linear,
Expand Down
2 changes: 1 addition & 1 deletion ferritin-featurizers/src/models/ligandmpnn/featurizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait LMPNNFeatures {
/// datasets
impl LMPNNFeatures for AtomCollection {
/// Return a 2D tensor of [1, seqlength]
fn encode_amino_acids(&self, device: &Device) -> Result<(Tensor)> {
fn encode_amino_acids(&self, device: &Device) -> Result<Tensor> {
let n = self.iter_residues_aminoacid().count();
let s = self
.iter_residues_aminoacid()
Expand Down
71 changes: 4 additions & 67 deletions ferritin-featurizers/src/models/ligandmpnn/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! Consider factoring out model creation of the DEC
//! and ENC layers using a function.
//!
//! here is an example of paramatereizable network creation:
//! here is an example of paramaterizable network creation:
//! https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/resnet.rs
//!
use super::configs::{ModelTypes, ProteinMPNNConfig};
Expand All @@ -16,7 +16,7 @@ use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::encoding::one_hot;
use candle_nn::ops::{log_softmax, softmax};
use candle_nn::{embedding, layer_norm, linear, Dropout, Embedding, Linear, VarBuilder};
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::generation::LogitsProcessor;

pub fn multinomial_sample(probs: &Tensor, temperature: f64, seed: u64) -> Result<Tensor> {
// Create the logits processor with its required arguments
Expand Down Expand Up @@ -170,7 +170,6 @@ impl EncLayer {
};

// Safe division with scale
println!("Scale value: {:?}", self.scale);
let dh = {
let sum = h_message.sum(D::Minus2)?;
let scale = if self.scale == 0.0 { 1.0 } else { self.scale };
Expand Down Expand Up @@ -223,7 +222,6 @@ impl EncLayer {
.forward(&h_message, training.expect("Training Must be specified"))?;
self.norm3.forward(&(h_e + h_message_dropout)?)?
};
println!("EncoderLayer: Finishing forward pass");
Ok((h_v, h_e))
}
}
Expand Down Expand Up @@ -406,53 +404,30 @@ impl ProteinMPNN {
None => &Tensor::ones_like(&s_true)?,
};

println!("Starting encode function");

match self.config.model_type {
ModelTypes::ProteinMPNN => {
let (e, e_idx) = self.features.forward(features, device)?;
println!("After embedding dims: {:?}", e.dims());

let mut h_v = Tensor::zeros(
(e.dim(0)?, e.dim(1)?, e.dim(D::Minus1)?),
DType::F64,
device,
)?;

let mut h_e = self.w_e.forward(&e)?;

let mask_attend = if let Some(mask) = features.get_sequence_mask() {
println!("Original mask dims: {:?}", mask.dims());
println!(
"Original mask values: {:?}",
mask.get(0)?.narrow(0, 0, 5)?.to_vec1::<f32>()?
);

// First unsqueeze mask
let mask_expanded = mask.unsqueeze(D::Minus1)?; // [B, L, 1]
println!(
"Expanded mask values: {:?}",
mask_expanded.get(0)?.narrow(0, 0, 5)?.to_vec2::<f32>()?
);

// Gather using E_idx
let mask_gathered = gather_nodes(&mask_expanded, &e_idx)?;
println!("Gathered mask dims: {:?}", mask_gathered.dims());
println!(
"Gathered mask values: {:?}",
mask_gathered
.get(0)?
.narrow(0, 0, 5)?
.narrow(1, 0, 5)?
.to_vec3::<f32>()?
);

let mask_gathered = mask_gathered.squeeze(D::Minus1)?;

// Multiply original mask with gathered mask
let mask_attend = {
let mask_unsqueezed = mask.unsqueeze(D::Minus1)?; // [B, L, 1]
println!("mask_unsqueezed dims: {:?}", mask_unsqueezed.dims());

// Explicitly expand mask_unsqueezed to match mask_gathered dimensions
let mask_expanded = mask_unsqueezed
Expand All @@ -462,7 +437,6 @@ impl ProteinMPNN {
mask_gathered.dim(2)?, // number of neighbors
))?
.contiguous()?;
println!("mask_expanded dims: {:?}", mask_expanded.dims());

// Now do the multiplication with explicit shapes
mask_expanded.mul(&mask_gathered)?
Expand All @@ -472,37 +446,14 @@ impl ProteinMPNN {
let (b, l) = mask.dims2()?;
let ones = Tensor::ones((b, l, e_idx.dim(2)?), DType::F32, device)?;
println!("Created default ones mask dims: {:?}", ones.dims());
println!(
"Created default ones mask values: {:?}",
ones.get(0)?
.narrow(0, 0, 5)?
.narrow(1, 0, 5)?
.to_vec2::<f32>()?
);

ones
};

for (i, layer) in self.encoder_layers.iter().enumerate() {
println!("Starting encoder layer {}", i);

// Debug h_v (3D tensor)
println!("h_v before layer {} dims: {:?}", i, h_v.dims());
let h_v_f32 = h_v.to_dtype(DType::F32)?;
println!(
"h_v before layer {} values: {:?}",
i,
h_v_f32.to_vec3::<f32>()?
);

// Debug h_e (4D tensor) - access first batch and first sequence position
println!("h_e before layer {} dims: {:?}", i, h_e.dims());
let h_e_f32 = h_e.to_dtype(DType::F32)?;
println!(
"h_e before layer {} first position values: {:?}",
i,
h_e_f32.get(0)?.get(0)?.to_vec2::<f32>()?
);

let (new_h_v, new_h_e) = layer.forward(
&h_v,
&h_e,
Expand All @@ -511,27 +462,13 @@ impl ProteinMPNN {
Some(&mask_attend),
Some(false),
)?;
println!("After layer {} forward pass:", i);

// Debug new_h_v
println!("new_h_v dims: {:?}", new_h_v.dims());
let new_h_v_f32 = new_h_v.to_dtype(DType::F32)?;
println!("new_h_v values: {:?}", new_h_v_f32.to_vec3::<f32>()?);

// Debug new_h_e
println!("new_h_e dims: {:?}", new_h_e.dims());
let new_h_e_f32 = new_h_e.to_dtype(DType::F32)?;
println!(
"new_h_e first position values: {:?}",
new_h_e_f32.get(0)?.get(0)?.to_vec2::<f32>()?
);

h_v = new_h_v;
h_e = new_h_e;
}
println!("Final h_v dims: {:?}", h_v.dims());
println!("Final h_e dims: {:?}", h_e.dims());
println!("Final e_idx dims: {:?}", e_idx.dims());

Ok((h_v, h_e, e_idx))
}
ModelTypes::LigandMPNN => {
Expand Down
10 changes: 5 additions & 5 deletions ferritin-featurizers/src/models/ligandmpnn/proteinfeatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl ProteinFeaturesModel {
})
}

/// This function calculates the nearest Ca coordinates and retunrs the ditances and indices.
/// This function calculates the nearest Ca coordinates and returns the distances and indices.
fn _dist(&self, x: &Tensor, mask: &Tensor, eps: f64) -> Result<(Tensor, Tensor)> {
println!("in _dist: ");
println!("Tensor dims: x, mask: {:?}, {:?}", x.dims(), mask.dims());
Expand All @@ -79,7 +79,7 @@ impl ProteinFeaturesModel {
// Create centers (μ)
let d_mu = linspace(D_MIN, D_MAX, self.num_rbf, device)?
.reshape((1, 1, 1, self.num_rbf))?
.to_dtype(candle_core::DType::F32)?;
.to_dtype(DType::F32)?;

// Calculate width (σ)
let d_sigma = (D_MAX - D_MIN) / self.num_rbf as f64;
Expand Down Expand Up @@ -249,17 +249,17 @@ impl ProteinFeaturesModel {
let d_chains = (&chain_labels.unsqueeze(2)?.broadcast_as(target_shape)?
- &chain_labels.unsqueeze(1)?.broadcast_as(target_shape)?)?
.eq(0.0)?
.to_dtype(candle_core::DType::I64)?;
.to_dtype(DType::I64)?;

// E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
let e_chains = gather_edges(&d_chains.unsqueeze(D::Minus1)?, &e_idx)?.squeeze(D::Minus1)?;

println!("About to start the embeddings calculation...");
let e_positional = self
.embeddings
.forward(&offset.to_dtype(candle_core::DType::I64)?, &e_chains)?;
.forward(&offset.to_dtype(DType::I64)?, &e_chains)?;

println!("About to cat the pos embeddigns...");
println!("About to cat the pos embeddings...");

let e = Tensor::cat(&[e_positional, rbf_all], D::Minus1)?;
let e = self.edge_embedding.forward(&e)?;
Expand Down
Loading

0 comments on commit 6e4dc95

Please sign in to comment.