Skip to content

Commit

Permalink
rename to optimization and define objective function
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Mar 27, 2024
1 parent a364f93 commit cb44a0e
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 354 deletions.
2 changes: 2 additions & 0 deletions mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@
'mlforecast.lgb_cv._rmse': ('lgb_cv.html#_rmse', 'mlforecast/lgb_cv.py'),
'mlforecast.lgb_cv._update': ('lgb_cv.html#_update', 'mlforecast/lgb_cv.py'),
'mlforecast.lgb_cv._update_and_predict': ('lgb_cv.html#_update_and_predict', 'mlforecast/lgb_cv.py')},
'mlforecast.optimization': { 'mlforecast.optimization.mlforecast_objective': ( 'optimization.html#mlforecast_objective',
'mlforecast/optimization.py')},
'mlforecast.target_transforms': { 'mlforecast.target_transforms.AutoDifferences': ( 'target_transforms.html#autodifferences',
'mlforecast/target_transforms.py'),
'mlforecast.target_transforms.AutoDifferences.__init__': ( 'target_transforms.html#autodifferences.__init__',
Expand Down
2 changes: 1 addition & 1 deletion mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _as_tuple(x):
return (x,)

# %% ../nbs/core.ipynb 19
Freq = Union[int, str, pd.offsets.BaseOffset]
Freq = Union[int, str]
Lags = Iterable[int]
LagTransform = Union[Callable, Tuple[Callable, Any]]
LagTransforms = Dict[int, List[LagTransform]]
Expand Down
84 changes: 84 additions & 0 deletions mlforecast/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/optimization.ipynb.

# %% auto 0
__all__ = ['mlforecast_objective']

# %% ../nbs/optimization.ipynb 2
import copy
from typing import Callable, List, Optional

import numpy as np
import optuna
import utilsforecast.processing as ufp
from utilsforecast.compat import DataFrame
from utilsforecast.losses import smape

from . import MLForecast
from .core import Freq

# %% ../nbs/optimization.ipynb 3
def mlforecast_objective(
df: DataFrame,
config_fn: Callable,
eval_fn: Callable,
model_constructor: Callable,
freq: Freq,
n_windows: int,
h: int,
id_col: str = "unique_id",
time_col: str = "ds",
target_col: str = "y",
) -> Callable:
def objective(trial: optuna.Trial) -> float:
config = config_fn(trial)
trial.set_user_attr("config", copy.deepcopy(config))
splits = ufp.backtest_splits(
df,
n_windows=n_windows,
h=h,
id_col=id_col,
time_col=time_col,
freq=freq,
)
metrics = []
for i, (_, train, valid) in enumerate(splits):
mlf = MLForecast(
models={"model": model_constructor(**config["model_params"])},
freq=freq,
**config["mlf_init_params"],
)
mlf.fit(
train,
id_col=id_col,
time_col=time_col,
target_col=target_col,
**config["mlf_fit_params"],
)
static = [c for c in mlf.ts.static_features_.columns if c != id_col]
dynamic = [
c
for c in valid.columns
if c not in static + [id_col, time_col, target_col]
]
if dynamic:
X_df: Optional[DataFrame] = ufp.drop_columns(
valid, static + [target_col]
)
else:
X_df = None
preds = mlf.predict(h=h, X_df=X_df)
full = valid.merge(preds, on=[id_col, time_col], how="left")
if full.shape[0] < valid.shape[0]:
raise ValueError(
"Cross validation result produced less results than expected. "
"Please verify that the passed frequency (freq) matches your series' "
"and that there aren't any missing periods."
)
metric = eval_fn(full)
metrics.append(metric)
trial.report(metric, step=i)
if trial.should_prune():
raise optuna.TrialPruned()
return np.mean(metrics)

return objective
2 changes: 1 addition & 1 deletion nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@
"outputs": [],
"source": [
"#| exporti\n",
"Freq = Union[int, str, pd.offsets.BaseOffset]\n",
"Freq = Union[int, str]\n",
"Lags = Iterable[int]\n",
"LagTransform = Union[Callable, Tuple[Callable, Any]]\n",
"LagTransforms = Dict[int, List[LagTransform]]\n",
Expand Down
Loading

0 comments on commit cb44a0e

Please sign in to comment.