diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dac356146f3015..90d15c44905a0e 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -532,7 +532,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) @@ -1278,6 +1278,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1389,6 +1395,12 @@ def __init__(self, config: ModernBertConfig): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.embeddings.tok_embeddings + + def set_input_embeddings(self, value): + self.embeddings.tok_embeddings = value + @add_start_docstrings_to_model_forward(MODERNBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC,