Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small Code Refactor #73

Merged
merged 2 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ members = [
"ferritin-molviewspec",
"ferritin-pymol",
"ferritin-test-data",
"ferritin-wasm-examples/*",
"ferritin-wasm-examples/amplify",
]
resolver = "2"

Expand Down
64 changes: 1 addition & 63 deletions ferritin-amplify/src/amplify/amplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
use super::config::AMPLIFYConfig;
use super::encoder::EncoderBlock;
use super::rotary::precompute_freqs_cis;
use super::outputs::ModelOutput;
use candle_core::{Device, Module, Result, Tensor, D};
use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -133,66 +134,3 @@ impl AMPLIFY {
.map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))
}
}

// Helper structs and enums
#[derive(Debug)]
/// Amplify Model Output
///
/// logits, hidden states, and attentions.
///
/// logits -> distribution of the sequences.
/// attentions -> contact map
pub struct ModelOutput {
pub logits: Tensor,
pub hidden_states: Option<Vec<Tensor>>,
pub attentions: Option<Vec<Tensor>>,
}

impl ModelOutput {
/// "Perform average product correct, used for contact prediction."
/// https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/examples/utils.py#L83
/// "Perform average product correct, used for contact prediction."
fn apc(&self, x: &Tensor) -> Result<Tensor> {
let a1 = x.sum_keepdim(D::Minus1)?;
let a2 = x.sum_keepdim(D::Minus2)?;
let a12 = x.sum_keepdim((D::Minus1, D::Minus2))?;
let avg = a1.matmul(&a2)?;
// Divide by a12 (equivalent to pytorch's div_)
// println!("IN the APC: avg, a12 {:?}, {:?}", avg, a12);
// let avg = avg.div(&a12)?;
let a12_broadcast = a12.broadcast_as(avg.shape())?;
let avg = avg.div(&a12_broadcast)?;
x.sub(&avg)
}
// From https://github.com/facebookresearch/esm/blob/main/esm/modules.py
// https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/examples/utils.py#L77
// "Make layer symmetric in final two dimensions, used for contact prediction."
fn symmetrize(&self, x: &Tensor) -> Result<Tensor> {
let x_transpose = x.transpose(D::Minus1, D::Minus2)?;
x.add(&x_transpose)
}
/// Contact maps can be obtained from the self-attentions
pub fn get_contact_map(&self) -> Result<Option<Tensor>> {
let Some(attentions) = &self.attentions else {
return Ok(None);
};
// we need the dimensions to reshape below.
// the attention blocks have the following shape
let (_1, _n_head, _seq_length, seq_length) = attentions.first().unwrap().dims4()?;
let last_dim = seq_length;
let attn_stacked = Tensor::stack(attentions, 0)?;
let total_elements = attn_stacked.dims().iter().product::<usize>();
let first_dim = total_elements / (last_dim * last_dim);
let attn_map_combined2 = attn_stacked.reshape(&[first_dim, last_dim, last_dim])?;

// In PyTorch: attn_map = attn_map[:, 1:-1, 1:-1]
let attn_map_combined2 = attn_map_combined2
.narrow(1, 1, attn_map_combined2.dim(1)? - 2)? // second dim
.narrow(2, 1, attn_map_combined2.dim(2)? - 2)?; // third dim
let symmetric = self.symmetrize(&attn_map_combined2)?;
let normalized = self.apc(&symmetric)?;
let proximity_map = normalized.permute((1, 2, 0))?; // # (residues, residues, map)

Ok(Some(proximity_map))
}
}
1 change: 1 addition & 0 deletions ferritin-amplify/src/amplify/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
pub mod amplify;
pub mod config;
pub mod encoder;
pub mod outputs;
pub mod rotary;
// pub mod loss;
64 changes: 64 additions & 0 deletions ferritin-amplify/src/amplify/outputs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use candle_core::{Result, Tensor, D};

// Helper structs and enums
#[derive(Debug)]
/// Amplify Model Output
///
/// logits, hidden states, and attentions.
///
/// logits -> distribution of the sequences.
/// attentions -> contact map
pub struct ModelOutput {
pub logits: Tensor,
pub hidden_states: Option<Vec<Tensor>>,
pub attentions: Option<Vec<Tensor>>,
}

impl ModelOutput {
/// "Perform average product correct, used for contact prediction."
/// https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/examples/utils.py#L83
/// "Perform average product correct, used for contact prediction."
fn apc(&self, x: &Tensor) -> Result<Tensor> {
let a1 = x.sum_keepdim(D::Minus1)?;
let a2 = x.sum_keepdim(D::Minus2)?;
let a12 = x.sum_keepdim((D::Minus1, D::Minus2))?;
let avg = a1.matmul(&a2)?;
// Divide by a12 (equivalent to pytorch's div_)
// println!("IN the APC: avg, a12 {:?}, {:?}", avg, a12);
// let avg = avg.div(&a12)?;
let a12_broadcast = a12.broadcast_as(avg.shape())?;
let avg = avg.div(&a12_broadcast)?;
x.sub(&avg)
}
// From https://github.com/facebookresearch/esm/blob/main/esm/modules.py
// https://github.com/chandar-lab/AMPLIFY/blob/rc-0.1/examples/utils.py#L77
// "Make layer symmetric in final two dimensions, used for contact prediction."
fn symmetrize(&self, x: &Tensor) -> Result<Tensor> {
let x_transpose = x.transpose(D::Minus1, D::Minus2)?;
x.add(&x_transpose)
}
/// Contact maps can be obtained from the self-attentions
pub fn get_contact_map(&self) -> Result<Option<Tensor>> {
let Some(attentions) = &self.attentions else {
return Ok(None);
};
// we need the dimensions to reshape below.
// the attention blocks have the following shape
let (_1, _n_head, _seq_length, seq_length) = attentions.first().unwrap().dims4()?;
let last_dim = seq_length;
let attn_stacked = Tensor::stack(attentions, 0)?;
let total_elements = attn_stacked.dims().iter().product::<usize>();
let first_dim = total_elements / (last_dim * last_dim);
let attn_map_combined2 = attn_stacked.reshape(&[first_dim, last_dim, last_dim])?;

// In PyTorch: attn_map = attn_map[:, 1:-1, 1:-1]
let attn_map_combined2 = attn_map_combined2
.narrow(1, 1, attn_map_combined2.dim(1)? - 2)? // second dim
.narrow(2, 1, attn_map_combined2.dim(2)? - 2)?; // third dim
let symmetric = self.symmetrize(&attn_map_combined2)?;
let normalized = self.apc(&symmetric)?;
let proximity_map = normalized.permute((1, 2, 0))?; // # (residues, residues, map)

Ok(Some(proximity_map))
}
}
Loading