-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* draft post * add blog post and fix a few typs * update
- Loading branch information
Showing
3 changed files
with
220 additions
and
2 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
--- | ||
title: "LMPNN Featurizers" | ||
description: "Protein MPNN Featurizer" | ||
author: "Zachary Charlop-Powers" | ||
date: "2024-11-05" | ||
categories: [rust, ai, proteins] | ||
image: "images/protein_features.jpg" | ||
|
||
--- | ||
|
||
|
||
# Intro | ||
|
||
One of the motivating reasons for creating `ferritin` was to be able to recreate and run protein language | ||
models in the browser via WASM. This post describes a work-in-progress effort to port | ||
the [ligand-mpnn](https://github.com/dauparas/LigandMPNN) protein model from python/pytorch to Rust/[Candle](https://github.com/huggingface/candle). | ||
My general strategy was to nibble at the model from both ends: begin a naive translation of the Model, | ||
and begin to recreate in Rust the protein processing portion. | ||
|
||
# Coordinates --> Tensors | ||
|
||
Protein language models require converting proteins into compact `Tensor` representations of molecular features. | ||
I defined the `LMPNNFeatures` trait to define the functions that I need. These functions operate on `AtomCollections` | ||
and return `candle_core::Tensor`s. You can see the function signatures | ||
below as well as the `ProteinFeatures` struct which holds a set of | ||
these `Tensor`s and mimics the LigandMPNN dictionary. | ||
|
||
|
||
```rust | ||
/// Convert the AtomCollection into a struct that can be passed to a model. | ||
pub trait LMPNNFeatures { | ||
/// convert AtomCollection to a set of Features. | ||
fn featurize(&self, device: &Device) -> Result<ProteinFeatures>; | ||
/// Create a Tensor of dimensions [ <# residues>, 4 <N/CA/C/O> , 3 <xyz>] | ||
fn to_numeric_backbone_atoms(&self, device: &Device) -> Result<Tensor>; | ||
/// Create a Tensor of dimensions [ <# of residues>, 37, 3 <xyz>] | ||
fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor>; | ||
/// Create a 3 Tensors of dimensions [ <# heavy atoms>, 3 <xyz>] | ||
/// ( positions , elements, mask ) | ||
fn to_numeric_ligand_atoms(&self, device: &Device) | ||
-> Result<(Tensor, Tensor, Tensor)>; | ||
} | ||
|
||
/// ProteinFeatures | ||
/// Essential Tensors are Aligned Coordinates of Atoms by Atom Type | ||
pub struct ProteinFeatures { | ||
s: Tensor, // protein amino acids as i32 Tensor | ||
x: Tensor, // protein co-oords by residue [1, #res, 37, 4] | ||
y: Tensor, // ligand coords | ||
y_t: Tensor, // encoded ligand atom names | ||
... | ||
... | ||
} | ||
``` | ||
|
||
One nice ergonomic crate I came across was [strum](https://docs.rs/strum/latest/strum/). | ||
In order to generate a `Tensor` with aligned coordinates I would like to iterate | ||
through the atomtypes and the `EnumIter` lets me do that. Below you can see | ||
how that enum is used to build up preallocated vector which we populate and | ||
then reshape into a Tensor. Very nice. | ||
|
||
|
||
```rust | ||
/// AAAtom. This is where the different Atom types are registered. | ||
#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, EnumIter)] | ||
pub enum AAAtom { | ||
N = 0, CA = 1, C = 2, CB = 3, O = 4, | ||
CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9, | ||
SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14, | ||
ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19, | ||
CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24, | ||
NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29, | ||
NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34, | ||
NZ = 35, OXT = 36, | ||
Unknown = -1, | ||
} | ||
|
||
|
||
// by iterating over the ENUM we can extract all of our coordinate data as follows. | ||
impl LMPNNFeatures for AtomCollection { | ||
fn to_numeric_atom37(&self, device: &Device) -> Result<Tensor> { | ||
let res_count = self.iter_residues_aminoacid().count(); | ||
let mut atom37_data = vec![0f32; res_count * 37 * 3]; | ||
for residue in self.iter_residues_aminoacid() { | ||
let resid = residue.res_id as usize; | ||
for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) { | ||
if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) { | ||
let [x, y, z] = atom.coords; | ||
let base_idx = (resid * 37 + atom_type as usize) * 3; | ||
atom37_data[base_idx] = *x; | ||
atom37_data[base_idx + 1] = *y; | ||
atom37_data[base_idx + 2] = *z; | ||
} | ||
} | ||
} | ||
// Create tensor with shape [residues, 37, 3] | ||
Tensor::from_vec(atom37_data, (res_count, 37, 3), &device) | ||
} | ||
} | ||
``` | ||
|
||
|
||
# The LigandMPNN/ProteinMPNN Model | ||
|
||
The section above discusses the first part of the problem - converting | ||
the proteins to features. This section starts from the other direction | ||
which is to take the pytorch code and translate it into Candle code. | ||
|
||
As of Nov 5th this is still a work in progress but is not in terrible shape. | ||
Below I show a snippet of the high-level model where you can see there | ||
are a set of decoder and encoder layers and set of defaults. Internally, | ||
there are some tricky bits that are made potentially more tricky than | ||
needed due to the many ways that LigandMPNN can be invoked. This results | ||
in a lot of branches and conditions within the model and I had some | ||
trouble teasing out the model code from the data. But hopefully you | ||
can catch a glimpse of what the model should look like. | ||
|
||
|
||
```rust | ||
pub struct ProteinMPNN { | ||
config: ProteinMPNNConfig, // device here ?? | ||
decoder_layers: Vec<DecLayer>, | ||
device: Device, | ||
encoder_layers: Vec<EncLayer>, | ||
features: ProteinFeaturesModel, // this needs to be a model with weights etc | ||
w_e: Linear, | ||
w_out: Linear, | ||
w_s: Linear, | ||
} | ||
|
||
impl ProteinMPNN { | ||
pub fn new(config: ProteinMPNNConfig, vb: VarBuilder) {} | ||
fn predict(&self) {} | ||
fn train(&mut self) {} | ||
fn encode(&self, features: &ProteinFeatures) {} | ||
fn sample(&self, features: &ProteinFeatures) {} | ||
pub fn score(&self, features: &ProteinFeatures, use_sequence: bool) {} | ||
} | ||
|
||
// Model Params Contained here. | ||
// this use of a `Config` struct with | ||
// methods to supply the model-specific values is | ||
// how most of the models in candle_transformers get their | ||
// args. | ||
impl ProteinMPNNConfig { | ||
pub fn proteinmpnn() -> Self { | ||
Self { | ||
atom_context_num: 0, | ||
augment_eps: 0.0, | ||
dropout_ratio: 0.1, | ||
edge_features: 128, | ||
hidden_dim: 128, | ||
k_neighbors: 24, | ||
ligand_mpnn_use_side_chain_context: false, | ||
model_type: ModelTypes::ProteinMPNN, | ||
node_features: 128, | ||
num_decoder_layers: 3, | ||
num_encoder_layers: 3, | ||
num_letters: 48, | ||
num_rbf: 16, | ||
scale_factor: 30.0, | ||
vocab: 48, | ||
} | ||
} | ||
} | ||
``` | ||
|
||
|
||
# CLI and Invocation Configs | ||
|
||
There are a LOT of possible configurations that LigandMPNN supports. Based on aggregating | ||
the [example CLI options](https://github.com/dauparas/LigandMPNN/blob/main/run_examples.sh) | ||
the LigandMPNN CLI repo, I partitioned them into related structs - e.g. AABias, ResidueControl etc - | ||
and consolidated all of the configs into a common file. The complicated structure of this config is | ||
not great and I am not sure I am going to stick with it. But I do have the beginnings of a testing setup | ||
that can run CLI tests. At the moment there is only one `ferritin-featurizer featurize <input> <output>` | ||
which will take a PDB or CIF and create a featurized `safetensor` file. | ||
|
||
|
||
|
||
```rust | ||
//ligandmpnn/config.rs | ||
|
||
//! - `ModelTypes` - Enum of supported model architectures | ||
//! - `ProteinMPNNConfig` - Core model parameters | ||
//! - `AABiasConfig` - Amino acid biasing controls | ||
//! - `LigandMPNNConfig` - LigandMPNN specific settings | ||
//! - `MembraneMPNNConfig` - MembraneMPNN specific settings | ||
//! - `MultiPDBConfig` - Multi-PDB mode configuration | ||
//! - `ResidueControl` - Residue-level design controls | ||
//! - `RunConfig` - Runtime execution parameters | ||
|
||
pub struct RunConfig { | ||
pub temperature: Option<f32>, | ||
pub verbose: Option<i32>, | ||
pub save_stats: Option<i32>, | ||
pub batch_size: Option<i32>, | ||
pub number_of_batches: Option<i32>, | ||
pub file_ending: Option<String>, | ||
pub zero_indexed: Option<i32>, | ||
pub homo_oligomer: Option<i32>, | ||
pub fasta_seq_separation: Option<String>, | ||
} | ||
``` | ||
|
||
|
||
|
||
# Todo | ||
|
||
I am happy with the progress so far but the next portion of work needs to tie the pieces together into | ||
a coherent, usable model. Because this is my first foray into `Candle` and because Protein-MPNN is a graph | ||
network its not an easy first project to implement. Therefore, I think I am going to take a detour to | ||
implement a simpler model and then, once I have my bearings, to return to finish the job. | ||
|
||
I think the best candidate for the first complete model would be [Amplify](https://github.com/chandar-lab/AMPLIFY), | ||
a model that is touting its small size relative to incumbents; whose weights are already in `safetensor` format | ||
and available on HuggingFace; and whose model architecture looks to be more of a standard transformer-style model | ||
which matches Candle's strengths. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters