Skip to content

Commit

Permalink
fix(auto): support custom column names (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Nov 12, 2024
1 parent e869465 commit 5ec7a6b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
20 changes: 17 additions & 3 deletions mlforecast/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,12 @@ def fit(
if loss is None:

def loss(df, train_df): # noqa: ARG001
return smape(df, models=["model"])["model"].mean()
return smape(
df,
models=["model"],
id_col=id_col,
target_col=target_col,
)["model"].mean()

if study_kwargs is None:
study_kwargs = {}
Expand Down Expand Up @@ -554,8 +559,14 @@ def config_fn(trial: optuna.Trial) -> Dict[str, Any]:
study.optimize(objective, n_trials=num_samples, **optimize_kwargs)
self.results_[name] = study
best_config = study.best_trial.user_attrs["config"]
best_config["mlf_fit_params"].pop("fitted", None)
best_config["mlf_fit_params"].pop("prediction_intervals", None)
for arg in (
"fitted",
"prediction_intervals",
"id_col",
"time_col",
"target_col",
):
best_config["mlf_fit_params"].pop(arg, None)
best_model = clone(auto_model.model)
best_model.set_params(**best_config["model_params"])
self.models_[name] = MLForecast(
Expand All @@ -567,6 +578,9 @@ def config_fn(trial: optuna.Trial) -> Dict[str, Any]:
df,
fitted=fitted,
prediction_intervals=prediction_intervals,
id_col=id_col,
time_col=time_col,
target_col=target_col,
**best_config["mlf_fit_params"],
)
return self
Expand Down
66 changes: 57 additions & 9 deletions nbs/auto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,12 @@
"\n",
" if loss is None:\n",
" def loss(df, train_df): # noqa: ARG001\n",
" return smape(df, models=['model'])['model'].mean()\n",
" return smape(\n",
" df,\n",
" models=['model'],\n",
" id_col=id_col,\n",
" target_col=target_col,\n",
" )['model'].mean()\n",
" if study_kwargs is None:\n",
" study_kwargs = {}\n",
" if 'sampler' not in study_kwargs:\n",
Expand Down Expand Up @@ -629,8 +634,10 @@
" study.optimize(objective, n_trials=num_samples, **optimize_kwargs)\n",
" self.results_[name] = study\n",
" best_config = study.best_trial.user_attrs['config']\n",
" best_config['mlf_fit_params'].pop('fitted', None)\n",
" best_config['mlf_fit_params'].pop('prediction_intervals', None)\n",
" for arg in (\n",
" 'fitted', 'prediction_intervals', 'id_col', 'time_col', 'target_col'\n",
" ):\n",
" best_config['mlf_fit_params'].pop(arg, None)\n",
" best_model = clone(auto_model.model)\n",
" best_model.set_params(**best_config['model_params'])\n",
" self.models_[name] = MLForecast(\n",
Expand All @@ -642,6 +649,9 @@
" df,\n",
" fitted=fitted,\n",
" prediction_intervals=prediction_intervals,\n",
" id_col=id_col,\n",
" time_col=time_col,\n",
" target_col=target_col,\n",
" **best_config['mlf_fit_params'],\n",
" )\n",
" return self\n",
Expand Down Expand Up @@ -904,7 +914,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L574){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.predict\n",
"\n",
Expand All @@ -924,7 +934,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L570){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L574){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.predict\n",
"\n",
Expand Down Expand Up @@ -962,7 +972,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L606){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.save\n",
"\n",
Expand All @@ -978,7 +988,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L602){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L606){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.save\n",
"\n",
Expand Down Expand Up @@ -1012,7 +1022,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L616){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.forecast_fitted_values\n",
"\n",
Expand All @@ -1030,7 +1040,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L612){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/auto.py#L616){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### AutoMLForecast.forecast_fitted_values\n",
"\n",
Expand Down Expand Up @@ -1062,6 +1072,7 @@
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from datasetsforecast.m4 import M4, M4Evaluation, M4Info\n",
"from sklearn.linear_model import Ridge\n",
"from sklearn.compose import ColumnTransformer\n",
Expand Down Expand Up @@ -1740,6 +1751,43 @@
"metric_step_1 = auto_mlf2.results_['ridge'].best_trial.value\n",
"assert abs(metric_step_h / metric_step_1 - 1) > 0.02"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3c00fd0-ebee-4d40-b2aa-dc7ae98ee94c",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# default loss with non standard names\n",
"auto_mlf = AutoMLForecast(\n",
" freq=1,\n",
" season_length=season_length,\n",
" models={'ridge': AutoRidge()},\n",
")\n",
"fit_kwargs = dict(\n",
" n_windows=2,\n",
" h=h,\n",
" step_size=1,\n",
" num_samples=2,\n",
" optimize_kwargs={'timeout': 60}, \n",
")\n",
"preds = auto_mlf.fit(df=train, **fit_kwargs).predict(5)\n",
"\n",
"train2 = train.rename(columns={'unique_id': 'id', 'ds': 'time', 'y': 'target'})\n",
"preds2 = auto_mlf.fit(\n",
" df=train2,\n",
" id_col='id',\n",
" time_col='time',\n",
" target_col='target',\n",
" **fit_kwargs,\n",
").predict(5)\n",
"pd.testing.assert_frame_equal(\n",
" preds,\n",
" preds2.rename(columns={'id': 'unique_id', 'time': 'ds'}),\n",
")"
]
}
],
"metadata": {
Expand Down

0 comments on commit 5ec7a6b

Please sign in to comment.