From 00301ad166c1dded5ec8db8ecb5f59eae8830142 Mon Sep 17 00:00:00 2001 From: Oskar Triebe Date: Tue, 27 Aug 2024 22:51:44 -0700 Subject: [PATCH] make logging of metrics optional and dependent on metrics setting (#1638) --- neuralprophet/forecaster.py | 41 ++++++++++++++++++++++------------ neuralprophet/utils_metrics.py | 6 +++-- tests/test_utils.py | 6 +++++ 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 85939955e..3039ca567 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -476,7 +476,7 @@ def __init__( log.info( DeprecationWarning( "Providing metrics to collect via `collect_metrics` in NeuralProphet is deprecated and will be " - + "removed in a future version. The metrics are now configure in the `fit()` method via `metrics`." + + "removed in a future version. The metrics are now configured in the `fit()` method via `metrics`." ) ) self.metrics = utils_metrics.get_metrics(collect_metrics) @@ -548,7 +548,6 @@ def __init__( self.data_params = None # Pytorch Lightning Trainer - self.metrics_logger = MetricsLogger(save_dir=os.getcwd()) self.accelerator = accelerator self.trainer_config = trainer_config @@ -911,6 +910,7 @@ def fit( early_stopping: bool = False, minimal: bool = False, metrics: Optional[np_types.CollectMetricsMode] = None, + metrics_log_dir: Optional[str] = None, progress: Optional[str] = "bar", checkpointing: bool = False, continue_training: bool = False, @@ -973,10 +973,15 @@ def fit( pd.DataFrame metrics with training and potentially evaluation metrics """ + if minimal: + checkpointing = False + self.metrics = False + progress = None + if self.fitted: raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") - # Configuration + # Train Config overrides if epochs is not None: self.config_train.epochs = epochs @@ -989,10 +994,7 @@ def fit( if early_stopping is not None: self.early_stopping = early_stopping - if metrics is not None: - self.metrics = utils_metrics.get_metrics(metrics) - - # Warnings + # Warning for early stopping and regularization if early_stopping: reg_enabled = utils.check_for_regularization( [ @@ -1012,17 +1014,28 @@ def fit( number of epochs to train for." ) - if progress == "plot" and metrics is False: - log.info("Progress plot requires metrics to be enabled. Enabling the default metrics.") - metrics = utils_metrics.get_metrics(True) + # Setup Metrics + if metrics is not None: + self.metrics = utils_metrics.get_metrics(metrics) + + if progress == "plot" and not self.metrics: + log.info("Progress plot requires metrics to be enabled. Setting progress to bar.") + progress = "bar" if not self.config_normalization.global_normalization: log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.") - if minimal: - checkpointing = False - self.metrics = False - progress = None + if metrics_log_dir is not None and not self.metrics: + log.error("Metrics are disabled. Ignoring provided logging directory.") + metrics_log_dir = None + if metrics_log_dir is None and self.metrics: + log.warning("Metrics are enabled. Please provide valid metrics logging directory. Setting to CWD") + metrics_log_dir = os.getcwd() + + if self.metrics: + self.metrics_logger = MetricsLogger(save_dir=metrics_log_dir) + else: + self.metrics_logger = None # Pre-processing # Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled) diff --git a/neuralprophet/utils_metrics.py b/neuralprophet/utils_metrics.py index 6dafc3ef8..368d742ed 100644 --- a/neuralprophet/utils_metrics.py +++ b/neuralprophet/utils_metrics.py @@ -27,7 +27,9 @@ def get_metrics(metric_input): Dict of names of torchmetrics.Metric metrics """ if metric_input is None: - return {} + return False + elif metric_input is False: + return False elif metric_input is True: return {"MAE": METRICS["MAE"], "RMSE": METRICS["RMSE"]} elif isinstance(metric_input, str): @@ -51,5 +53,5 @@ def get_metrics(metric_input): "All metrics must be valid names of torchmetrics.Metric objects." ) return {k: [v, {}] for k, v in metric_input.items()} - elif metric_input is not False: + else: raise ValueError("Received unsupported argument for collect_metrics.") diff --git a/tests/test_utils.py b/tests/test_utils.py index a327f3122..1bda8f108 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -40,6 +40,12 @@ def test_create_dummy_datestamps(): _ = m.make_future_dataframe(df_dummy, periods=365, n_historic_predictions=True) +def test_no_log(): + df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR) + _ = m.fit(df, metrics=False, metrics_log_dir=False) + + def test_save_load(): df = pd.read_csv(PEYTON_FILE, nrows=NROWS) m = NeuralProphet(