Skip to content

Commit

Permalink
AMPLIFY + Metal (#69)
Browse files Browse the repository at this point in the history
* get AMPLIFY up and running on Mac // Metal

* simplify the AMPLIFY Example code

* WASM Example of AMPLIFY

* refactor deps to let WASM compilation

* being the JS version that pulls the model, the configs, and the tokenizer from the same repo

* realize we should follow that same format for the AMPLIFy structs.
  • Loading branch information
zachcp authored Dec 6, 2024
1 parent c3b7013 commit af62d0e
Show file tree
Hide file tree
Showing 18 changed files with 1,055 additions and 97 deletions.
312 changes: 291 additions & 21 deletions Cargo.lock

Large diffs are not rendered by default.

23 changes: 16 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@ members = [
"ferritin-molviewspec",
"ferritin-pymol",
"ferritin-test-data",
"ferritin-wasm-examples/*",
]
resolver = "2"

[workspace.dependencies]
anyhow = "1.0"
bitflags = "2.6.0"
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
] }
candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" }
candle-hf-hub = "0.3.3"
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers" }
itertools = "0.13.0"
once_cell = "1.20.2"
pdbtbx = "0.12.0"
rand = "0.8.5"
safetensors = "0.4.5"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.133"
serde_bytes = "0.11.15"
serde_json = "1.0.133"
serde_repr = "0.1.19"
urlencoding = "2.1.3"
once_cell = "1.20.2"
bitflags = "2.6.0"
anyhow = "1.0"
pdbtbx = "0.12.0"
itertools = "0.13.0"
tokenizers = { version = "0.21.0", default-features = false }

[workspace.package]
version = "0.1.0"
Expand Down
21 changes: 14 additions & 7 deletions ferritin-amplify/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,30 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"]

[dependencies]
anyhow.workspace = true
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
] }
candle-hf-hub = "0.3.3"
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" }
candle-core.workspace = true
candle-nn.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
rand = "0.8.5"
rand.workspace = true
safetensors = "0.4.5"
tokenizers = "0.21.0"

[target.'cfg(target_os = "macos")'.features]
metal = []

[target.'cfg(target_os = "macos")'.dependencies]
candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
tokenizers = { version = "0.21.0", default-features = false, features = [
"unstable_wasm",
] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
candle-hf-hub = { workspace = true }
tokenizers = { version = "0.21.0" } # full features for non-wasm


[dev-dependencies]
candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" }
candle-examples.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
assert_cmd = "2.0.16"
tempfile = "3.14.0"
59 changes: 15 additions & 44 deletions ferritin-amplify/examples/amplify/main.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,32 @@
use anyhow::Result;
use candle_core::{DType, Device, D};
use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use ferritin_amplify::{AMPLIFYConfig, ProteinTokenizer, AMPLIFY};
use safetensors::SafeTensors;
use candle_core::D;
use candle_examples::device;
use ferritin_amplify::AMPLIFY;

fn main() -> Result<()> {
let model_id = "chandar-lab/AMPLIFY_120M";
let revision = "main";
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
let weights_path = repo.get("model.safetensors")?;

// Available for printing Tensor data....
let print_tensor_info = false;
if print_tensor_info {
println!("Model tensors:");
let weights = std::fs::read(&weights_path)?;
let tensors = SafeTensors::deserialize(&weights)?;
tensors.names().iter().for_each(|tensor_name| {
if let Ok(tensor_info) = tensors.tensor(tensor_name) {
println!(
"Tensor: {:<44} || Shape: {:?}",
tensor_name,
tensor_info.shape(),
);
}
});
}
println!("Determining the Device ......");
#[cfg(target_os = "macos")]
let use_gpu = false;
#[cfg(not(target_os = "macos"))]
let use_gpu = true;
let dev = device(use_gpu)?;

println!("Loading the Amplify Model ......");
// https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs#L91C1-L92C101
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path.clone()], DType::F32, &Device::Cpu)?
};
let config = AMPLIFYConfig::amp_120m();
let model = AMPLIFY::load(vb, &config)?;

println!("Tokenizing and Modelling a Sequence from Swissprot...");
let tokenizer = repo.get("tokenizer.json")?;
let protein_tokenizer = ProteinTokenizer::new(tokenizer)?;
let (tokenizer, amplify) = AMPLIFY::load_from_huggingface(dev.clone())?;
let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL";
let pmatrix = protein_tokenizer.encode(&[sprot_01.to_string()], None, false, false)?;
let pmatrix = tokenizer.encode(&[sprot_01.to_string()], None, false, false)?;
let pmatrix = pmatrix.to_device(&dev)?;

let pmatrix = pmatrix.unsqueeze(0)?; // [batch, length] <- add batch of 1 in this case
let encoded = model.forward(&pmatrix, None, false, false)?;
let encoded = amplify.forward(&pmatrix, None, false, false)?;

println!("Assessing the Predictions.......");
// As of Nov 13 this is definitely not right....
// Input: MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL
// Output: MSVQLNIVGQSAAWTHGAAVCATCAQTFWPMSRGRQPPVNMSRFTARCTECIWYEAAFNARFNFVHLYNCGPNMSECLANMSWWYACQFGVHMSKSHYCGNKPLGTDNTKMMHHRECTSTVVWKHWPLCKVTVCYRHGLVSCTMHQRSTWTPRNEASWVPEWETSTPEHTCGDYWACQMPAGHGVCCCMMTEHWKPHTRVVCQTIEMWTYLQTYYYFWGVPEPCHHHIWTEPMPTSTSTSYDVVMYTTSGFGQHHW
let predictions = encoded.logits.argmax(D::Minus1)?;
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = protein_tokenizer.decode(indices.as_slice(), true)?;
let decoded = tokenizer.decode(indices.as_slice(), true)?;
println!("Encoded Logits Dimension: {:?}, ", encoded.logits);
println!("indices: {:?}", indices);
println!("Decoded Values: {}", decoded.replace(" ", ""));
Expand Down
6 changes: 5 additions & 1 deletion ferritin-amplify/src/amplify/amplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
use super::rotary::{apply_rotary_emb, precompute_freqs_cis};
use super::tokenizer::ProteinTokenizer;
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::{
embedding, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Activation, Dropout,
Embedding, Linear, RmsNorm, VarBuilder,
};

#[cfg(not(target_arch = "wasm32"))]
use candle_hf_hub::{api::sync::Api, Repo, RepoType};

#[derive(Debug, Clone)]
/// Configuration Struct for AMPLIFY
///
Expand Down Expand Up @@ -481,6 +483,7 @@ impl AMPLIFY {
config: cfg.clone(),
})
}
#[cfg(not(target_arch = "wasm32"))]
/// Retreive the model and make it available for usage.
/// hardcode the 120M for the moment...
pub fn load_from_huggingface(device: Device) -> Result<(ProteinTokenizer, Self)> {
Expand All @@ -506,6 +509,7 @@ impl AMPLIFY {
let tokenizer = repo
.get("tokenizer.json")
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;

let protein_tokenizer =
ProteinTokenizer::new(tokenizer).map_err(|e| candle_core::Error::Msg(e.to_string()))?;

Expand Down
16 changes: 8 additions & 8 deletions ferritin-esm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"]

[dependencies]
anyhow.workspace = true
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
] }
candle-hf-hub = "0.3.3"
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" }
candle-core.workspace = true
candle-hf-hub.workspace = true
candle-nn.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
rand = "0.8.5"
safetensors = "0.4.5"
tokenizers = "0.21.0"
rand.workspace = true
safetensors.workspace = true
tokenizers.workspace = true


[target.'cfg(target_os = "macos")'.features]
metal = []
Expand All @@ -27,7 +27,7 @@ metal = []
candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true }

[dev-dependencies]
candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" }
candle-examples.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
assert_cmd = "2.0.16"
tempfile = "3.14.0"
13 changes: 6 additions & 7 deletions ferritin-ligandmpnn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@ metal = [

[dependencies]
anyhow.workspace = true
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
] }
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers" }
candle-core.workspace = true
candle-nn.workspace = true
candle-transformers.workspace = true
clap = "4.5.23"
ferritin-core = { path = "../ferritin-core" }
ferritin-test-data = { path = "../ferritin-test-data" }
itertools.workspace = true
pdbtbx.workspace = true
rand = "0.8.5"
safetensors = "0.4.5"
rand.workspace = true
safetensors.workspace = true
strum = { version = "0.26", features = ["derive"] }


Expand All @@ -37,7 +36,7 @@ metal = []
candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true }

[dev-dependencies]
candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" }
candle-examples.workspace = true
ferritin-test-data = { path = "../ferritin-test-data" }
assert_cmd = "2.0.16"
tempfile = "3.14.0"
2 changes: 1 addition & 1 deletion ferritin-molviewspec/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ chrono = "0.4.38"
validator = { version = "0.19.0", features = ["derive"] }
serde = { workspace = true }
serde_json = { workspace = true }
urlencoding = { workspace = true }
urlencoding = "2.1.3"

[dev.dependencies]
ferritin-pymol = { path = "../ferritin-pymol" }
1 change: 1 addition & 0 deletions ferritin-wasm-examples/amplify/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
25 changes: 25 additions & 0 deletions ferritin-wasm-examples/amplify/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
name = "ferritin-wasm-example-amplify"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = { workspace = true }
candle-nn = { workspace = true }
tokenizers = { workspace = true, default-features = false, features = [
"unstable_wasm",
] }
ferritin-amplify = { path = "../../ferritin-amplify" }

# for wasm
gloo = "0.11.0"
js-sys = "0.3.74"
wasm-bindgen = "0.2.97"
serde-wasm-bindgen = "0.6.5"
serde_json.workspace = true
console_error_panic_hook = "0.1.7"
getrandom = { version = "0.2.15", features = ["js"] }


# [lib]
# crate-type = ["cdylib", "rlib"]
1 change: 1 addition & 0 deletions ferritin-wasm-examples/amplify/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Amplify WASM Example
81 changes: 81 additions & 0 deletions ferritin-wasm-examples/amplify/amplifyWorker.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//load Candle Bert Module wasm module
import init, { Model } from "./build/m.js";

async function fetchArrayBuffer(url) {
const cacheName = "amplify-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}

class Amplify {
static instance = {};

static async getInstance(weightsURL, tokenizerURL, configURL, modelID) {
if (!this.instance[modelID]) {
await init();

self.postMessage({ status: "downloaded", message: "Downloaded Model" });

const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);

self.postMessage({ status: "loading", message: "Loading Model" });

this.instance[modelID] = new Model(
weightsArrayU8,
tokenizerArrayU8,
mel_filtersArrayU8
);
} else {
self.postMessage({ status: "ready", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}

self.addEventListener("message", async (event) => {
const {
weightsURL,
tokenizerURL,
configURL,
modelID,
// sentences,
normalize = true,
} = event.data;
try {
self.postMessage({ status: "ready", message: "Starting Bert Model" });
const model = await Amplify.getInstance(
weightsURL,
tokenizerURL,
configURL,
modelID
);
self.postMessage({
status: "embedding",
message: "Calculating Embeddings",
});
const output = model.get_embeddings({
sentences: sentences,
normalize_embeddings: normalize,
});

self.postMessage({
status: "complete",
message: "complete",
output: output.data,
});
} catch (e) {
self.postMessage({ error: e });
}
});
19 changes: 19 additions & 0 deletions ferritin-wasm-examples/amplify/build-lib.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash
set -e # Exit on error

# cargo clean
mkdir -p target/wasm32-unknown-unknown/release
cargo update
cargo +nightly build \
--target wasm32-unknown-unknown \
--release \
-Z build-std=std,panic_abort \
--no-default-features



# cargo build --target wasm32-unknown-unknown --release
# rustup +nightly target add wasm32-unknown-unknown
# cargo +nightly build --target wasm32-unknown-unknown --release
#
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web
Loading

0 comments on commit af62d0e

Please sign in to comment.