From 7ff39fa75b9e4714ada2e50c4e66333e4874c72e Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Fri, 27 Dec 2024 01:23:23 +0900 Subject: [PATCH 1/2] Update modular_modernbert.py --- src/transformers/models/modernbert/modular_modernbert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dac356146f3015..4424e8b2fead5d 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -532,7 +532,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) From e51b48ab108aea43f9fca0f7dc35efd783c39128 Mon Sep 17 00:00:00 2001 From: Koichi Yasuoka Date: Fri, 27 Dec 2024 17:42:07 +0900 Subject: [PATCH 2/2] support {set,get}_input_embeddings --- .../models/modernbert/modular_modernbert.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 4424e8b2fead5d..90d15c44905a0e 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -1278,6 +1278,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1389,6 +1395,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC,