Skip to content

Commit

Permalink
make logging of metrics optional and dependent on metrics setting (#1638
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ourownstory authored Aug 28, 2024
1 parent 5503553 commit 00301ad
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
41 changes: 27 additions & 14 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def __init__(
log.info(
DeprecationWarning(
"Providing metrics to collect via `collect_metrics` in NeuralProphet is deprecated and will be "
+ "removed in a future version. The metrics are now configure in the `fit()` method via `metrics`."
+ "removed in a future version. The metrics are now configured in the `fit()` method via `metrics`."
)
)
self.metrics = utils_metrics.get_metrics(collect_metrics)
Expand Down Expand Up @@ -548,7 +548,6 @@ def __init__(
self.data_params = None

# Pytorch Lightning Trainer
self.metrics_logger = MetricsLogger(save_dir=os.getcwd())
self.accelerator = accelerator
self.trainer_config = trainer_config

Expand Down Expand Up @@ -911,6 +910,7 @@ def fit(
early_stopping: bool = False,
minimal: bool = False,
metrics: Optional[np_types.CollectMetricsMode] = None,
metrics_log_dir: Optional[str] = None,
progress: Optional[str] = "bar",
checkpointing: bool = False,
continue_training: bool = False,
Expand Down Expand Up @@ -973,10 +973,15 @@ def fit(
pd.DataFrame
metrics with training and potentially evaluation metrics
"""
if minimal:
checkpointing = False
self.metrics = False
progress = None

if self.fitted:
raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.")

# Configuration
# Train Config overrides
if epochs is not None:
self.config_train.epochs = epochs

Expand All @@ -989,10 +994,7 @@ def fit(
if early_stopping is not None:
self.early_stopping = early_stopping

if metrics is not None:
self.metrics = utils_metrics.get_metrics(metrics)

# Warnings
# Warning for early stopping and regularization
if early_stopping:
reg_enabled = utils.check_for_regularization(
[
Expand All @@ -1012,17 +1014,28 @@ def fit(
number of epochs to train for."
)

if progress == "plot" and metrics is False:
log.info("Progress plot requires metrics to be enabled. Enabling the default metrics.")
metrics = utils_metrics.get_metrics(True)
# Setup Metrics
if metrics is not None:
self.metrics = utils_metrics.get_metrics(metrics)

if progress == "plot" and not self.metrics:
log.info("Progress plot requires metrics to be enabled. Setting progress to bar.")
progress = "bar"

if not self.config_normalization.global_normalization:
log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.")

if minimal:
checkpointing = False
self.metrics = False
progress = None
if metrics_log_dir is not None and not self.metrics:
log.error("Metrics are disabled. Ignoring provided logging directory.")
metrics_log_dir = None
if metrics_log_dir is None and self.metrics:
log.warning("Metrics are enabled. Please provide valid metrics logging directory. Setting to CWD")
metrics_log_dir = os.getcwd()

if self.metrics:
self.metrics_logger = MetricsLogger(save_dir=metrics_log_dir)
else:
self.metrics_logger = None

# Pre-processing
# Copy df and save list of unique time series IDs (the latter for global-local modelling if enabled)
Expand Down
6 changes: 4 additions & 2 deletions neuralprophet/utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def get_metrics(metric_input):
Dict of names of torchmetrics.Metric metrics
"""
if metric_input is None:
return {}
return False
elif metric_input is False:
return False
elif metric_input is True:
return {"MAE": METRICS["MAE"], "RMSE": METRICS["RMSE"]}
elif isinstance(metric_input, str):
Expand All @@ -51,5 +53,5 @@ def get_metrics(metric_input):
"All metrics must be valid names of torchmetrics.Metric objects."
)
return {k: [v, {}] for k, v in metric_input.items()}
elif metric_input is not False:
else:
raise ValueError("Received unsupported argument for collect_metrics.")
6 changes: 6 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def test_create_dummy_datestamps():
_ = m.make_future_dataframe(df_dummy, periods=365, n_historic_predictions=True)


def test_no_log():
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(epochs=EPOCHS, batch_size=BATCH_SIZE, learning_rate=LR)
_ = m.fit(df, metrics=False, metrics_log_dir=False)


def test_save_load():
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
Expand Down

0 comments on commit 00301ad

Please sign in to comment.