From af62d0ea95f2d9f7d22c778ce2922a65b8fdd1a9 Mon Sep 17 00:00:00 2001 From: zachcp Date: Fri, 6 Dec 2024 16:41:20 -0500 Subject: [PATCH] AMPLIFY + Metal (#69) * get AMPLIFY up and running on Mac // Metal * simplify the AMPLIFY Example code * WASM Example of AMPLIFY * refactor deps to let WASM compilation * being the JS version that pulls the model, the configs, and the tokenizer from the same repo * realize we should follow that same format for the AMPLIFy structs. --- Cargo.lock | 312 ++++++++++++++- Cargo.toml | 23 +- ferritin-amplify/Cargo.toml | 21 +- ferritin-amplify/examples/amplify/main.rs | 59 +-- ferritin-amplify/src/amplify/amplify.rs | 6 +- ferritin-esm/Cargo.toml | 16 +- ferritin-ligandmpnn/Cargo.toml | 13 +- ferritin-molviewspec/Cargo.toml | 2 +- ferritin-wasm-examples/amplify/.gitignore | 1 + ferritin-wasm-examples/amplify/Cargo.toml | 25 ++ ferritin-wasm-examples/amplify/Readme.md | 1 + .../amplify/amplifyWorker.js | 81 ++++ ferritin-wasm-examples/amplify/build-lib.sh | 19 + .../amplify/lib-example.html | 375 ++++++++++++++++++ ferritin-wasm-examples/amplify/src/bin/m.rs | 101 +++++ ferritin-wasm-examples/amplify/src/lib.rs | 19 + ferritin-wasm-examples/amplify/utils.js | 75 ++++ justfile | 3 +- 18 files changed, 1055 insertions(+), 97 deletions(-) create mode 100644 ferritin-wasm-examples/amplify/.gitignore create mode 100644 ferritin-wasm-examples/amplify/Cargo.toml create mode 100644 ferritin-wasm-examples/amplify/Readme.md create mode 100644 ferritin-wasm-examples/amplify/amplifyWorker.js create mode 100644 ferritin-wasm-examples/amplify/build-lib.sh create mode 100644 ferritin-wasm-examples/amplify/lib-example.html create mode 100644 ferritin-wasm-examples/amplify/src/bin/m.rs create mode 100644 ferritin-wasm-examples/amplify/src/lib.rs create mode 100644 ferritin-wasm-examples/amplify/utils.js diff --git a/Cargo.lock b/Cargo.lock index 36a7ae7a..319639e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -892,7 +892,7 @@ dependencies = [ "proc-macro2 1.0.92", "quote 1.0.37", "syn", - "toml_edit", + "toml_edit 0.22.22", ] [[package]] @@ -1352,6 +1352,15 @@ dependencies = [ "winit", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bindgen" version = "0.70.1" @@ -1604,7 +1613,7 @@ checksum = "ca5f45ce8fe55a9e9246a3fc60000d7ed11b88a84d72f753488f7264ce04b102" dependencies = [ "dirs", "futures", - "http", + "http 1.2.0", "indicatif", "log", "num_cpus", @@ -2398,7 +2407,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2465,9 +2474,9 @@ checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "fdeflate" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" dependencies = [ "simd-adler32", ] @@ -2597,6 +2606,23 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ferritin-wasm-example-amplify" +version = "0.1.0" +dependencies = [ + "candle-core", + "candle-nn", + "console_error_panic_hook", + "ferritin-amplify", + "getrandom", + "gloo", + "js-sys", + "serde-wasm-bindgen", + "serde_json", + "tokenizers 0.21.0", + "wasm-bindgen", +] + [[package]] name = "fixedbitset" version = "0.4.2" @@ -3071,6 +3097,187 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "gloo" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d15282ece24eaf4bd338d73ef580c6714c8615155c4190c781290ee3fa0fd372" +dependencies = [ + "gloo-console", + "gloo-dialogs", + "gloo-events", + "gloo-file", + "gloo-history", + "gloo-net", + "gloo-render", + "gloo-storage", + "gloo-timers", + "gloo-utils", + "gloo-worker", +] + +[[package]] +name = "gloo-console" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a17868f56b4a24f677b17c8cb69958385102fa879418052d60b50bc1727e261" +dependencies = [ + "gloo-utils", + "js-sys", + "serde", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-dialogs" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4748e10122b01435750ff530095b1217cf6546173459448b83913ebe7815df" +dependencies = [ + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-events" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27c26fb45f7c385ba980f5fa87ac677e363949e065a083722697ef1b2cc91e41" +dependencies = [ + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-file" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97563d71863fb2824b2e974e754a81d19c4a7ec47b09ced8a0e6656b6d54bd1f" +dependencies = [ + "gloo-events", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-history" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "903f432be5ba34427eac5e16048ef65604a82061fe93789f2212afc73d8617d6" +dependencies = [ + "getrandom", + "gloo-events", + "gloo-utils", + "serde", + "serde-wasm-bindgen", + "serde_urlencoded", + "thiserror", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-net" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43aaa242d1239a8822c15c645f02166398da4f8b5c4bae795c1f5b44e9eee173" +dependencies = [ + "futures-channel", + "futures-core", + "futures-sink", + "gloo-utils", + "http 0.2.12", + "js-sys", + "pin-project", + "serde", + "serde_json", + "thiserror", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "gloo-render" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56008b6744713a8e8d98ac3dcb7d06543d5662358c9c805b4ce2167ad4649833" +dependencies = [ + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-storage" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc8031e8c92758af912f9bc08fbbadd3c6f3cfcbf6b64cdf3d6a81f0139277a" +dependencies = [ + "gloo-utils", + "js-sys", + "serde", + "serde_json", + "thiserror", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-timers" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "gloo-utils" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5555354113b18c547c1d3a98fbf7fb32a9ff4f6fa112ce823a21641a0ba3aa" +dependencies = [ + "js-sys", + "serde", + "serde_json", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "gloo-worker" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "085f262d7604911c8150162529cefab3782e91adb20202e8658f7275d2aefe5d" +dependencies = [ + "bincode", + "futures", + "gloo-utils", + "gloo-worker-macros", + "js-sys", + "pinned", + "serde", + "thiserror", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "gloo-worker-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "956caa58d4857bc9941749d55e4bd3000032d8212762586fa5705632967140e7" +dependencies = [ + "proc-macro-crate 1.3.1", + "proc-macro2 1.0.92", + "quote 1.0.37", + "syn", +] + [[package]] name = "glow" version = "0.14.2" @@ -3206,7 +3413,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.2.0", "indexmap", "slab", "tokio", @@ -3301,6 +3508,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.2.0" @@ -3319,7 +3537,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.2.0", ] [[package]] @@ -3330,7 +3548,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http", + "http 1.2.0", "http-body", "pin-project-lite", ] @@ -3351,7 +3569,7 @@ dependencies = [ "futures-channel", "futures-util", "h2", - "http", + "http 1.2.0", "http-body", "httparse", "itoa", @@ -3368,7 +3586,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", - "http", + "http 1.2.0", "hyper", "hyper-util", "rustls", @@ -3403,7 +3621,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", + "http 1.2.0", "http-body", "hyper", "pin-project-lite", @@ -3855,7 +4073,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -4332,7 +4550,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ - "proc-macro-crate", + "proc-macro-crate 3.2.0", "proc-macro2 1.0.92", "quote 1.0.37", "syn", @@ -4809,6 +5027,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pinned" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a829027bd95e54cfe13e3e258a1ae7b645960553fb82b75ff852c29688ee595b" +dependencies = [ + "futures", + "rustversion", + "thiserror", +] + [[package]] name = "piper" version = "0.2.4" @@ -4921,13 +5150,23 @@ dependencies = [ "syn", ] +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit 0.19.15", +] + [[package]] name = "proc-macro-crate" version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" dependencies = [ - "toml_edit", + "toml_edit 0.22.22", ] [[package]] @@ -5223,7 +5462,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 1.2.0", "http-body", "http-body-util", "hyper", @@ -5516,6 +5755,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde-wasm-bindgen" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "serde_bytes" version = "0.11.15" @@ -5890,7 +6140,7 @@ dependencies = [ "fastrand", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -6012,6 +6262,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", + "fancy-regex", "getrandom", "indicatif", "itertools 0.12.1", @@ -6074,12 +6325,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ "rustls", - "rustls-pki-types", "tokio", ] @@ -6102,6 +6352,17 @@ version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +[[package]] +name = "toml_edit" +version = "0.19.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow 0.5.40", +] + [[package]] name = "toml_edit" version = "0.22.22" @@ -6110,7 +6371,7 @@ checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap", "toml_datetime", - "winnow", + "winnow 0.6.20", ] [[package]] @@ -6738,7 +6999,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -7162,6 +7423,15 @@ dependencies = [ "xkbcommon-dl", ] +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + [[package]] name = "winnow" version = "0.6.20" diff --git a/Cargo.toml b/Cargo.toml index f5254ccf..54e66789 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,20 +9,29 @@ members = [ "ferritin-molviewspec", "ferritin-pymol", "ferritin-test-data", + "ferritin-wasm-examples/*", ] resolver = "2" [workspace.dependencies] +anyhow = "1.0" +bitflags = "2.6.0" +candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [ +] } +candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" } +candle-hf-hub = "0.3.3" +candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" } +candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers" } +itertools = "0.13.0" +once_cell = "1.20.2" +pdbtbx = "0.12.0" +rand = "0.8.5" +safetensors = "0.4.5" serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0.133" serde_bytes = "0.11.15" +serde_json = "1.0.133" serde_repr = "0.1.19" -urlencoding = "2.1.3" -once_cell = "1.20.2" -bitflags = "2.6.0" -anyhow = "1.0" -pdbtbx = "0.12.0" -itertools = "0.13.0" +tokenizers = { version = "0.21.0", default-features = false } [workspace.package] version = "0.1.0" diff --git a/ferritin-amplify/Cargo.toml b/ferritin-amplify/Cargo.toml index db217a9b..2349501b 100644 --- a/ferritin-amplify/Cargo.toml +++ b/ferritin-amplify/Cargo.toml @@ -11,14 +11,11 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"] [dependencies] anyhow.workspace = true -candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [ -] } -candle-hf-hub = "0.3.3" -candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" } +candle-core.workspace = true +candle-nn.workspace = true ferritin-test-data = { path = "../ferritin-test-data" } -rand = "0.8.5" +rand.workspace = true safetensors = "0.4.5" -tokenizers = "0.21.0" [target.'cfg(target_os = "macos")'.features] metal = [] @@ -26,8 +23,18 @@ metal = [] [target.'cfg(target_os = "macos")'.dependencies] candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true } +[target.'cfg(target_arch = "wasm32")'.dependencies] +tokenizers = { version = "0.21.0", default-features = false, features = [ + "unstable_wasm", +] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +candle-hf-hub = { workspace = true } +tokenizers = { version = "0.21.0" } # full features for non-wasm + + [dev-dependencies] -candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" } +candle-examples.workspace = true ferritin-test-data = { path = "../ferritin-test-data" } assert_cmd = "2.0.16" tempfile = "3.14.0" diff --git a/ferritin-amplify/examples/amplify/main.rs b/ferritin-amplify/examples/amplify/main.rs index d635e6cb..f1d1b492 100644 --- a/ferritin-amplify/examples/amplify/main.rs +++ b/ferritin-amplify/examples/amplify/main.rs @@ -1,53 +1,24 @@ use anyhow::Result; -use candle_core::{DType, Device, D}; -use candle_hf_hub::{api::sync::Api, Repo, RepoType}; -use candle_nn::VarBuilder; -use ferritin_amplify::{AMPLIFYConfig, ProteinTokenizer, AMPLIFY}; -use safetensors::SafeTensors; +use candle_core::D; +use candle_examples::device; +use ferritin_amplify::AMPLIFY; fn main() -> Result<()> { - let model_id = "chandar-lab/AMPLIFY_120M"; - let revision = "main"; - let api = Api::new()?; - let repo = api.repo(Repo::with_revision( - model_id.to_string(), - RepoType::Model, - revision.to_string(), - )); - let weights_path = repo.get("model.safetensors")?; - - // Available for printing Tensor data.... - let print_tensor_info = false; - if print_tensor_info { - println!("Model tensors:"); - let weights = std::fs::read(&weights_path)?; - let tensors = SafeTensors::deserialize(&weights)?; - tensors.names().iter().for_each(|tensor_name| { - if let Ok(tensor_info) = tensors.tensor(tensor_name) { - println!( - "Tensor: {:<44} || Shape: {:?}", - tensor_name, - tensor_info.shape(), - ); - } - }); - } + println!("Determining the Device ......"); + #[cfg(target_os = "macos")] + let use_gpu = false; + #[cfg(not(target_os = "macos"))] + let use_gpu = true; + let dev = device(use_gpu)?; println!("Loading the Amplify Model ......"); - // https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs#L91C1-L92C101 - let vb = unsafe { - VarBuilder::from_mmaped_safetensors(&[weights_path.clone()], DType::F32, &Device::Cpu)? - }; - let config = AMPLIFYConfig::amp_120m(); - let model = AMPLIFY::load(vb, &config)?; - - println!("Tokenizing and Modelling a Sequence from Swissprot..."); - let tokenizer = repo.get("tokenizer.json")?; - let protein_tokenizer = ProteinTokenizer::new(tokenizer)?; + let (tokenizer, amplify) = AMPLIFY::load_from_huggingface(dev.clone())?; let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL"; - let pmatrix = protein_tokenizer.encode(&[sprot_01.to_string()], None, false, false)?; + let pmatrix = tokenizer.encode(&[sprot_01.to_string()], None, false, false)?; + let pmatrix = pmatrix.to_device(&dev)?; + let pmatrix = pmatrix.unsqueeze(0)?; // [batch, length] <- add batch of 1 in this case - let encoded = model.forward(&pmatrix, None, false, false)?; + let encoded = amplify.forward(&pmatrix, None, false, false)?; println!("Assessing the Predictions......."); // As of Nov 13 this is definitely not right.... @@ -55,7 +26,7 @@ fn main() -> Result<()> { // Output: MSVQLNIVGQSAAWTHGAAVCATCAQTFWPMSRGRQPPVNMSRFTARCTECIWYEAAFNARFNFVHLYNCGPNMSECLANMSWWYACQFGVHMSKSHYCGNKPLGTDNTKMMHHRECTSTVVWKHWPLCKVTVCYRHGLVSCTMHQRSTWTPRNEASWVPEWETSTPEHTCGDYWACQMPAGHGVCCCMMTEHWKPHTRVVCQTIEMWTYLQTYYYFWGVPEPCHHHIWTEPMPTSTSTSYDVVMYTTSGFGQHHW let predictions = encoded.logits.argmax(D::Minus1)?; let indices: Vec = predictions.to_vec2()?[0].to_vec(); - let decoded = protein_tokenizer.decode(indices.as_slice(), true)?; + let decoded = tokenizer.decode(indices.as_slice(), true)?; println!("Encoded Logits Dimension: {:?}, ", encoded.logits); println!("indices: {:?}", indices); println!("Decoded Values: {}", decoded.replace(" ", "")); diff --git a/ferritin-amplify/src/amplify/amplify.rs b/ferritin-amplify/src/amplify/amplify.rs index 444f7dd6..62ed24db 100644 --- a/ferritin-amplify/src/amplify/amplify.rs +++ b/ferritin-amplify/src/amplify/amplify.rs @@ -12,12 +12,14 @@ use super::rotary::{apply_rotary_emb, precompute_freqs_cis}; use super::tokenizer::ProteinTokenizer; use candle_core::{DType, Device, Module, Result, Tensor, D}; -use candle_hf_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::{ embedding, linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Activation, Dropout, Embedding, Linear, RmsNorm, VarBuilder, }; +#[cfg(not(target_arch = "wasm32"))] +use candle_hf_hub::{api::sync::Api, Repo, RepoType}; + #[derive(Debug, Clone)] /// Configuration Struct for AMPLIFY /// @@ -481,6 +483,7 @@ impl AMPLIFY { config: cfg.clone(), }) } + #[cfg(not(target_arch = "wasm32"))] /// Retreive the model and make it available for usage. /// hardcode the 120M for the moment... pub fn load_from_huggingface(device: Device) -> Result<(ProteinTokenizer, Self)> { @@ -506,6 +509,7 @@ impl AMPLIFY { let tokenizer = repo .get("tokenizer.json") .map_err(|e| candle_core::Error::Msg(e.to_string()))?; + let protein_tokenizer = ProteinTokenizer::new(tokenizer).map_err(|e| candle_core::Error::Msg(e.to_string()))?; diff --git a/ferritin-esm/Cargo.toml b/ferritin-esm/Cargo.toml index 6f046be0..29f9b8c9 100644 --- a/ferritin-esm/Cargo.toml +++ b/ferritin-esm/Cargo.toml @@ -11,14 +11,14 @@ metal = ["candle-core/metal", "candle-nn/metal", "candle-metal-kernels"] [dependencies] anyhow.workspace = true -candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [ -] } -candle-hf-hub = "0.3.3" -candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" } +candle-core.workspace = true +candle-hf-hub.workspace = true +candle-nn.workspace = true ferritin-test-data = { path = "../ferritin-test-data" } -rand = "0.8.5" -safetensors = "0.4.5" -tokenizers = "0.21.0" +rand.workspace = true +safetensors.workspace = true +tokenizers.workspace = true + [target.'cfg(target_os = "macos")'.features] metal = [] @@ -27,7 +27,7 @@ metal = [] candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true } [dev-dependencies] -candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" } +candle-examples.workspace = true ferritin-test-data = { path = "../ferritin-test-data" } assert_cmd = "2.0.16" tempfile = "3.14.0" diff --git a/ferritin-ligandmpnn/Cargo.toml b/ferritin-ligandmpnn/Cargo.toml index 4168c38e..5b408726 100644 --- a/ferritin-ligandmpnn/Cargo.toml +++ b/ferritin-ligandmpnn/Cargo.toml @@ -16,17 +16,16 @@ metal = [ [dependencies] anyhow.workspace = true -candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [ -] } -candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" } -candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers" } +candle-core.workspace = true +candle-nn.workspace = true +candle-transformers.workspace = true clap = "4.5.23" ferritin-core = { path = "../ferritin-core" } ferritin-test-data = { path = "../ferritin-test-data" } itertools.workspace = true pdbtbx.workspace = true -rand = "0.8.5" -safetensors = "0.4.5" +rand.workspace = true +safetensors.workspace = true strum = { version = "0.26", features = ["derive"] } @@ -37,7 +36,7 @@ metal = [] candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels", optional = true } [dev-dependencies] -candle-examples = { git = "https://github.com/huggingface/candle.git", package = "candle-examples" } +candle-examples.workspace = true ferritin-test-data = { path = "../ferritin-test-data" } assert_cmd = "2.0.16" tempfile = "3.14.0" diff --git a/ferritin-molviewspec/Cargo.toml b/ferritin-molviewspec/Cargo.toml index 70f51710..44584767 100644 --- a/ferritin-molviewspec/Cargo.toml +++ b/ferritin-molviewspec/Cargo.toml @@ -11,7 +11,7 @@ chrono = "0.4.38" validator = { version = "0.19.0", features = ["derive"] } serde = { workspace = true } serde_json = { workspace = true } -urlencoding = { workspace = true } +urlencoding = "2.1.3" [dev.dependencies] ferritin-pymol = { path = "../ferritin-pymol" } diff --git a/ferritin-wasm-examples/amplify/.gitignore b/ferritin-wasm-examples/amplify/.gitignore new file mode 100644 index 00000000..567609b1 --- /dev/null +++ b/ferritin-wasm-examples/amplify/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/ferritin-wasm-examples/amplify/Cargo.toml b/ferritin-wasm-examples/amplify/Cargo.toml new file mode 100644 index 00000000..4a95e734 --- /dev/null +++ b/ferritin-wasm-examples/amplify/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "ferritin-wasm-example-amplify" +version = "0.1.0" +edition = "2021" + +[dependencies] +candle-core = { workspace = true } +candle-nn = { workspace = true } +tokenizers = { workspace = true, default-features = false, features = [ + "unstable_wasm", +] } +ferritin-amplify = { path = "../../ferritin-amplify" } + +# for wasm +gloo = "0.11.0" +js-sys = "0.3.74" +wasm-bindgen = "0.2.97" +serde-wasm-bindgen = "0.6.5" +serde_json.workspace = true +console_error_panic_hook = "0.1.7" +getrandom = { version = "0.2.15", features = ["js"] } + + +# [lib] +# crate-type = ["cdylib", "rlib"] diff --git a/ferritin-wasm-examples/amplify/Readme.md b/ferritin-wasm-examples/amplify/Readme.md new file mode 100644 index 00000000..8fbf9902 --- /dev/null +++ b/ferritin-wasm-examples/amplify/Readme.md @@ -0,0 +1 @@ +# Amplify WASM Example diff --git a/ferritin-wasm-examples/amplify/amplifyWorker.js b/ferritin-wasm-examples/amplify/amplifyWorker.js new file mode 100644 index 00000000..bff6d24e --- /dev/null +++ b/ferritin-wasm-examples/amplify/amplifyWorker.js @@ -0,0 +1,81 @@ +//load Candle Bert Module wasm module +import init, { Model } from "./build/m.js"; + +async function fetchArrayBuffer(url) { + const cacheName = "amplify-candle-cache"; + const cache = await caches.open(cacheName); + const cachedResponse = await cache.match(url); + if (cachedResponse) { + const data = await cachedResponse.arrayBuffer(); + return new Uint8Array(data); + } + const res = await fetch(url, { cache: "force-cache" }); + cache.put(url, res.clone()); + return new Uint8Array(await res.arrayBuffer()); +} + +class Amplify { + static instance = {}; + + static async getInstance(weightsURL, tokenizerURL, configURL, modelID) { + if (!this.instance[modelID]) { + await init(); + + self.postMessage({ status: "downloaded", message: "Downloaded Model" }); + + const [weightsArrayU8, tokenizerArrayU8, mel_filtersArrayU8] = + await Promise.all([ + fetchArrayBuffer(weightsURL), + fetchArrayBuffer(tokenizerURL), + fetchArrayBuffer(configURL), + ]); + + self.postMessage({ status: "loading", message: "Loading Model" }); + + this.instance[modelID] = new Model( + weightsArrayU8, + tokenizerArrayU8, + mel_filtersArrayU8 + ); + } else { + self.postMessage({ status: "ready", message: "Model Already Loaded" }); + } + return this.instance[modelID]; + } +} + +self.addEventListener("message", async (event) => { + const { + weightsURL, + tokenizerURL, + configURL, + modelID, + // sentences, + normalize = true, + } = event.data; + try { + self.postMessage({ status: "ready", message: "Starting Bert Model" }); + const model = await Amplify.getInstance( + weightsURL, + tokenizerURL, + configURL, + modelID + ); + self.postMessage({ + status: "embedding", + message: "Calculating Embeddings", + }); + const output = model.get_embeddings({ + sentences: sentences, + normalize_embeddings: normalize, + }); + + self.postMessage({ + status: "complete", + message: "complete", + output: output.data, + }); + } catch (e) { + self.postMessage({ error: e }); + } +}); diff --git a/ferritin-wasm-examples/amplify/build-lib.sh b/ferritin-wasm-examples/amplify/build-lib.sh new file mode 100644 index 00000000..7c2cf0c5 --- /dev/null +++ b/ferritin-wasm-examples/amplify/build-lib.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e # Exit on error + +# cargo clean +mkdir -p target/wasm32-unknown-unknown/release +cargo update +cargo +nightly build \ + --target wasm32-unknown-unknown \ + --release \ + -Z build-std=std,panic_abort \ + --no-default-features + + + +# cargo build --target wasm32-unknown-unknown --release +# rustup +nightly target add wasm32-unknown-unknown +# cargo +nightly build --target wasm32-unknown-unknown --release +# +wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web diff --git a/ferritin-wasm-examples/amplify/lib-example.html b/ferritin-wasm-examples/amplify/lib-example.html new file mode 100644 index 00000000..d279f54c --- /dev/null +++ b/ferritin-wasm-examples/amplify/lib-example.html @@ -0,0 +1,375 @@ + + + + Ferritin Amplify + + + + + + + + + + + + + + + +
+ 🕯️ +
+

Ferritin AMPLIFY

+

Rust/WASM Demo

+

+ Caluculating protein language model embeddings in the browser using + the AMPLIFY Model written with + Candle + + and compiled to Wasm. Models and weights are derived from the + + AMPLIFY model + . + +

+
+ +
+ + +
+
+

Examples:

+
+ + + + + + + +
+
+
+ + + +
+
+

Input text:

+
+
+ +
+
+
+ +

+ Input text to perform semantic similarity search... +

+
+
+
+ + diff --git a/ferritin-wasm-examples/amplify/src/bin/m.rs b/ferritin-wasm-examples/amplify/src/bin/m.rs new file mode 100644 index 00000000..bc38a12c --- /dev/null +++ b/ferritin-wasm-examples/amplify/src/bin/m.rs @@ -0,0 +1,101 @@ +use candle_core::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use ferritin_amplify::{AMPLIFYConfig, ProteinTokenizer, AMPLIFY}; +use ferritin_wasm_example_amplify::console_log; +use tokenizers::{PaddingParams, Tokenizer}; +use wasm_bindgen::prelude::*; + +#[wasm_bindgen] +pub struct Model { + amplify: AMPLIFY, + tokenizer: ProteinTokenizer, +} + +#[wasm_bindgen] +impl Model { + #[wasm_bindgen(constructor)] + pub fn load(weights: Vec, tokenizer: Vec, config: Vec) -> Result { + console_error_panic_hook::set_once(); + let device = &Device::Cpu; + let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?; + let config: AMPLIFYConfig = serde_json::from_slice(&config)?; + let amplify = AMPLIFY::load(vb, &config)?; + // currently tokenizer fetches from HuggingFace + let tokenizer = + Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; + + Ok(Self { amplify, tokenizer }) + // Ok(Self { amplify }) + } + + // pub fn get_embeddings(&mut self, input: JsValue) -> Result { + // let input: Params = + // serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; + // let sentences = input.sentences; + // let normalize_embeddings = input.normalize_embeddings; + // let device = &Device::Cpu; + // if let Some(pp) = self.tokenizer.get_padding_mut() { + // pp.strategy = tokenizers::PaddingStrategy::BatchLongest + // } else { + // let pp = PaddingParams { + // strategy: tokenizers::PaddingStrategy::BatchLongest, + // ..Default::default() + // }; + // self.tokenizer.with_padding(Some(pp)); + // } + // let tokens = self + // .tokenizer + // .encode_batch(sentences.to_vec(), true) + // .map_err(|m| JsError::new(&m.to_string()))?; + + // let token_ids: Vec = tokens + // .iter() + // .map(|tokens| { + // let tokens = tokens.get_ids().to_vec(); + // Tensor::new(tokens.as_slice(), device) + // }) + // .collect::, _>>()?; + // let attention_mask: Vec = tokens + // .iter() + // .map(|tokens| { + // let tokens = tokens.get_attention_mask().to_vec(); + // Tensor::new(tokens.as_slice(), device) + // }) + // .collect::, _>>()?; + // let token_ids = Tensor::stack(&token_ids, 0)?; + // let attention_mask = Tensor::stack(&attention_mask, 0)?; + // let token_type_ids = token_ids.zeros_like()?; + // console_log!("running inference on batch {:?}", token_ids.shape()); + // let embeddings = self + // .bert + // .forward(&token_ids, &token_type_ids, Some(&attention_mask))?; + // console_log!("generated embeddings {:?}", embeddings.shape()); + // // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + // let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; + // let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; + // let embeddings = if normalize_embeddings { + // embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)? + // } else { + // embeddings + // }; + // let embeddings_data = embeddings.to_vec2()?; + // Ok(serde_wasm_bindgen::to_value(&Embeddings { + // data: embeddings_data, + // })?) + // } +} + +// #[derive(serde::Serialize, serde::Deserialize)] +// struct Embeddings { +// data: Vec>, +// } + +// #[derive(serde::Serialize, serde::Deserialize)] +// pub struct Params { +// sentences: Vec, +// normalize_embeddings: bool, +// } + +fn main() { + console_error_panic_hook::set_once(); +} diff --git a/ferritin-wasm-examples/amplify/src/lib.rs b/ferritin-wasm-examples/amplify/src/lib.rs new file mode 100644 index 00000000..b16196d6 --- /dev/null +++ b/ferritin-wasm-examples/amplify/src/lib.rs @@ -0,0 +1,19 @@ +use wasm_bindgen::prelude::*; + +pub use ferritin_amplify::AMPLIFY; +pub use tokenizers::{PaddingParams, Tokenizer}; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + pub fn log(s: &str); +} + +#[macro_export] +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string())) +} diff --git a/ferritin-wasm-examples/amplify/utils.js b/ferritin-wasm-examples/amplify/utils.js new file mode 100644 index 00000000..7a96a6be --- /dev/null +++ b/ferritin-wasm-examples/amplify/utils.js @@ -0,0 +1,75 @@ +export async function getEmbeddings( + worker, + weightsURL, + tokenizerURL, + configURL, + modelID, + // sentences, + updateStatus = null +) { + return new Promise((resolve, reject) => { + worker.postMessage({ + weightsURL, + tokenizerURL, + configURL, + modelID, + // sentences, + }); + function messageHandler(event) { + if ("error" in event.data) { + worker.removeEventListener("message", messageHandler); + reject(new Error(event.data.error)); + } + if (event.data.status === "complete") { + worker.removeEventListener("message", messageHandler); + resolve(event.data); + } + if (updateStatus) updateStatus(event.data); + } + worker.addEventListener("message", messageHandler); + }); +} + +const MODELS = { + "amplify_120M": { + base_url: "https://huggingface.co/chandar-lab/AMPLIFY_120M/resolve/main/", + }, +}; +export function getModelInfo(id) { + console.log(id); + console.log(MODELS); + return { + modelURL: MODELS[id].base_url + "model.safetensors", + configURL: MODELS[id].base_url + "config.json", + tokenizerURL: MODELS[id].base_url + "tokenizer.json", + }; +} + +// export function cosineSimilarity(vec1, vec2) { +// const dot = vec1.reduce((acc, val, i) => acc + val * vec2[i], 0); +// const a = Math.sqrt(vec1.reduce((acc, val) => acc + val * val, 0)); +// const b = Math.sqrt(vec2.reduce((acc, val) => acc + val * val, 0)); +// return dot / (a * b); +// } + +// export async function getWikiText(article) { +// // thanks to wikipedia for the API +// const URL = `https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=${article}&explaintext=1&exsectionformat=plain&format=json&origin=*`; +// return fetch(URL, { +// method: "GET", +// headers: { +// Accept: "application/json", +// }, +// }) +// .then((r) => r.json()) +// .then((data) => { +// const pages = data.query.pages; +// const pageId = Object.keys(pages)[0]; +// const extract = pages[pageId].extract; +// if (extract === undefined || extract === "") { +// throw new Error("No article found"); +// } +// return extract; +// }) +// .catch((error) => console.error("Error:", error)); +// } diff --git a/justfile b/justfile index c022372d..b320c474 100644 --- a/justfile +++ b/justfile @@ -39,7 +39,8 @@ test-full: cargo test -- --include-ignored amplify: - cargo run --example amplify + # cargo run --example amplify + cargo run --example amplify --features metal test-ligandmpnn: cargo test --features metal -p ferritin-ligandmpnn test_cli_command_run_example_06 -- --nocapture