Skip to content

Commit

Permalink
PyTorch Compatibility (#41)
Browse files Browse the repository at this point in the history
* follow the rotary embedding shape

* more prints

* update amplify with a missing transpose/permite

* update permute

* get my initial for AMPLIFY going.
  • Loading branch information
zachcp authored Nov 14, 2024
1 parent 5a678d8 commit d51120f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 37 deletions.
28 changes: 14 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ferritin-featurizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ anyhow.workspace = true
candle-core = "0.8"
candle-nn = "0.8"
candle-transformers = "0.8"
clap = "4.5.20"
clap = "4.5.21"
ferritin-core = { path = "../ferritin-core" }
itertools.workspace = true
pdbtbx.workspace = true
Expand Down
60 changes: 43 additions & 17 deletions ferritin-featurizers/src/models/amplify/amplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
//! - Specialized architecture optimizations
//! - Memory efficient inference
use super::rotary::{apply_rotary_emb, precompute_freqs_cis};
use candle_core::{Module, Result, Tensor, D};
use candle_core::shape::Dim;
use candle_core::{Module, Result, Shape, Tensor, D};
use candle_nn::{
embedding, linear, linear_no_bias, ops::softmax, rms_norm, Activation, Dropout, Embedding,
Linear, RmsNorm, VarBuilder,
Expand Down Expand Up @@ -161,10 +162,14 @@ impl EncoderBlock {
// println!("EncoderBlock.forward(): FeedForward Norm");
let normed = self.ffn_norm.forward(&x)?;
// ffn.forward needs to do teh swiglu stesp with w12 and w3
// println!("EncoderBlock.forward(): FeedForward_forward");
println!("FFN_norm shape {:?}", normed.dims());

let ffn_output = self.ffn_forward(&normed)?;
println!("FFN_forward shape {:?}", ffn_output.dims());

// println!("EncoderBlock.forward(): FeedForward_dropout");
let ff = self.ffn_dropout.forward(&ffn_output, false)?; // Todo: pass in the Inference/Training bit
println!("FFN_dropout shape {:?}", ff.dims());
let x = x.add(&ff)?;
Ok((x, contacts))
}
Expand Down Expand Up @@ -253,7 +258,7 @@ impl EncoderBlock {
) -> Result<(Tensor, Option<Tensor>)> {
println!("AttentionBlock: commence");
println!(
"Input x shape, freqs_cis shape: {:?},{:?}",
"ATT Block: Input x shape, freqs_cis shape: {:?},{:?}",
x.dims(),
freqs_cis.dims()
);
Expand Down Expand Up @@ -323,15 +328,30 @@ impl EncoderBlock {
None
};
println!("calc attention...");

println!("ATTN CALC IN: xq: {:?}", xq.dims());
let xq_permute = xq.permute((0, 2, 1, 3))?;
let xk_permute = xk.permute((0, 2, 1, 3))?;
let xv_permute = xv.permute((0, 2, 1, 3))?;

println!("ATTN CALC IN: xq_permute: {:?}", xq_permute.dims());

let attn = self.scaled_dot_product_attention(
&xq,
&xk,
&xv,
&xq_permute,
&xk_permute,
&xv_permute,
pad_mask.as_ref(),
dropout_prob,
false,
)?;

println!("ATTENTION_pretranspose: {:?}", attn.dims());

// Missed this Transpose!
// `[batch, num_heads, seq_len, head_dim]` → `[batch, seq_len, num_heads, head_dim]`
let attn = attn.permute((0, 2, 1, 3))?;
println!("ATTENTION: {:?}", attn.dims());

let _attn = if output_attentions {
let xq_t = xq.permute((0, 2, 1, 3))?;
let xk_t = xk.permute((0, 2, 3, 1))?;
Expand All @@ -350,9 +370,13 @@ impl EncoderBlock {
seq_len,
self.config.num_attention_heads * self.d_head,
))?;
let output = self.wo.forward(&output)?;
let output = self.resid_dropout.forward(&output, false)?;
Ok((output, _attn))
println!("ATTENTION_reshaped: {:?}", output.dims());
let output01 = self.wo.forward(&output)?;

println!("ATTENTION_output: {:?}", output01.dims());
let output02 = self.resid_dropout.forward(&output01, false)?;
println!("ATTENTION_output_drop: {:?}", output02.dims());
Ok((output02, _attn))
}

/// Load Weights from a Model
Expand Down Expand Up @@ -442,23 +466,24 @@ impl AMPLIFY {
let mut hidden_states = vec![];
let mut attentions = vec![];

// println!(
// "AMPLIFY.forward(): Freq_CIS. Shape: {:?}",
// &self.freqs_cis.dims()
// );
println!(
"AMPLIFY.forward(): Freq_CIS. Shape: {:?}",
&self.freqs_cis.dims()
);

// Process attention mask if provided
// println!("AMPLIFY.forward(): creating attention mask");
println!("AMPLIFY.forward(): creating attention mask");

let attention_mask =
self.process_attention_mask(pad_mask, self.transformer_encoder.len() as i64)?;
// Get appropriate length of freqs_cis
// println!("AMPLIFY.forward(): creating freqs_cis mask");
let freqs_cis = self.freqs_cis.narrow(0, 0, src.dim(1)?)?;

// Embedding layer
// println!("AMPLIFY.forward(): creating encoder");
println!("AMPLIFY.forward(): creating encoder");
let mut x = self.encoder.forward(src)?;
// println!("X dims: {:?}", x.dims());
println!("X dims: {:?}", x.dims());
// Transform through encoder blocks
// println!("AMPLIFY.forward(): running through the transformer");
for layer in self.transformer_encoder.iter() {
Expand Down Expand Up @@ -523,8 +548,9 @@ impl AMPLIFY {
// let freqs_cis =
// precompute_freqs_cis(cfg.hidden_size / cfg.num_attention_heads, cfg.max_length)?;
let head_dim = cfg.hidden_size / cfg.num_attention_heads;

let freqs_cis = precompute_freqs_cis(head_dim, cfg.max_length)?;
// println!("AMPLIFY: Freq_CIS Initiated. Shape: {:?}", freqs_cis.dims());
println!("AMPLIFY: Freq_CIS Initiated. Shape: {:?}", freqs_cis.dims());
// println!("AMPLIFY: freqs_cis Created .");

Ok(Self {
Expand Down
6 changes: 5 additions & 1 deletion ferritin-featurizers/src/models/amplify/rotary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ pub fn precompute_freqs_cis(head_dim: usize, seq_len: usize) -> Result<Tensor> {
freqs_sin.dims()
);

Tensor::stack(&[freqs_cos, freqs_sin], D::Minus1)
let return_tensor = Tensor::stack(&[freqs_cos, freqs_sin], D::Minus1)?;

println!("Precomputed return Tensor: {:?}", return_tensor.dims());

Ok(return_tensor)
}

pub fn apply_rotary_emb(xq: &Tensor, xk: &Tensor, freqs_cis: &Tensor) -> Result<(Tensor, Tensor)> {
Expand Down
7 changes: 4 additions & 3 deletions ferritin-featurizers/tests/test_amplify_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ fn test_amplify_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let tokenizer = repo.get("tokenizer.json")?;
let protein_tokenizer = ProteinTokenizer::new(tokenizer)?;

let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL";
let pmatrix = protein_tokenizer.encode(&[sprot_01.to_string()], None, false, false)?;
// let sprot_01 = "MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQVECPKAPVEWNNPPSEKGLIVGHFSGIKYKGEKAQASEVDVNKMCCWVSKFKDAMRRYQGIQTCKIPGKVLSDLDAKIKAYNLTVEGVEGFVRYSRVTKQHVAAFLKELRHSKQYENVNLIHYILTDKRVDIQHLEKDLVKDFKALVESAHRMRQGHMINVKYILYQLLKKHGHGPDGPDILTVKTGSKGVLYDDSFRKIYTDLGWKFTPL";
let AMPLIFY_TEST_SEQ = "MSVVGIDLGFQSCYVAVARAGGIETIANEYSDRCTPACISFGPKNR";
let pmatrix = protein_tokenizer.encode(&[AMPLIFY_TEST_SEQ.to_string()], None, false, false)?;
let pmatrix = pmatrix.unsqueeze(0)?; // [batch, length] <- add batch of 1 in this case
let encoded = model.forward(&pmatrix, None, false, false)?;

Expand All @@ -41,7 +42,7 @@ fn test_amplify_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let indices: Vec<u32> = predictions.to_vec2()?[0].to_vec();
let decoded = protein_tokenizer.decode(indices.as_slice(), true)?;

assert_eq!(sprot_01, decoded.replace(" ", ""));
assert_eq!(AMPLIFY_TEST_SEQ, decoded.replace(" ", ""));

Ok(())
}
2 changes: 1 addition & 1 deletion ferritin-pymol/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description.workspace = true

[dependencies]
ferritin-molviewspec = { path = "../ferritin-molviewspec" }
clap = { version = "4.5.20", features = ["derive"] }
clap = { version = "4.5.21", features = ["derive"] }
serde-pickle = "1.1"
serde_bytes = { workspace = true }
serde = { workspace = true }
Expand Down

0 comments on commit d51120f

Please sign in to comment.