Skip to content

Commit

Permalink
Moved logic of update of _best_ckpt_metrics before we build state dic…
Browse files Browse the repository at this point in the history
…t for checkpoint (#2007)
  • Loading branch information
BloodAxe authored Jun 2, 2024
1 parent 217353a commit f3b8947
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,15 @@ def _save_checkpoint(
train_metrics_titles = get_metrics_titles(self.train_metrics)
all_metrics["train"] = {metric_name: float(train_metrics_dict[metric_name]) for metric_name in train_metrics_titles}

best_checkpoint = (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
)

if best_checkpoint:
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric
self._best_ckpt_metrics = all_metrics

# BUILD THE state_dict
state = {
"net": unwrap_model(self.net).state_dict(),
Expand Down Expand Up @@ -713,13 +722,7 @@ def _save_checkpoint(
self.sg_logger.add_checkpoint(tag=f"ckpt_epoch_{epoch}.pth", state_dict=state, global_step=epoch)

# OVERRIDE THE BEST CHECKPOINT AND best_metric IF metric GOT BETTER THAN THE PREVIOUS BEST
if (curr_tracked_metric > self.best_metric and self.greater_metric_to_watch_is_better) or (
curr_tracked_metric < self.best_metric and not self.greater_metric_to_watch_is_better
):
# STORE THE CURRENT metric AS BEST
self.best_metric = curr_tracked_metric

self._best_ckpt_metrics = all_metrics
if best_checkpoint:
self.sg_logger.add_checkpoint(tag=self.ckpt_best_name, state_dict=state, global_step=epoch)

# RUN PHASE CALLBACKS
Expand Down

0 comments on commit f3b8947

Please sign in to comment.