-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* stub blog post * remove metal for CI
- Loading branch information
Showing
6 changed files
with
139 additions
and
69 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
--- | ||
title: "Return to LigandMPNN" | ||
description: "Model updates and Metal Implementation" | ||
author: "Zachary Charlop-Powers" | ||
date: "2024-12-02" | ||
categories: [rust, ai, proteins] | ||
image: "images/metal_pre.png" | ||
--- | ||
|
||
|
||
|
||
# Returning to LigandMPNN | ||
|
||
The original motivating goal of this project was a pure-rust-WASM ProteinMPNN/LigandMPNN | ||
implementation for fast local design. In previous posts I discussed: | ||
|
||
1. Core Protein Data Structures for effecient protein representation using the Stuct-of-Arrays style. [link](https://zachcp.github.io/ferritin/posts/20241029_first_post/) | ||
2. The LigandMPNN Trait to define functions for extracting features from a protein representation. [link](https://zachcp.github.io/ferritin/posts/20241105_lmpnnfeaturizer/) | ||
3. The Candle implementation of the Amplify model (parts [1](https://zachcp.github.io/ferritin/posts/20241111_amplify/), [2](https://zachcp.github.io/ferritin/posts/20241114_amplify_02/), and [3](https://zachcp.github.io/ferritin/posts/20241115_amplify_03/)) | ||
|
||
|
||
Now that I was able to get my feet wet on a protein language model implmentation, I am ready to return to the more architecturally challenging problem of {Protein/Ligand}-MPNN. This post will descrive a few of the challenges faced | ||
in porting that library over and the current state of the model. | ||
|
||
# Issues to Solve. | ||
|
||
## Model Complexity | ||
|
||
There are a few differences between Candle and Pytorch in terms of how: | ||
|
||
- how they handle dimensions selection (pytorch: NUMPY-like; Candle: methods like `i`, `narrow`, `squeeze` and `unsqueeze`) | ||
- whether the Tensors can be mutated in place (pytorch:yes; Candle: no) | ||
- specifying matrix contiguity. (pytorch: ?; candle: `.contiguous()`) | ||
|
||
As a ballpark approximation we can take a look at a few of the function call types that handle those differences. I am | ||
calculating the occurences of function calls in the code bases [here](https://github.com/dauparas/LigandMPNN) and [here](https://huggingface.co/chandar-lab/AMPLIFY_120M) and showing | ||
the aggregated results in the table below. The LigandMPNN data is inflated because these calls also include the `sc.py` file which allows for side-cahin packing. Nonetheless, | ||
from the perspective of implementation, it should be clear that the potentially tricky implementation bits in LigandMPNN are far greater than in Amplify. | ||
|
||
```shell | ||
# ligandMPNN or hugginface AMPLIFY_120M dirs | ||
rg -c '\[' *py | ||
rg -c 'gather' *py | ||
rg -c 'scatter' *py | ||
rg -c 'mul' *py | ||
``` | ||
|
||
|
||
| Term | LigandMPNN | AMPLIFY_120M | | ||
|------|------------|--------------| | ||
| `[` | 910 | 12 | | ||
| `gather` | 55 | 0 | | ||
| `scatter` | 5 | 0 | | ||
| `mul` | 52 | 3 | | ||
|
||
|
||
## Model Loading | ||
|
||
In working with AMPLIFY, one of the key successes was being able to load Amplify's model into a [VarBuilder(https://docs.rs/candle-nn/latest/candle_nn/var_builder/type.VarBuilder.html) using the | ||
`from_mmaped_safetensors`. You can then build your model by accessing the Tensors by name. It allows you to match the pytorch model using layers that make sense by name. During this process, | ||
I noticed that there is a similar function for pytorch files - `from_pth`. Excellent! This gives me a new tool I lacked - the ability to load the model with the exact same | ||
names as the pytorch model. I should be able to laod this file and have all the Tensors match and use/account for all layers. It turned out there was a hiccup in that the PTH | ||
and safetensor formats differ a bit and you need to be able to access the pytorch statemap. I submitted a fix [here](https://github.com/huggingface/candle/pull/2639) and began | ||
assemble the model into my pre-existing code. This is where I began to run into a nubmer of issues related to Tensor dimension. | ||
|
||
|
||
## Dimension Matching | ||
|
||
As I began to load the Tensors in from the pytorch file, I began to hit errors introduced by the incompatible syntax mentioned above. In this case I would need to compare | ||
the pytorch code with my Rust code and 1) introduce the Candle syntax while 2) maintaining the flow/intention of the model. For this I leaned quite heavily on Claude/Sonnet3.5 via | ||
the Zed editor. This was an invaluable experience and further impressed me as to LLM capability. Here is a taste of Claude's explanatory power; full gist [here](https://gist.github.com/zachcp/45ae897bd0db389b6a288a99d25011bd) | ||
|
||
![](images/claude.png) | ||
|
||
|
||
|
||
## Speed | ||
|
||
After a bit of work I was able to load the model and `run` it where `run` means execute the model with an input and get an output without failing. My strategy had alwasy been to get it | ||
running then get it to pass tests so I was pretty pleased. However, the model took minutes to run! Not what I was looking for. So after pinging on the Candle Discord, I realized | ||
that I was on MacOS but that I had been using `Device::CPU`. What if we switch to `Device::Metal`? I had to rework the code a bit to get the Device passed in but then I hit a bunch of errors like: | ||
|
||
```rust | ||
// Metal doesn't support 64 bits! | ||
Err(WithBacktrace { inner: Msg("Metal strided to_dtype F64 F32 not implemented") | ||
// No gather calls on integers! | ||
Err(WithBacktrace { inner: Msg("Metal gather U32 I64 not implemented"), | ||
// No scatter-add on integers! | ||
Err(Metal(UnexpectedDType { msg: "scatter-add ids should be u8/u32/i64", expected: U32, got: U32 })) | ||
``` | ||
|
||
The first fix is to convert a number of F64/I64s to F32/U32. I then needed to track down and implement a few 2-line additions to Candle's Metal kernels that would allow the kernels | ||
to work. These were the PRs for [Gather](https://github.com/huggingface/candle/pull/2653) and [Scatter Add](https://github.com/huggingface/candle/pull/2656). The result was impressive. | ||
My inital model ran in 3 minutes; the new model in 8 seconds! Okay, we can work with that. | ||
|
||
|
||
|
||
```shell | ||
cargo instruments -t time \ | ||
--bin ferritin-featurizers \ | ||
-- run --seed 111 \ | ||
--pdb-path ferritin-test-data/data/structures/1bc8.cif \ | ||
--model-type protein_mpnn --out-folder testout | ||
``` | ||
|
||
|
||
|
||
|
||
:::: {.columns} | ||
|
||
::: {.column width="50%"} | ||
![](images/metal_pre.png) | ||
::: | ||
|
||
::: {.column width="50%"} | ||
![](images/metal_post.png) | ||
::: | ||
|
||
:::: | ||
|
||
|
||
## Testing Suite. | ||
|
||
I've started a [test suite](https://github.com/zachcp/ferritin/pull/59) to match LigandMPNNs and have begun implementing the CLI code for it. As of Today, December 2 there is | ||
not much to show. But I am satisfied with where the project has gotten and am impressed by Justas Dauparas and his colloaborators on this implementation. There are still some hard bits ahead. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters