Skip to content

Commit

Permalink
Refactor AMPLIFY CLI (#70)
Browse files Browse the repository at this point in the history
* make an examples  crate

* stub out an AMPLIFY example

* get AMPLIFY to work

* use remote JSON loading

* update the WASM version as well
  • Loading branch information
zachcp authored Dec 7, 2024
1 parent af62d0e commit d41b75f
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 46 deletions.
29 changes: 25 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"ferritin-cellscape",
"ferritin-core",
"ferritin-esm",
"ferritin-examples",
"ferritin-amplify",
"ferritin-ligandmpnn",
"ferritin-molviewspec",
Expand Down
3 changes: 2 additions & 1 deletion ferritin-amplify/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ candle-core.workspace = true
candle-nn.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
rand.workspace = true
safetensors = "0.4.5"
safetensors.workspace = true
serde.workspace = true

[target.'cfg(target_os = "macos")'.features]
metal = []
Expand Down
34 changes: 0 additions & 34 deletions ferritin-amplify/examples/amplify/main.rs

This file was deleted.

7 changes: 5 additions & 2 deletions ferritin-amplify/src/amplify/amplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ use candle_nn::{
embedding, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Activation, Dropout,
Embedding, Linear, RmsNorm, VarBuilder,
};
use serde::Deserialize;

#[cfg(not(target_arch = "wasm32"))]
use candle_hf_hub::{api::sync::Api, Repo, RepoType};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Deserialize)]
/// Configuration Struct for AMPLIFY
///
/// Currently only holds the weight params for
Expand Down Expand Up @@ -487,7 +488,6 @@ impl AMPLIFY {
/// Retreive the model and make it available for usage.
/// hardcode the 120M for the moment...
pub fn load_from_huggingface(device: Device) -> Result<(ProteinTokenizer, Self)> {
let ampconfig = AMPLIFYConfig::amp_120m();
let model_id = "chandar-lab/AMPLIFY_120M";
let revision = "main";
let api = Api::new().map_err(|e| candle_core::Error::Msg(e.to_string()))?;
Expand Down Expand Up @@ -515,6 +515,9 @@ impl AMPLIFY {

Ok((protein_tokenizer, model))
}
pub fn get_device(&self) -> &Device {
self.freqs_cis.device()
}
}

// Helper structs and enums
Expand Down
40 changes: 40 additions & 0 deletions ferritin-examples/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[package]
name = "ferritin-examples"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
description.workspace = true

[features]
metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"]

[dependencies]
anyhow.workspace = true
candle-core.workspace = true
candle-nn.workspace = true
clap = { version = "4.5.23", features = ["derive"] }
ferritin-amplify = { path = "../ferritin-amplify" }
serde_json.workspace = true

[target.'cfg(target_os = "macos")'.features]
metal = []

[target.'cfg(target_os = "macos")'.dependencies]
candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
tokenizers = { version = "0.21.0", default-features = false, features = [
"unstable_wasm",
] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
candle-hf-hub = { workspace = true }
tokenizers = { version = "0.21.0" } # full features for non-wasm


[dev-dependencies]
candle-examples.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
assert_cmd = "2.0.16"
tempfile = "3.14.0"
File renamed without changes.
116 changes: 116 additions & 0 deletions ferritin-examples/examples/amplify/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use anyhow::{Error as E, Result};
use candle_core::{DType, Tensor, D};
use candle_examples::device;
use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use clap::Parser;
use ferritin_amplify::{AMPLIFYConfig as Config, AMPLIFY};
use tokenizers::Tokenizer;

pub const DTYPE: DType = DType::F32;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,

/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,

/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,

/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,

/// Use tanh based approximation for Gelu instead of erf implementation.
#[arg(long, default_value = "false")]
approximate_gelu: bool,
}

impl Args {
fn build_model_and_tokenizer(&self) -> Result<(AMPLIFY, Tokenizer)> {
let device = device(self.cpu)?;
let default_model = "chandar-lab/AMPLIFY_120M".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config_str = std::fs::read_to_string(config_filename)?;
let config_str = config_str
.replace("SwiGLU", "swiglu")
.replace("Swiglu", "swiglu");
let config: Config = serde_json::from_str(&config_str)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = AMPLIFY::load(vb, &config)?;
Ok((model, tokenizer))
}
}

fn main() -> Result<()> {
let args = Args::parse();
let (model, tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.get_device();
let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL";

let tokens = tokenizer
.encode(sprot_01.to_string(), false)
.map_err(E::msg)?
.get_ids()
.to_vec();

let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Encoding.......");
let encoded = model.forward(&token_ids, None, false, false)?;

println!("Predicting.......");
let predictions = encoded.logits.argmax(D::Minus1)?;

println!("Decoding.......");
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = tokenizer.decode(indices.as_slice(), true);

println!("Decoded: {:?}, ", decoded);
Ok(())
}
14 changes: 10 additions & 4 deletions ferritin-wasm-examples/amplify/src/bin/m.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use ferritin_amplify::{AMPLIFYConfig, ProteinTokenizer, AMPLIFY};
use ferritin_amplify::{AMPLIFYConfig as Config, ProteinTokenizer, AMPLIFY};
use ferritin_wasm_example_amplify::console_log;
use tokenizers::{PaddingParams, Tokenizer};
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
pub struct Model {
amplify: AMPLIFY,
tokenizer: ProteinTokenizer,
tokenizer: Tokenizer,
}

#[wasm_bindgen]
Expand All @@ -18,14 +18,20 @@ impl Model {
console_error_panic_hook::set_once();
let device = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
let config: AMPLIFYConfig = serde_json::from_slice(&config)?;

let config_str = String::from_utf8(config).map_err(|e| JsError::new(&e.to_string()))?;
let config_str = config_str
.replace("SwiGLU", "swiglu")
.replace("Swiglu", "swiglu");

let config: Config = serde_json::from_str(&config_str)?;
let amplify = AMPLIFY::load(vb, &config)?;

// currently tokenizer fetches from HuggingFace
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;

Ok(Self { amplify, tokenizer })
// Ok(Self { amplify })
}

// pub fn get_embeddings(&mut self, input: JsValue) -> Result<JsValue, JsError> {
Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ test-full:

amplify:
# cargo run --example amplify
cargo run --example amplify --features metal
RUST_BACKTRACE=1 cargo run --example amplify --features metal

test-ligandmpnn:
cargo test --features metal -p ferritin-ligandmpnn test_cli_command_run_example_06 -- --nocapture
Expand Down

0 comments on commit d41b75f

Please sign in to comment.