Skip to content

Commit

Permalink
[LLM] valid loss before optimizer step (#9255) (#9705)
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII authored Dec 27, 2024
1 parent 7197b79 commit 691ae01
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss():
tr_loss /= self.args.gradient_accumulation_steps

# assert if loss is invalid
self._check_loss_valid(tr_loss)

self.timers and self.timers("forward-backward").stop()
# Maunally collect gradients
# Case 1: Use recompute and dp
Expand Down Expand Up @@ -1431,13 +1434,17 @@ def _print_timer(self):
if timer_info or paddle_timer_info:
logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}")

def _get_item_from_loss(self, loss):
def _check_loss_valid(self, loss):
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
loss_value = loss.item()
if not self.args.fp16:
if not np.isfinite(loss_value).all():
err_msg = LOSS_NAN_ERROR if np.isnan(loss_value).any() else LOSS_INF_ERROR
raise ValueError(f"{err_msg}. Loss contains inf or nan values, its value is {loss_value}")

def _get_item_from_loss(self, loss):
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
loss_value = loss.item()
return loss_value

def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
Expand Down

0 comments on commit 691ae01

Please sign in to comment.