Skip to content

Commit

Permalink
ESMC Init Part II (#67)
Browse files Browse the repository at this point in the history
* fix imports

* allow ESMCConfig clone

* update regression_head, sequential,  rotary

* get the main.rs going
  • Loading branch information
zachcp authored Dec 6, 2024
1 parent 7b99766 commit c3b7013
Show file tree
Hide file tree
Showing 20 changed files with 572 additions and 250 deletions.
20 changes: 10 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

251 changes: 250 additions & 1 deletion ferritin-esm/Readme.md

Large diffs are not rendered by default.

28 changes: 11 additions & 17 deletions ferritin-esm/examples/esmc/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use candle_core::pickle::PthTensors;
use candle_core::{DType, Device, D};
use candle_hf_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
// use ferritin_esm::{AMPLIFYConfig, ProteinTokenizer, AMPLIFY};
use safetensors::SafeTensors;
use ferritin_esm::{ESMCConfig, ESMC};

// pub fn esmc_300m_202412(device: &Device) -> Result<Box<dyn Model>> {
// let tokenizer = get_model_tokenizers(ESM3_OPEN_SMALL)?.sequence;
Expand Down Expand Up @@ -35,23 +34,18 @@ fn main() -> Result<()> {
for (name, tensor) in pth.tensor_infos() {
println!("{}: {:?}", name, tensor);
}
// let vb = VarBuilder::from_backend(Box::new(pth), DType::F32, Device::Cpu);

// def ESMC_300M_202412(device: torch.device | str = "cpu"):
// with torch.device(device):
// model = ESMC(
// d_model=960,
// n_heads=15,
// n_layers=30,
// tokenizer=get_model_tokenizers(ESM3_OPEN_SMALL).sequence,
// ).eval()
// state_dict = torch.load(
// data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
// map_location=device,
// )
// model.load_state_dict(state_dict)
let vb = VarBuilder::from_backend(Box::new(pth), DType::F32, Device::Cpu);
let config = ESMCConfig::esmc_300m();
let esmc = ESMC::load(vb.clone(), config)?;
// println!("ESMC Loaded: {}", esmc);

// return model
// Error: cannot find tensor transformer.layer.attention.layer_norm.weight

println!(
"VB: {}",
vb.contains_tensor("transformer.blocks.6.attn.layernorm_qkv.1.weight")
);

Ok(())
}
33 changes: 17 additions & 16 deletions ferritin-esm/src/esm/layers/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,32 @@ impl MultiHeadAttention {
// rotary: RotaryEmbedding::new(d_model / n_heads)?,
// })
// }
pub fn load(vb: VarBuilder, config: ESMCConfig) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
let ESMCConfig {
d_model,
expansion_ratio,
n_heads,
bias,
..
d_model, n_heads, ..
} = config;

let d_head = d_model / n_heads;
let ln_conf = LayerNormConfig::from(1e-5);
let layernorm = nn::layer_norm(d_model, ln_conf, vb.pp("layer_norm"))?;
let linear = nn::linear(d_model, d_model * 3, vb.pp("linear1"))?;
// let ln_conf = LayerNormConfig::from(1e-5);
let ln_conf = LayerNormConfig {
eps: 1e-5,
remove_mean: true,
affine: false,
};
let layernorm = nn::layer_norm(*d_model, ln_conf, vb.pp("layernorm_qkv.0"))?;
let linear = nn::linear_no_bias(*d_model, d_model * 3, vb.pp("layernorm_qkv.1"))?;
let layernorm_qkv = nn::seq().add(layernorm).add(linear);
let out_proj = nn::linear(d_model, d_model, vb.pp("out_proj"))?;

let out_proj = nn::linear_no_bias(*d_model, *d_model, vb.pp("out_proj"))?;
// note: only handling the True case for the moment
// let qk_layernorm = true
let q_ln = Box::new(nn::layer_norm(d_model, ln_conf, vb.pp("q_ln"))?);
let k_ln = Box::new(nn::layer_norm(d_model, ln_conf, vb.pp("k_ln"))?);
let rotary: RotaryEmbedding::load(vb.pp("rotary"), config)?;
let q_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("q_ln"))?);
let k_ln = Box::new(nn::layer_norm(*d_model, ln_conf, vb.pp("k_ln"))?);

let rotary = RotaryEmbedding::load(vb.pp("rotary"), config)?;

Ok(Self {
d_model,
n_heads,
d_model: *d_model,
n_heads: *n_heads,
d_head,
layernorm_qkv,
out_proj,
Expand Down
59 changes: 28 additions & 31 deletions ferritin-esm/src/esm/layers/blocks.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::attention::MultiHeadAttention;
use super::geom_attention::GeometricReasoningOriginalImpl;
use crate::esm::models::esmc::{ESMCConfig, Ffn_Type};
use crate::esm::utils::structure::affine3d::Affine3D;
// use crate::esm::utils::structure::affine3d::Affine3D;
use candle_core::{Module, Result, Tensor, D};
use candle_nn::ops::silu;
use candle_nn::{self as nn, VarBuilder};
Expand All @@ -18,20 +18,19 @@ impl SwiGLU {
((expansion_ratio * d_model as f64 + 255.0) / 256.0).floor() as usize * 256
}

pub fn load(vb: VarBuilder, config: ESMCConfig) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
let ESMCConfig {
d_model,
expansion_ratio,
bias,
..
} = config;

let hidden_dim = Self::swiglu_correction_fn(expansion_ratio, d_model);
let hidden_dim = Self::swiglu_correction_fn(*expansion_ratio, *d_model);

Ok(Self {
layer_norm: nn::layer_norm(d_model, 1e-5, vb.pp("layer_norm"))?,
linear1: nn::linear(d_model, hidden_dim * 2, vb.pp("linear1"))?,
linear2: nn::linear(hidden_dim, d_model, vb.pp("linear2"))?,
layer_norm: nn::layer_norm(*d_model, 1e-5, vb.pp("0"))?,
linear1: nn::linear_no_bias(*d_model, hidden_dim * 2, vb.pp("1"))?,
linear2: nn::linear_no_bias(hidden_dim, *d_model, vb.pp("3"))?,
})
}
}
Expand Down Expand Up @@ -128,49 +127,47 @@ impl UnifiedTransformerBlock {
// scaling_factor: residue_scaling_factor,
// })
// }
pub fn load(vb: VarBuilder, config: ESMCConfig, layer: usize) -> Self {
pub fn load(vb: VarBuilder, config: &ESMCConfig, layer: usize) -> Result<Self> {
let ESMCConfig {
d_model,
n_heads,
n_layers,
v_head_transformer,
ffn_type,
tokenizer,
v_head_transformer,
use_plain_attn,
n_layers_geom,
scale_residue,
residue_scaling_factor,
mask_and_zero_frameless,
bias,
qk_layernorm,
expansion_ratio,
..
} = config;

let use_geom_attn: bool = layer < n_layers_geom;

let attn = match use_plain_attn {
false => None,
true => Some(MultiHeadAttention::load(vb, config)),
true => Some(MultiHeadAttention::load(vb.pp("attn"), config)?),
};

let geom_attn = match use_geom_attn {
false => None,
true => Some(GeometricReasoningOriginalImpl::load(vb, config)?),
};
// println!("LAYER; GEOM: {}, {}", layer, n_layers_geom);
let use_geom_attn: bool = layer < *n_layers_geom;
// println!("Geom ATTN {}", use_geom_attn);
// let geom_attn = match use_geom_attn {
// false => None,
// true => Some(GeometricReasoningOriginalImpl::load(
// vb.pp("geometric"),
// config,
// )?),
// };

let geom_attn = None;

let ffn = match ffn_type {
Ffn_Type::SWIGLU => SwiGLU::load(vb, config),
Ffn_Type::SWIGLU => SwiGLU::load(vb.pp("ffn"), config)?,
_ => unimplemented!(), // Ffn_Type::GLU => unimplemented!(),
};

Self {
use_plain_attn,
Ok(Self {
use_plain_attn: *use_plain_attn,
attn,
use_geom_attn,
geom_attn,
ffn: ffn.unwrap(),
scaling_factor: residue_scaling_factor,
}
ffn,
scaling_factor: *residue_scaling_factor,
})
}
}

Expand Down
16 changes: 9 additions & 7 deletions ferritin-esm/src/esm/layers/geom_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl GeometricReasoningOriginalImpl {
// rotation_scale_per_head: Tensor::zeros((v_heads,), device)?,
// })
// }
pub fn load(vb: VarBuilder, config: ESMCConfig) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &ESMCConfig) -> Result<Self> {
let ESMCConfig {
d_model,
v_head_transformer,
Expand All @@ -48,24 +48,26 @@ impl GeometricReasoningOriginalImpl {
} = config;

let num_vector_messages = 1usize;
let v_heads = v_head_transformer.unwrap();

// todo: this is a hidden param. Needs to be fixed
let v_heads = v_head_transformer.unwrap_or(128);

let dim_proj = 4 * v_heads * 3 + v_heads * 3 * num_vector_messages;
let channels_out = v_heads * 3 * num_vector_messages;

let ln_conf = LayerNormConfig::from(1e-5);
let s_norm = nn::layer_norm(d_model, ln_conf, vb.pp("layer_norm"))?;
let s_norm = nn::layer_norm(*d_model, ln_conf, vb.pp("layer_norm"))?;

let proj = nn::linear(d_model, dim_proj, vb.pp("linear1"))?;
let out_proj = nn::linear(channels_out, d_model, vb.pp("outproj"))?;
let proj = nn::linear(*d_model, dim_proj, vb.pp("linear1"))?;
let out_proj = nn::linear(channels_out, *d_model, vb.pp("outproj"))?;
let distance_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;
let rotation_scale_per_head = Tensor::zeros((v_heads,), vb.dtype(), vb.device())?;

Ok(Self {
c_s: d_model as usize,
c_s: *d_model as usize,
v_heads: v_heads as usize,
num_vector_messages,
mask_and_zero_frameless,
mask_and_zero_frameless: *mask_and_zero_frameless,
s_norm,
proj,
out_proj,
Expand Down
43 changes: 33 additions & 10 deletions ferritin-esm/src/esm/layers/regression_head.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,43 @@
use candle_nn::{Linear, Module, Sequential};
use crate::esm::models::esmc::ESMCConfig;
use candle_core::Tensor;
use candle_nn::{self as nn, LayerNormConfig, Module, Sequential, VarBuilder};

pub struct RegressionHead {
model: Sequential,
}

impl RegressionHead {
pub fn new(d_model: usize, output_dim: usize, hidden_dim: Option<usize>) -> candle_core::Result<Self> {
let hidden_dim = hidden_dim.unwrap_or(d_model);

let model = Sequential::new(vec![
Linear::new(d_model as usize, hidden_dim as usize)?.into(),
candle_nn::Activation::Gelu.into(),
candle_nn::LayerNorm::new(vec![hidden_dim])?.into(),
Linear::new(hidden_dim as usize, output_dim as usize)?.into(),
]);
// pub fn new(d_model: usize, output_dim: usize, hidden_dim: Option<usize>) -> candle_core::Result<Self> {
// let hidden_dim = hidden_dim.unwrap_or(d_model);

// let model = Sequential::new(vec![
// Linear::new(d_model as usize, hidden_dim as usize)?.into(),
// candle_nn::Activation::Gelu.into(),
// candle_nn::LayerNorm::new(vec![hidden_dim])?.into(),
// Linear::new(hidden_dim as usize, output_dim as usize)?.into(),
// ]);

// Ok(Self { model })
// }
pub fn load(vb: VarBuilder, config: &ESMCConfig) -> candle_core::Result<Self> {
let ESMCConfig {
d_model,
regression_head_output_dim,
regression_head_hidden_dim,
..
} = config;

let linear1 = nn::linear(*d_model, *regression_head_hidden_dim, vb.pp("0"))?;
let gelu = candle_nn::Activation::Gelu;
let ln_conf = LayerNormConfig::from(1e-5);
let norm = nn::layer_norm(*regression_head_hidden_dim, ln_conf, vb.pp("2"))?;
let linear2 = nn::linear(
*regression_head_hidden_dim,
*regression_head_output_dim,
vb.pp("3"),
)?;

let model = nn::seq().add(linear1).add(gelu).add(norm).add(linear2);

Ok(Self { model })
}
Expand Down
Loading

0 comments on commit c3b7013

Please sign in to comment.