Skip to content

Commit

Permalink
Begin Refactoring LigandMPNN CLI Call to use Metal (#56)
Browse files Browse the repository at this point in the history
* begin factoring out dtype conversions

* dtype conversion pass

* remove contiguous calls.

* update device passing

* update utilities

* I64->U32

* Metal ops for Gather and Scatter-add

* Cargo coords to temp branch
  • Loading branch information
zachcp authored Dec 1, 2024
1 parent 6e4dc95 commit e83f661
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 211 deletions.
68 changes: 61 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 16 additions & 3 deletions ferritin-featurizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,23 @@ description.workspace = true

[dependencies]
anyhow.workspace = true
candle-core = "0.8"

candle-metal-kernels = { git = "https://github.com/zachcp/candle.git", package = "candle-metal-kernels", branch = "20241201-SA" }
candle-core = { git = "https://github.com/zachcp/candle.git", package = "candle-core", features = [
"metal",
], branch = "20241201-SA" }

# candle-core = { version = "0.8", features = ["metal"] }
candle-hf-hub = "0.3.3"
candle-nn = "0.8"
candle-transformers = "0.8.0"

# candle-nn = { version = "0.8", features = ["metal"] }
candle-nn = { git = "https://github.com/zachcp/candle.git", package = "candle-nn", features = [
"metal",
], branch = "20241201-SA" }
candle-transformers = { git = "https://github.com/zachcp/candle.git", package = "candle-transformers", features = [
"metal",
], branch = "20241201-SA" }
# candle-transformers = "0.8.0"
clap = "4.5.21"
ferritin-core = { path = "../ferritin-core" }
itertools.workspace = true
Expand Down
28 changes: 26 additions & 2 deletions ferritin-featurizers/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,30 @@ use crate::models::ligandmpnn::configs::{
AABiasConfig, LigandMPNNConfig, MPNNExecConfig, MembraneMPNNConfig, ModelTypes, MultiPDBConfig,
ResidueControl, RunConfig,
};
use candle_core::Device;
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{Device, Result};

pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!(
"Running on CPU, to run on GPU(metal), build this example with `--features metal`"
);
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(Device::Cpu)
}
}

pub fn execute(
seed: i32,
Expand All @@ -17,7 +40,7 @@ pub fn execute(
multi_pdb_config: MultiPDBConfig,
) -> anyhow::Result<()> {
// todo - whats the best way to handle device?
let device = &Device::Cpu;
let device = device(false)?;

let model_type = model_type.unwrap_or(ModelTypes::ProteinMPNN);

Expand All @@ -37,6 +60,7 @@ pub fn execute(
println!("About to Load the model!");
let model = exec.load_model()?;
println!("Model Loaded!");
println!("Model Loaded on the {:?}", model.device);

println!("Generating Protein Features");
let prot_features = exec.generate_protein_features()?;
Expand Down
12 changes: 7 additions & 5 deletions ferritin-featurizers/src/models/ligandmpnn/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ pub struct MPNNExecConfig {
pub(crate) membrane_mpnn_config: Option<MembraneMPNNConfig>,
pub(crate) multi_pdb_config: Option<MultiPDBConfig>,
pub(crate) residue_control_config: Option<ResidueControl>,
// device: &candle_core::Device,
pub(crate) device: Device,
pub(crate) seed: i32,
}

impl MPNNExecConfig {
pub fn new(
seed: i32,
device: &Device,
device: Device,
pdb_path: String,
model_type: ModelTypes,
run_config: RunConfig,
Expand All @@ -62,23 +62,25 @@ impl MPNNExecConfig {
residue_control_config: residue_config,
multi_pdb_config: multi_pdb_specific,
seed,
// device: device,
device: device,
})
}
// Todo: refactor this to use loader.
pub fn load_model(&self) -> Result<ProteinMPNN, Error> {
let default_dtype = DType::F32;

// this is a hidden dep....
let (mpnn_file, _handle) = TestFile::ligmpnn_pmpnn_01().create_temp()?;
let pth = PthTensors::new(mpnn_file, Some("model_state_dict"))?;
let vb = VarBuilder::from_backend(Box::new(pth), DType::F32, Device::Cpu);
let vb = VarBuilder::from_backend(Box::new(pth), default_dtype, self.device.clone());
let pconf = ProteinMPNNConfig::proteinmpnn();
Ok(ProteinMPNN::load(vb, &pconf).expect("Unable to load the PMPNN Model"))
}
pub fn generate_model(self) {
todo!()
}
pub fn generate_protein_features(&self) -> Result<ProteinFeatures, Error> {
let device = Device::Cpu;
let device = self.device.clone();
let base_dtype = DType::F32;

// init the Protein Features
Expand Down
2 changes: 1 addition & 1 deletion ferritin-featurizers/src/models/ligandmpnn/featurizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn is_heavy_atom(element: &Element) -> bool {

/// Convert the AtomCollection into a struct that can be passed to a model.
pub trait LMPNNFeatures {
fn encode_amino_acids(&self, device: &Device) -> Result<(Tensor)>; // ( residue types )
fn encode_amino_acids(&self, device: &Device) -> Result<Tensor>; // ( residue types )
fn featurize(&self, device: &Device) -> Result<ProteinFeatures>; // need more control over this featurization process
fn get_res_index(&self) -> Vec<u32>;
fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>; // [residues, N/CA/C/O, xyz]
Expand Down
Loading

0 comments on commit e83f661

Please sign in to comment.