From 99eb4826e82466f9963f0200d3f4f4d4940b88c6 Mon Sep 17 00:00:00 2001 From: Janice Lan Date: Mon, 17 Jul 2023 17:14:54 -0700 Subject: [PATCH] default for get task metrics --- ocpmodels/trainers/base_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index cb7ba19bc..71b4fd893 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -181,7 +181,7 @@ def __init__( self.evaluator = Evaluator( task=name, eval_metrics=self.config["task"].get( - "evaluation_metrics", Evaluator.task_metrics[name] + "evaluation_metrics", Evaluator.task_metrics.get(name, {}) ), ) @@ -960,7 +960,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False): evaluator = Evaluator( task=self.name, eval_metrics=self.config["task"].get( - "evaluation_metrics", Evaluator.task_metrics[self.name] + "evaluation_metrics", Evaluator.task_metrics.get(self.name, {}) ), )