Skip to content

Commit

Permalink
LigandMPNN Blog Post (#60)
Browse files Browse the repository at this point in the history
* stub blog post

* remove metal for CI
  • Loading branch information
zachcp authored Dec 2, 2024
1 parent 0c8a8e0 commit 9992f55
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 69 deletions.
59 changes: 1 addition & 58 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added docs/posts/20241202_lmpnn_01/images/claude.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
125 changes: 125 additions & 0 deletions docs/posts/20241202_lmpnn_01/index.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
---
title: "Return to LigandMPNN"
description: "Model updates and Metal Implementation"
author: "Zachary Charlop-Powers"
date: "2024-12-02"
categories: [rust, ai, proteins]
image: "images/metal_pre.png"
---



# Returning to LigandMPNN

The original motivating goal of this project was a pure-rust-WASM ProteinMPNN/LigandMPNN
implementation for fast local design. In previous posts I discussed:

1. Core Protein Data Structures for effecient protein representation using the Stuct-of-Arrays style. [link](https://zachcp.github.io/ferritin/posts/20241029_first_post/)
2. The LigandMPNN Trait to define functions for extracting features from a protein representation. [link](https://zachcp.github.io/ferritin/posts/20241105_lmpnnfeaturizer/)
3. The Candle implementation of the Amplify model (parts [1](https://zachcp.github.io/ferritin/posts/20241111_amplify/), [2](https://zachcp.github.io/ferritin/posts/20241114_amplify_02/), and [3](https://zachcp.github.io/ferritin/posts/20241115_amplify_03/))


Now that I was able to get my feet wet on a protein language model implmentation, I am ready to return to the more architecturally challenging problem of {Protein/Ligand}-MPNN. This post will descrive a few of the challenges faced
in porting that library over and the current state of the model.

# Issues to Solve.

## Model Complexity

There are a few differences between Candle and Pytorch in terms of how:

- how they handle dimensions selection (pytorch: NUMPY-like; Candle: methods like `i`, `narrow`, `squeeze` and `unsqueeze`)
- whether the Tensors can be mutated in place (pytorch:yes; Candle: no)
- specifying matrix contiguity. (pytorch: ?; candle: `.contiguous()`)

As a ballpark approximation we can take a look at a few of the function call types that handle those differences. I am
calculating the occurences of function calls in the code bases [here](https://github.com/dauparas/LigandMPNN) and [here](https://huggingface.co/chandar-lab/AMPLIFY_120M) and showing
the aggregated results in the table below. The LigandMPNN data is inflated because these calls also include the `sc.py` file which allows for side-cahin packing. Nonetheless,
from the perspective of implementation, it should be clear that the potentially tricky implementation bits in LigandMPNN are far greater than in Amplify.

```shell
# ligandMPNN or hugginface AMPLIFY_120M dirs
rg -c '\[' *py
rg -c 'gather' *py
rg -c 'scatter' *py
rg -c 'mul' *py
```


| Term | LigandMPNN | AMPLIFY_120M |
|------|------------|--------------|
| `[` | 910 | 12 |
| `gather` | 55 | 0 |
| `scatter` | 5 | 0 |
| `mul` | 52 | 3 |


## Model Loading

In working with AMPLIFY, one of the key successes was being able to load Amplify's model into a [VarBuilder(https://docs.rs/candle-nn/latest/candle_nn/var_builder/type.VarBuilder.html) using the
`from_mmaped_safetensors`. You can then build your model by accessing the Tensors by name. It allows you to match the pytorch model using layers that make sense by name. During this process,
I noticed that there is a similar function for pytorch files - `from_pth`. Excellent! This gives me a new tool I lacked - the ability to load the model with the exact same
names as the pytorch model. I should be able to laod this file and have all the Tensors match and use/account for all layers. It turned out there was a hiccup in that the PTH
and safetensor formats differ a bit and you need to be able to access the pytorch statemap. I submitted a fix [here](https://github.com/huggingface/candle/pull/2639) and began
assemble the model into my pre-existing code. This is where I began to run into a nubmer of issues related to Tensor dimension.


## Dimension Matching

As I began to load the Tensors in from the pytorch file, I began to hit errors introduced by the incompatible syntax mentioned above. In this case I would need to compare
the pytorch code with my Rust code and 1) introduce the Candle syntax while 2) maintaining the flow/intention of the model. For this I leaned quite heavily on Claude/Sonnet3.5 via
the Zed editor. This was an invaluable experience and further impressed me as to LLM capability. Here is a taste of Claude's explanatory power; full gist [here](https://gist.github.com/zachcp/45ae897bd0db389b6a288a99d25011bd)

![](images/claude.png)



## Speed

After a bit of work I was able to load the model and `run` it where `run` means execute the model with an input and get an output without failing. My strategy had alwasy been to get it
running then get it to pass tests so I was pretty pleased. However, the model took minutes to run! Not what I was looking for. So after pinging on the Candle Discord, I realized
that I was on MacOS but that I had been using `Device::CPU`. What if we switch to `Device::Metal`? I had to rework the code a bit to get the Device passed in but then I hit a bunch of errors like:

```rust
// Metal doesn't support 64 bits!
Err(WithBacktrace { inner: Msg("Metal strided to_dtype F64 F32 not implemented")
// No gather calls on integers!
Err(WithBacktrace { inner: Msg("Metal gather U32 I64 not implemented"),
// No scatter-add on integers!
Err(Metal(UnexpectedDType { msg: "scatter-add ids should be u8/u32/i64", expected: U32, got: U32 }))
```

The first fix is to convert a number of F64/I64s to F32/U32. I then needed to track down and implement a few 2-line additions to Candle's Metal kernels that would allow the kernels
to work. These were the PRs for [Gather](https://github.com/huggingface/candle/pull/2653) and [Scatter Add](https://github.com/huggingface/candle/pull/2656). The result was impressive.
My inital model ran in 3 minutes; the new model in 8 seconds! Okay, we can work with that.



```shell
cargo instruments -t time \
--bin ferritin-featurizers \
-- run --seed 111 \
--pdb-path ferritin-test-data/data/structures/1bc8.cif \
--model-type protein_mpnn --out-folder testout
```




:::: {.columns}

::: {.column width="50%"}
![](images/metal_pre.png)
:::

::: {.column width="50%"}
![](images/metal_post.png)
:::

::::


## Testing Suite.

I've started a [test suite](https://github.com/zachcp/ferritin/pull/59) to match LigandMPNNs and have begun implementing the CLI code for it. As of Today, December 2 there is
not much to show. But I am satisfied with where the project has gotten and am impressed by Justas Dauparas and his colloaborators on this implementation. There are still some hard bits ahead.
24 changes: 13 additions & 11 deletions ferritin-featurizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ description.workspace = true
[dependencies]
anyhow.workspace = true

candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels" }
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
"metal",
] }
candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn", features = [
"metal",
] }
candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers", features = [
"metal",
] }

candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn" }
#candle-metal-kernels = { git = "https://github.com/huggingface/candle.git", package = "candle-metal-kernels" }
candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers" }
# candle-core = { git = "https://github.com/huggingface/candle.git", package = "candle-core", features = [
# "metal",
# ] }
# candle-nn = { git = "https://github.com/huggingface/candle.git", package = "candle-nn", features = [
# "metal",
# ] }
# candle-transformers = { git = "https://github.com/huggingface/candle.git", package = "candle-transformers", features = [
# "metal",
# ] }
# candle-core = { version = "0.8", features = ["metal"] }
# # candle-nn = { version = "0.8", features = ["metal"] }
candle-hf-hub = "0.3.3"
Expand Down

0 comments on commit 9992f55

Please sign in to comment.