Skip to content

Commit

Permalink
LigandMPNN Test Suite (#62)
Browse files Browse the repository at this point in the history
* creating the score->fasta fn

* get the run script ready

* gitignore and model

*  rework test harness to conditionally compile for Metal

* begin exploring why outputs are homogenous
  • Loading branch information
zachcp authored Dec 2, 2024
1 parent 18138f0 commit aa2d4b5
Show file tree
Hide file tree
Showing 13 changed files with 1,060 additions and 74 deletions.
111 changes: 102 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 0 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ members = [
]
resolver = "2"


[workspace.dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.133"
Expand All @@ -29,11 +28,3 @@ edition = "2021"
authors = ["Zach Charlop-Powers<[email protected]>"]
description = "Molecular visualization tools"
license = "MIT OR Apache-2.0"

[profile.dev]
# Disabling debug info speeds up builds a bunch,
# and we don't rely on it for debugging that much.
debug = 0

[workspace.features]
metal = []
2 changes: 1 addition & 1 deletion ferritin-bevy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description.workspace = true

[dependencies]
bevy = "0.15.0"
bon = "3.1.1"
bon = "3.2.0"
ferritin-core = { path = "../ferritin-core" }
pdbtbx.workspace = true

Expand Down
1 change: 1 addition & 0 deletions ferritin-featurizers/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*.safetensors
outputs/
33 changes: 15 additions & 18 deletions ferritin-featurizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,18 @@ authors.workspace = true
license.workspace = true
description.workspace = true

[target.'cfg(target_os = "macos")'.features]
default = ["metal"]


# Metal-enabled dependencies when feature is active
[target.'cfg(feature = "metal")'.dependencies]
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn", features = [
"metal",
] }
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
"metal",
] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers", features = [
"metal",
] }
candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels" }
[features]
metal = [
"candle-core/metal",
"candle-nn/metal",
"candle-transformers/metal",
"candle-metal-kernels",
]

[dependencies]
anyhow.workspace = true
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core" }
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
] }
candle-hf-hub = "0.3.3"
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers" }
Expand All @@ -39,9 +31,14 @@ safetensors = "0.4.5"
strum = { version = "0.26", features = ["derive"] }
tokenizers = "0.21.0"

[target.'cfg(target_os = "macos")'.features]
metal = []

[target.'cfg(target_os = "macos")'.dependencies]
candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true }

[dev-dependencies]
candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" }
ferritin-test-data = { path = "../ferritin-test-data" }
assert_cmd = "2.0.16"
tempfile = "3.14.0"
candle-hf-hub = "0.3.3"
35 changes: 27 additions & 8 deletions ferritin-featurizers/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ pub fn execute(
Some(multi_pdb_config),
)?;

println!("About to Load the model!");
let model = exec.load_model()?;
println!("Model Loaded!");
println!("Model Loaded on the {:?}", model.device);

println!("Generating Protein Features");
let prot_features = exec.generate_protein_features()?;
Expand All @@ -72,16 +70,37 @@ pub fn execute(
std::fs::create_dir_all(format!("{}/backbones", out_folder))?;
std::fs::create_dir_all(format!("{}/packed", out_folder))?;

// Score a Protein!
println!("Sampling from the Model...");
let model_sample = model.sample(&prot_features)?;
println!("{:?}", model_sample);

// // Score a Protein!
// println!("Scoring the Protein...");
// let model_score = model.score(&prot_features, false);
// println!("{:?}", model_score);
// let model_score = model.score(&prot_features, false)?;
// println!("Protein Score: {:?}", model_score);
std::fs::create_dir_all(format!("{}/seqs", out_folder))?;
let sequences = model_sample.get_sequences()?;
println!("OUTPUT FASTA: {:?}", sequences);
println!("DECODING ORDER: {:?}", model_sample.get_decoding_order()?);

let fasta_path = format!("{}/seqs/output.fasta", out_folder);

let mut fasta_content = String::new();
for (i, seq) in sequences.iter().enumerate() {
fasta_content.push_str(&format!(">sequence_{}\n{}\n", i + 1, seq));
}
std::fs::write(fasta_path, fasta_content)?;

// Sample from the Model!
// Note: sampling from the model
println!("Sampling from the Model...");
let model_sample = model.sample(&prot_features);
println!("{:?}", model_sample);
// println!("Sampling from the Model...");
// let model_sample = model.sample(&prot_features);
// println!("{:?}", model_sample);

// assert_eq!(true, false);

// prot_features
// generate_protein_features()

// model.score() -> Result<ScoreOutput>

Expand Down
1 change: 1 addition & 0 deletions ferritin-featurizers/src/models/ligandmpnn/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
use super::featurizer::ProteinFeatures;
use super::model::ProteinMPNN;
use crate::models::ligandmpnn::featurizer::LMPNNFeatures;
use crate::models::ligandmpnn::model::ScoreOutput;
use anyhow::Error;
use candle_core::pickle::PthTensors;
use candle_core::{DType, Device, Tensor};
Expand Down
Loading

0 comments on commit aa2d4b5

Please sign in to comment.