From f2d4aee4782a4565f08fb84118705e4d0255ca58 Mon Sep 17 00:00:00 2001 From: "Federico G. Schwindt" Date: Sun, 21 Jul 2024 00:03:31 +0100 Subject: [PATCH 1/2] Add masked_fill under Tensor --- candle-core/src/tensor.rs | 7 +++++++ candle-core/tests/tensor_tests.rs | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index dd1b44b0a0..73597f0913 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2454,6 +2454,13 @@ impl Tensor { pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { rhs.broadcast_mul(&self.log()?)?.exp() } + + pub fn masked_fill(&self, rhs: &Tensor, value: f32) -> Result { + rhs.where_cond( + &Tensor::new(value, self.device())?.broadcast_as(rhs.shape().dims())?, + self, + ) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca148..8b748cb139 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1345,3 +1345,21 @@ fn pow() -> Result<()> { ); Ok(()) } + +#[test] +fn masked_fill() -> Result<()> { + let lhs = Tensor::zeros((5, 5), DType::F32, &Device::Cpu)?; + let rhs = Tensor::eye(5, DType::I64, &Device::Cpu)?; + let res = lhs.masked_fill(&rhs, f32::NEG_INFINITY)?; + assert_eq!( + res.to_vec2::()?, + [ + [f32::NEG_INFINITY, 0.0, 0.0, 0.0, 0.0], + [0.0, f32::NEG_INFINITY, 0.0, 0.0, 0.0], + [0.0, 0.0, f32::NEG_INFINITY, 0.0, 0.0], + [0.0, 0.0, 0.0, f32::NEG_INFINITY, 0.0], + [0.0, 0.0, 0.0, 0.0, f32::NEG_INFINITY], + ] + ); + Ok(()) +} From 4b083de8e72a40291876a5b6112553be88c02670 Mon Sep 17 00:00:00 2001 From: "Federico G. Schwindt" Date: Sun, 21 Jul 2024 00:14:52 +0100 Subject: [PATCH 2/2] Update masked_fill usage --- candle-transformers/src/models/chatglm.rs | 10 +--------- candle-transformers/src/models/distilbert.rs | 11 +++-------- candle-transformers/src/models/falcon.rs | 13 +++---------- candle-transformers/src/models/llama.rs | 9 +-------- candle-transformers/src/models/llama2_c.rs | 9 +-------- candle-transformers/src/models/mpt.rs | 14 ++------------ candle-transformers/src/models/phi.rs | 10 +--------- candle-transformers/src/models/quantized_llama.rs | 13 +------------ .../src/models/quantized_llama2_c.rs | 9 +-------- .../src/models/quantized_mixformer.rs | 10 +--------- candle-transformers/src/models/quantized_mpt.rs | 7 ++----- candle-transformers/src/models/quantized_phi.rs | 11 +---------- candle-transformers/src/models/quantized_phi3.rs | 11 +---------- candle-transformers/src/models/quantized_qwen2.rs | 12 +----------- candle-transformers/src/models/quantized_t5.rs | 10 +--------- candle-transformers/src/models/t5.rs | 10 +--------- candle-wasm-examples/llama2-c/src/model.rs | 9 +-------- 17 files changed, 23 insertions(+), 155 deletions(-) diff --git a/candle-transformers/src/models/chatglm.rs b/candle-transformers/src/models/chatglm.rs index 0686b34ef3..4ed4801e59 100644 --- a/candle-transformers/src/models/chatglm.rs +++ b/candle-transformers/src/models/chatglm.rs @@ -107,13 +107,6 @@ struct CoreAttention { norm_factor: f64, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - impl CoreAttention { fn new(layer_number: usize, cfg: &Config) -> Result { let norm_factor = (cfg.kv_channels as f64).sqrt(); @@ -152,8 +145,7 @@ impl CoreAttention { Some(coeff) => (matmul_result * coeff)?, }; let attention_scores = match attention_mask { - Some(mask) => masked_fill( - &matmul_result, + Some(mask) => matmul_result.masked_fill( &mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?, f32::NEG_INFINITY, )?, diff --git a/candle-transformers/src/models/distilbert.rs b/candle-transformers/src/models/distilbert.rs index ea074c9782..1f534467a2 100644 --- a/candle-transformers/src/models/distilbert.rs +++ b/candle-transformers/src/models/distilbert.rs @@ -5,13 +5,6 @@ use serde::Deserialize; pub const DTYPE: DType = DType::F32; -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[serde(rename_all = "lowercase")] enum HiddenAct { @@ -180,7 +173,9 @@ impl MultiHeadSelfAttention { let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?; let mask = attention_mask.broadcast_as(scores.shape())?; - let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?; + let scores = scores + .to_dtype(DType::F32)? + .masked_fill(&mask, f32::NEG_INFINITY)?; let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?; let context = weights.matmul(&v.contiguous()?)?; diff --git a/candle-transformers/src/models/falcon.rs b/candle-transformers/src/models/falcon.rs index 3a3575aac2..92489f7c14 100644 --- a/candle-transformers/src/models/falcon.rs +++ b/candle-transformers/src/models/falcon.rs @@ -177,15 +177,6 @@ impl FalconRotaryEmbedding { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())? - .to_dtype(on_false.dtype())? - .broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone)] struct FalconAttention { query_key_value: Linear, @@ -298,7 +289,9 @@ impl FalconAttention { let attention_scores = match mask { None => attention_scores, Some(mask) => { - let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)? + let mask = mask + .to_dtype(DType::F32)? + .masked_fill(mask, -1e9)? .to_dtype(query.dtype())?; attention_scores.broadcast_add(&mask.squeeze(1)?)? } diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a1f43d35b8..c9196a9b55 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -255,7 +255,7 @@ impl CausalSelfAttention { att } else { let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - masked_fill(&att, &mask, f32::NEG_INFINITY)? + att.masked_fill(&mask, f32::NEG_INFINITY)? }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -295,13 +295,6 @@ impl CausalSelfAttention { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone)] struct Mlp { c_fc1: Linear, diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index bba8b66607..30a0e8085d 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -198,7 +198,7 @@ impl CausalSelfAttention { att } else { let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - masked_fill(&att, &mask, f32::NEG_INFINITY)? + att.masked_fill(&mask, f32::NEG_INFINITY)? }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -242,13 +242,6 @@ impl CausalSelfAttention { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone)] struct Mlp { c_fc1: Linear, diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index d46524fcc2..1e8eaeec0b 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -118,11 +118,8 @@ impl GroupedQueryAttention { let attn_weights = attn_weights.broadcast_add(&attn_bias)?; let attn_weights = match mask { None => attn_weights, - Some(mask) => masked_fill( - &attn_weights, - &mask.broadcast_as(attn_weights.shape())?, - f32::NEG_INFINITY, - )?, + Some(mask) => attn_weights + .masked_fill(&mask.broadcast_as(attn_weights.shape())?, f32::NEG_INFINITY)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_weights @@ -281,10 +278,3 @@ pub(crate) fn get_mask(size: usize, device: &Device) -> Result { .collect(); Tensor::from_slice(&mask, (size, size), device) } - -pub(crate) fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} diff --git a/candle-transformers/src/models/phi.rs b/candle-transformers/src/models/phi.rs index bffc14faed..11a91f619d 100644 --- a/candle-transformers/src/models/phi.rs +++ b/candle-transformers/src/models/phi.rs @@ -126,13 +126,6 @@ fn get_mask(size: usize, device: &Device) -> Result { Tensor::from_slice(&mask, (size, size), device) } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - impl Attention { fn new(cfg: &Config, vb: VarBuilder) -> Result { let num_heads = cfg.num_attention_heads; @@ -233,8 +226,7 @@ impl Attention { * self.softmax_scale)?; let attn_weights = match mask { None => attn_weights, - Some(mask) => masked_fill( - &attn_weights, + Some(mask) => attn_weights.masked_fill( &mask.broadcast_left((b_size, self.num_heads))?, f32::NEG_INFINITY, )?, diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe92..efbb523f0b 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -138,19 +138,12 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, - neg_inf: Tensor, kv_cache: Option<(Tensor, Tensor)>, span_attn: tracing::Span, span_rot: tracing::Span, span_mlp: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); @@ -214,7 +207,7 @@ impl LayerWeights { None => att, Some(mask) => { let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? + att.masked_fill(&mask, f32::NEG_INFINITY)? } }; let att = candle_nn::ops::softmax_last_dim(&att)?; @@ -260,7 +253,6 @@ impl ModelWeights { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; @@ -300,7 +292,6 @@ impl ModelWeights { head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, @@ -349,7 +340,6 @@ impl ModelWeights { .and_then(|m| m.to_f32()) .unwrap_or(10000f32); let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -420,7 +410,6 @@ impl ModelWeights { head_dim: embedding_length / head_count, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, diff --git a/candle-transformers/src/models/quantized_llama2_c.rs b/candle-transformers/src/models/quantized_llama2_c.rs index cbb8aad8da..ac8b321f30 100644 --- a/candle-transformers/src/models/quantized_llama2_c.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -75,7 +75,7 @@ impl CausalSelfAttention { att } else { let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; - masked_fill(&att, &mask, f32::NEG_INFINITY)? + att.masked_fill(&mask, f32::NEG_INFINITY)? }; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -119,13 +119,6 @@ impl CausalSelfAttention { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone)] struct Mlp { c_fc1: Linear, diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index fa72672a9e..0b8cdedb32 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -32,13 +32,6 @@ fn get_mask(size: usize, device: &Device) -> Result { Tensor::from_slice(&mask, (size, size), device) } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, @@ -219,8 +212,7 @@ impl MHA { // scores = scores + causal_mask.to(dtype=scores.dtype) let attn_weights = match mask { None => attn_weights, - Some(mask) => masked_fill( - &attn_weights, + Some(mask) => attn_weights.masked_fill( &mask.broadcast_left(b_size * self.n_head)?, f32::NEG_INFINITY, )?, diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs index 056fcac2d1..5845f8f481 100644 --- a/candle-transformers/src/models/quantized_mpt.rs +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -85,11 +85,8 @@ impl GroupedQueryAttention { let attn_weights = attn_weights.broadcast_add(&attn_bias)?; let attn_weights = match mask { None => attn_weights, - Some(mask) => super::mpt::masked_fill( - &attn_weights, - &mask.broadcast_as(attn_weights.shape())?, - f32::NEG_INFINITY, - )?, + Some(mask) => attn_weights + .masked_fill(&mask.broadcast_as(attn_weights.shape())?, f32::NEG_INFINITY)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; let attn_output = attn_weights diff --git a/candle-transformers/src/models/quantized_phi.rs b/candle-transformers/src/models/quantized_phi.rs index 0ebf7f4d4b..d5acdefee7 100644 --- a/candle-transformers/src/models/quantized_phi.rs +++ b/candle-transformers/src/models/quantized_phi.rs @@ -61,18 +61,11 @@ struct LayerWeights { cos: Tensor, sin: Tensor, rope_dim: usize, - neg_inf: Tensor, kv_cache: Option<(Tensor, Tensor)>, span_attn: tracing::Span, span_rot: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); @@ -131,7 +124,7 @@ impl LayerWeights { None => att, Some(mask) => { let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? + att.masked_fill(&mask, f32::NEG_INFINITY)? } }; let att = candle_nn::ops::softmax_last_dim(&att)?; @@ -199,7 +192,6 @@ impl ModelWeights { let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize; let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64; let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -233,7 +225,6 @@ impl ModelWeights { cos: cos.clone(), sin: sin.clone(), rope_dim, - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, diff --git a/candle-transformers/src/models/quantized_phi3.rs b/candle-transformers/src/models/quantized_phi3.rs index 257ad98379..1ef7dd06c4 100644 --- a/candle-transformers/src/models/quantized_phi3.rs +++ b/candle-transformers/src/models/quantized_phi3.rs @@ -67,19 +67,12 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, - neg_inf: Tensor, kv_cache: KvCache, use_flash_attn: bool, span_attn: tracing::Span, span_rot: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); @@ -141,7 +134,7 @@ impl LayerWeights { None => att, Some(mask) => { let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? + att.masked_fill(&mask, f32::NEG_INFINITY)? } }; let att = candle_nn::ops::softmax_last_dim(&att)?; @@ -224,7 +217,6 @@ impl ModelWeights { let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize; let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; @@ -263,7 +255,6 @@ impl ModelWeights { head_dim, cos: cos.clone(), sin: sin.clone(), - neg_inf: neg_inf.clone(), kv_cache, use_flash_attn, span_attn, diff --git a/candle-transformers/src/models/quantized_qwen2.rs b/candle-transformers/src/models/quantized_qwen2.rs index addfab2b04..09c9b33a55 100644 --- a/candle-transformers/src/models/quantized_qwen2.rs +++ b/candle-transformers/src/models/quantized_qwen2.rs @@ -39,19 +39,12 @@ struct LayerWeights { head_dim: usize, cos: Tensor, sin: Tensor, - neg_inf: Tensor, kv_cache: Option<(Tensor, Tensor)>, span_attn: tracing::Span, span_rot: tracing::Span, span_mlp: tracing::Span, } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { - let shape = mask.shape(); - let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; - Ok(m) -} - impl LayerWeights { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { let _enter = self.span_rot.enter(); @@ -120,7 +113,7 @@ impl LayerWeights { None => att, Some(mask) => { let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? + att.masked_fill(&mask, f32::NEG_INFINITY)? } }; let att = candle_nn::ops::softmax_last_dim(&att)?; @@ -185,8 +178,6 @@ impl ModelWeights { let head_dim = embedding_length / head_count; - let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; - let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; let norm = RmsNorm::from_qtensor( @@ -256,7 +247,6 @@ impl ModelWeights { n_head: head_count, n_kv_head: head_count_kv, head_dim, - neg_inf: neg_inf.clone(), kv_cache: None, span_attn, span_rot, diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index 88224d2da3..f1ef429626 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -33,13 +33,6 @@ fn get_mask(size: usize, device: &Device) -> Result { Tensor::from_slice(&mask, (size, size), device) } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { vocab_size: usize, @@ -351,8 +344,7 @@ impl T5Attention { }; let scores = match mask { None => scores, - Some(mask) => masked_fill( - &scores, + Some(mask) => scores.masked_fill( &mask .unsqueeze(0)? .unsqueeze(0)? diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 8a7a8955b6..8d3b9d2131 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -30,13 +30,6 @@ fn get_mask(size: usize, device: &Device) -> Result { Tensor::from_slice(&mask, (size, size), device) } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - #[derive(Debug, Deserialize, Default, Clone, PartialEq)] pub struct ActivationWithOptionalGating { pub gated: bool, @@ -410,8 +403,7 @@ impl T5Attention { }; let scores = match mask { None => scores, - Some(mask) => masked_fill( - &scores, + Some(mask) => scores.masked_fill( &mask .unsqueeze(0)? .unsqueeze(0)? diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index ab9333d27e..aa6aa533e6 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -119,7 +119,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = att.masked_fill(&mask, f32::NEG_INFINITY)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; @@ -163,13 +163,6 @@ impl CausalSelfAttention { } } -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - struct Mlp { c_fc1: Linear, c_fc2: Linear,