From 1c0f29ec446fb298a484e64da101b27feef51201 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 23 Dec 2024 12:45:51 +0100 Subject: [PATCH] push to device --- src/transformers/loss/loss_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 7f6aaaa44264ca..2e8e2bb5f149d7 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -34,6 +34,7 @@ def ForCausalLMLoss( ): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() + labels = labels.to(logits.device) # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -52,6 +53,7 @@ def ForMaskedLMLoss( ): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() + labels = labels.to(logits.device) # Flatten the tokens logits = logits.view(-1, vocab_size) @@ -73,6 +75,7 @@ def ForSequenceClassificationLoss(labels, pooled_logits, config, **kwargs): else: config.problem_type = "multi_label_classification" + labels = labels.to(pooled_logits.device) if config.problem_type == "regression": loss_fct = MSELoss() if num_labels == 1: @@ -109,7 +112,7 @@ def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_posi def ForTokenClassification(logits, labels, config, **kwargs): # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.view(-1, config.num_labels) - labels = labels.view(-1) + labels = labels.view(-1).to(logits.device) logits = logits.float() # Flatten the tokens return fixed_cross_entropy(logits, labels, **kwargs)