-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* updating AMPLIFY * remove the ProteinTokenizer in favor of the upstream Toeknizer class * make wasm compatible * Refactor ferritin-amplify structs to their own files * allow passing of 120M or 350M * example into just
- Loading branch information
Showing
15 changed files
with
607 additions
and
838 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
use candle_nn::Activation; | ||
use serde::Deserialize; | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
/// Configuration Struct for AMPLIFY | ||
/// | ||
/// Currently only holds the weight params for | ||
/// those models found on GH: the 120M and 350M models. | ||
/// | ||
pub struct AMPLIFYConfig { | ||
pub hidden_size: usize, | ||
pub num_hidden_layers: usize, | ||
pub num_attention_heads: usize, | ||
pub intermediate_size: usize, | ||
pub dropout_prob: f64, | ||
pub embedding_init_range: f64, | ||
pub decoder_init_range: f64, | ||
pub rms_norm: bool, | ||
pub norm_eps: f64, | ||
pub hidden_act: Activation, | ||
pub layer_norm_after_embedding: bool, | ||
pub layer_norm_before_last_layer: bool, | ||
pub vocab_size: usize, | ||
pub ffn_bias: bool, | ||
pub att_bias: bool, | ||
pub pad_token_id: usize, | ||
pub max_length: usize, | ||
} | ||
|
||
impl Default for AMPLIFYConfig { | ||
fn default() -> Self { | ||
AMPLIFYConfig::amp_120m() | ||
} | ||
} | ||
impl AMPLIFYConfig { | ||
pub fn amp_120m() -> Self { | ||
Self { | ||
hidden_size: 640, | ||
num_hidden_layers: 24, | ||
num_attention_heads: 10, | ||
intermediate_size: 2560, | ||
dropout_prob: 0.0, | ||
embedding_init_range: 0.02, | ||
decoder_init_range: 0.02, | ||
rms_norm: true, | ||
norm_eps: 1e-5, | ||
hidden_act: Activation::Swiglu, | ||
layer_norm_after_embedding: false, | ||
layer_norm_before_last_layer: true, | ||
vocab_size: 27, | ||
ffn_bias: false, | ||
att_bias: false, | ||
pad_token_id: 0, | ||
max_length: 2048, | ||
} | ||
} | ||
pub fn amp_350m() -> Self { | ||
Self { | ||
hidden_size: 960, | ||
num_hidden_layers: 32, | ||
num_attention_heads: 15, | ||
intermediate_size: 3840, | ||
dropout_prob: 0.0, | ||
embedding_init_range: 0.02, | ||
decoder_init_range: 0.02, | ||
rms_norm: true, | ||
norm_eps: 1e-5, | ||
hidden_act: Activation::Swiglu, | ||
layer_norm_after_embedding: false, | ||
layer_norm_before_last_layer: true, | ||
vocab_size: 27, | ||
ffn_bias: false, | ||
att_bias: false, | ||
pad_token_id: 0, | ||
max_length: 2048, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
use super::config::AMPLIFYConfig; | ||
use super::rotary::apply_rotary_emb; | ||
use candle_core::{Module, Result, Tensor, D}; | ||
use candle_nn::{ | ||
linear, linear_no_bias, ops::softmax_last_dim, rms_norm, Dropout, Linear, RmsNorm, VarBuilder, | ||
}; | ||
|
||
/// An encoder block in the AMPLIFY transformer architecture. | ||
/// | ||
/// This implements a standard transformer encoder block with: | ||
/// - Multi-head self-attention with rotary positional embeddings | ||
/// - Feed-forward network with SwiGLU activation | ||
/// - RMSNorm for layer normalization | ||
/// | ||
/// # Arguments | ||
/// * `config` - Configuration parameters for the model | ||
/// * `vb` - Variable builder for loading weights | ||
/// * `layer` - Layer index in the transformer stack | ||
/// | ||
/// - [T5](https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/t5.rs#L331) | ||
/// - [distilbert](https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/distilbert.rs#L198) | ||
/// - [glm4](https://github.com/huggingface/candle/blob/e2b6b367fa852ed30ac532f8d77cd8479c7ed092/candle-transformers/src/models/glm4.rs#L340) | ||
/// - [SwiGLu Implementation](https://github.com/facebookresearch/xformers/blob/main/xformers/ops/swiglu_op.py#L462) | ||
#[derive(Debug)] | ||
pub struct EncoderBlock { | ||
q: Linear, | ||
k: Linear, | ||
v: Linear, | ||
wo: Linear, | ||
resid_dropout: Dropout, | ||
w12: Linear, | ||
w3: Linear, | ||
ffn_norm: RmsNorm, | ||
attention_norm: RmsNorm, | ||
ffn_dropout: Dropout, | ||
d_head: usize, | ||
config: AMPLIFYConfig, | ||
} | ||
|
||
impl EncoderBlock { | ||
// pub fn new(config: &LIFYConfig, vb: VarBuilder, layer: i32) -> Result<Self> { | ||
// let multiple_of = 8; | ||
// let intermediate_size = (config.intermediate_size * 2) / 3; | ||
// let intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) / multiple_of); | ||
// let vb = vb.pp(layer); | ||
// let q = linear(config.hidden_size, config.hidden_size, vb.pp("q"))?; | ||
// let k = linear(config.hidden_size, config.hidden_size, vb.pp("k"))?; | ||
// let v = linear(config.hidden_size, config.hidden_size, vb.pp("v"))?; | ||
// let wo = linear(config.hidden_size, config.hidden_size, vb.pp("wo"))?; | ||
// let w12 = linear_no_bias(intermediate_size * 2, config.hidden_size, vb.pp("ffn.w12"))?; | ||
// let w3 = linear_no_bias(config.hidden_size, intermediate_size, vb.pp("ffn.w3"))?; | ||
// let ffn_norm = rms_norm(config.hidden_size, config.norm_eps, vb.pp("ffn_norm"))?; | ||
// let attention_norm = | ||
// rms_norm(config.hidden_size, config.norm_eps, vb.pp("attention_norm"))?; | ||
|
||
// Ok(Self { | ||
// q, | ||
// k, | ||
// v, | ||
// wo, | ||
// resid_dropout: Dropout::new(config.dropout_prob as f32), | ||
// w12, | ||
// w3, | ||
// attention_norm, | ||
// ffn_norm, | ||
// ffn_dropout: Dropout::new(config.dropout_prob as f32), | ||
// d_head: config.hidden_size / config.num_attention_heads, | ||
// config: config.clone(), // Todo: remove this clone | ||
// }) | ||
// } | ||
pub fn forward( | ||
&self, | ||
x: &Tensor, | ||
pad_mask: Option<&Tensor>, | ||
freqs_cis: &Tensor, | ||
output_attentions: bool, | ||
) -> Result<(Tensor, Option<Tensor>)> { | ||
let normed = self.attention_norm.forward(x)?; | ||
let (attn, contacts) = | ||
self.attention_block(&normed, pad_mask, freqs_cis, output_attentions)?; | ||
let x = x.add(&attn)?; | ||
let normed = self.ffn_norm.forward(&x)?; | ||
let ffn_output = self.ffn_forward(&normed)?; | ||
let ff = self.ffn_dropout.forward(&ffn_output, false)?; // Todo: pass in the Inference/Training bit | ||
let x = x.add(&ff)?; | ||
Ok((x, contacts)) | ||
} | ||
/// process the FFN Block using swiglu | ||
fn ffn_forward(&self, x: &Tensor) -> Result<Tensor> { | ||
// Todo: see if the apply or add can be done di | ||
// Store original batch dimensions | ||
let dims = x.dims(); | ||
let batch_shape = &dims[..dims.len() - 1]; | ||
// Reshape input to 2D: (batch_size, input_dim) | ||
let x_flat = self.flatten_last_dim(&x)?; | ||
// Apply packed W1W2 linear transformation | ||
let w12_out = self.w12.forward(&x_flat)?; | ||
// Split the output into two halves (for SwiGLU activation) | ||
let chunks = w12_out.chunk(2, 1)?; | ||
let x1 = &chunks[0]; | ||
let x2 = &chunks[1]; | ||
// Apply SwiGLU: silu(x1) * x2 | ||
let hidden = x1.silu()?.mul(x2)?; | ||
// Final linear transformation | ||
let output = self.w3.forward(&hidden)?; | ||
// Reshape back to original batch dimensions | ||
let mut new_shape = batch_shape.to_vec(); | ||
new_shape.push(output.dim(1)?); | ||
output.reshape(new_shape) | ||
} | ||
fn flatten_last_dim(&self, x: &Tensor) -> Result<Tensor> { | ||
let dims = x.dims(); | ||
let last_dim = dims[dims.len() - 1]; | ||
let total_elements = dims.iter().product::<usize>(); | ||
let first_dim = total_elements / last_dim; | ||
x.reshape((first_dim, last_dim)) | ||
} | ||
fn scaled_dot_product_attention( | ||
&self, | ||
query: &Tensor, | ||
key: &Tensor, | ||
value: &Tensor, | ||
attn_mask: Option<&Tensor>, | ||
dropout_p: f64, | ||
is_causal: bool, | ||
) -> Result<Tensor> { | ||
// Calculate attention scores | ||
let d_k = key.dim(key.dims().len() - 1)? as f64; | ||
let scaling = 1.0 / d_k.sqrt(); | ||
// (B, H, L, S) = (batch, heads, query_length, key_length) | ||
let scores = (query.matmul(&key.transpose(D::Minus2, D::Minus1)?)? * scaling)?; | ||
|
||
// Apply mask if provided | ||
if let Some(mask) = attn_mask { | ||
let scores = scores.add(mask)?; | ||
} | ||
// Apply softmax | ||
let attn = softmax_last_dim(&scores)?; | ||
|
||
// Apply dropout if needed | ||
let attn = if dropout_p > 0.0 { | ||
candle_nn::ops::dropout(&attn, dropout_p as f32)? | ||
} else { | ||
attn | ||
}; | ||
// Final matrix multiplication with values | ||
attn.matmul(value) | ||
} | ||
fn attention_block( | ||
&self, | ||
x: &Tensor, | ||
pad_mask: Option<&Tensor>, | ||
freqs_cis: &Tensor, | ||
output_attentions: bool, | ||
) -> Result<(Tensor, Option<Tensor>)> { | ||
// Query, Key, Value projections | ||
let (batch_size, seq_len, _) = x.dims3()?; | ||
// [batch_size, seq_len, hidden_size] | ||
let xq = self.q.forward(x)?.contiguous()?; | ||
let xk = self.k.forward(x)?.contiguous()?; | ||
let xv = self.v.forward(x)?.contiguous()?; | ||
// Reshape for rotary embeddings | ||
let xq = xq.reshape(( | ||
batch_size, | ||
seq_len, | ||
self.config.num_attention_heads, | ||
self.d_head, | ||
))?; | ||
let xk = xk.reshape(( | ||
batch_size, | ||
seq_len, | ||
self.config.num_attention_heads, | ||
self.d_head, | ||
))?; | ||
let xv = xv.reshape(( | ||
batch_size, | ||
seq_len, | ||
self.config.num_attention_heads, | ||
self.d_head, | ||
))?; | ||
let (xq, xk) = apply_rotary_emb(&xq, &xk, &freqs_cis)?; | ||
let dropout_prob = self.config.dropout_prob; | ||
|
||
// need to handle pad_mask better .... | ||
let pad_mask = if let Some(mask) = pad_mask { | ||
let (batch_size, seq_len) = (x.dim(0)?, x.dim(1)?); | ||
let num_heads = self.config.num_attention_heads; | ||
|
||
// Following PyTorch's implementation: | ||
// 1. unsqueeze twice to add head dimensions | ||
// 2. repeat to match attention matrix size | ||
let mask = mask | ||
.unsqueeze(1)? | ||
.unsqueeze(1)? | ||
.expand((batch_size, num_heads, seq_len, seq_len))?; // Expand to full attention size | ||
Some(mask) | ||
} else { | ||
None | ||
}; | ||
|
||
let attn = self.scaled_dot_product_attention( | ||
&xq.permute((0, 2, 1, 3))?.contiguous()?, | ||
&xk.permute((0, 2, 1, 3))?.contiguous()?, | ||
&xv.permute((0, 2, 1, 3))?.contiguous()?, | ||
pad_mask.as_ref(), | ||
dropout_prob, | ||
false, | ||
)?; | ||
|
||
// `[batch, num_heads, seq_len, head_dim]` → `[batch, seq_len, num_heads, head_dim]` | ||
let attn = attn.permute((0, 2, 1, 3))?; | ||
let _attn = if output_attentions { | ||
let xq_t = xq.permute((0, 2, 1, 3))?.contiguous()?; | ||
let xk_t = xk.permute((0, 2, 3, 1))?.contiguous()?; | ||
let mut attn_weights = xq_t.matmul(&xk_t)?; | ||
let scale = (xq.dim(D::Minus1)? as f64).sqrt(); | ||
attn_weights = (attn_weights / scale)?; | ||
// attn_weights = attn_weights.add(pad_mask)?; <- Todo. Revisit | ||
Some(softmax_last_dim(&attn_weights)?) | ||
} else { | ||
None | ||
}; | ||
|
||
// Final projection and dropout | ||
let output = attn.reshape(( | ||
batch_size, | ||
seq_len, | ||
self.config.num_attention_heads * self.d_head, | ||
))?; | ||
let output01 = self.wo.forward(&output)?; | ||
let output02 = self.resid_dropout.forward(&output01, false)?; | ||
Ok((output02, _attn)) | ||
} | ||
/// Load Weights from a Model | ||
pub fn load(vb: VarBuilder, config: &LIFYConfig, layer: i32) -> Result<Self> { | ||
// To keep the number of parameters and the amount of computation constant, we reduce the number of | ||
// hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to | ||
// avoid RuntimeError due to misaligned operand | ||
let multiple_of = 8; | ||
let intermediate_size = (config.intermediate_size * 2) / 3; | ||
let intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) / multiple_of); | ||
let vb = vb.pp(layer); // handle the layer nubmer here. | ||
let q = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("q"))?; | ||
let k = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("k"))?; | ||
let v = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("v"))?; | ||
let wo = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("wo"))?; | ||
let w12 = linear_no_bias(config.hidden_size, intermediate_size * 2, vb.pp("ffn.w12"))?; | ||
let w3 = linear_no_bias(intermediate_size, config.hidden_size, vb.pp("ffn.w3"))?; | ||
let ffn_norm = rms_norm(config.hidden_size, config.norm_eps, vb.pp("ffn_norm"))?; | ||
let attention_norm = | ||
rms_norm(config.hidden_size, config.norm_eps, vb.pp("attention_norm"))?; | ||
|
||
Ok(Self { | ||
q, | ||
k, | ||
v, | ||
wo, | ||
resid_dropout: Dropout::new(config.dropout_prob as f32), | ||
w12, | ||
w3, | ||
attention_norm, | ||
ffn_norm, | ||
ffn_dropout: Dropout::new(config.dropout_prob as f32), | ||
d_head: config.hidden_size / config.num_attention_heads, | ||
config: config.clone(), | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.