Skip to content

Commit

Permalink
ensure static features are constant (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jul 25, 2024
1 parent 66b8203 commit 2edde26
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 41 deletions.
2 changes: 1 addition & 1 deletion mlforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.13.2"
__version__ = "0.13.3"
__all__ = ['MLForecast']
from mlforecast.forecast import MLForecast
25 changes: 20 additions & 5 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,19 +293,34 @@ def _fit(
tfm.set_column_names(id_col, time_col, target_col)
sorted_df = tfm.fit_transform(sorted_df)
ga.data = sorted_df[target_col].to_numpy()
self.ga = ga
last_idxs_per_serie = self.ga.indptr[1:] - 1
to_drop = [id_col, time_col, target_col]
if static_features is None:
static_features = [c for c in df.columns if c not in [time_col, target_col]]
elif id_col not in static_features:
static_features = [id_col] + static_features
else: # static_features defined and contain id_col
to_drop = [time_col, target_col]
self.ga = ga
series_starts = ga.indptr[:-1]
series_ends = ga.indptr[1:] - 1
if self._sort_idxs is not None:
last_idxs_per_serie = self._sort_idxs[last_idxs_per_serie]
self.static_features_ = ufp.take_rows(df, last_idxs_per_serie)[static_features]
self.static_features_ = ufp.drop_index_if_pandas(self.static_features_)
series_starts = self._sort_idxs[series_starts]
series_ends = self._sort_idxs[series_ends]
statics_on_starts = ufp.drop_index_if_pandas(
ufp.take_rows(df, series_starts)[static_features]
)
statics_on_ends = ufp.drop_index_if_pandas(
ufp.take_rows(df, series_ends)[static_features]
)
for feat in static_features:
if (statics_on_starts[feat] != statics_on_ends[feat]).any():
raise ValueError(
f"{feat} is declared as a static feature but its values change "
"over time. Please set the `static_features` argument to "
"indicate which features are static.\nIf all of your features "
"are dynamic please set `static_features=[]`."
)
self.static_features_ = statics_on_ends
self.features_order_ = [
c for c in df.columns if c not in to_drop
] + self.features
Expand Down
53 changes: 35 additions & 18 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -776,19 +776,34 @@
" tfm.set_column_names(id_col, time_col, target_col)\n",
" sorted_df = tfm.fit_transform(sorted_df)\n",
" ga.data = sorted_df[target_col].to_numpy()\n",
" self.ga = ga\n",
" last_idxs_per_serie = self.ga.indptr[1:] - 1\n",
" to_drop = [id_col, time_col, target_col]\n",
" if static_features is None:\n",
" static_features = [c for c in df.columns if c not in [time_col, target_col]]\n",
" elif id_col not in static_features:\n",
" static_features = [id_col] + static_features\n",
" else: # static_features defined and contain id_col\n",
" to_drop = [time_col, target_col]\n",
" self.ga = ga\n",
" series_starts = ga.indptr[:-1]\n",
" series_ends = ga.indptr[1:] - 1\n",
" if self._sort_idxs is not None:\n",
" last_idxs_per_serie = self._sort_idxs[last_idxs_per_serie]\n",
" self.static_features_ = ufp.take_rows(df, last_idxs_per_serie)[static_features]\n",
" self.static_features_ = ufp.drop_index_if_pandas(self.static_features_)\n",
" series_starts = self._sort_idxs[series_starts]\n",
" series_ends = self._sort_idxs[series_ends]\n",
" statics_on_starts = ufp.drop_index_if_pandas(\n",
" ufp.take_rows(df, series_starts)[static_features]\n",
" )\n",
" statics_on_ends = ufp.drop_index_if_pandas(\n",
" ufp.take_rows(df, series_ends)[static_features]\n",
" )\n",
" for feat in static_features:\n",
" if (statics_on_starts[feat] != statics_on_ends[feat]).any():\n",
" raise ValueError(\n",
" f'{feat} is declared as a static feature but its values change '\n",
" 'over time. Please set the `static_features` argument to '\n",
" 'indicate which features are static.\\nIf all of your features '\n",
" 'are dynamic please set `static_features=[]`.'\n",
" )\n",
" self.static_features_ = statics_on_ends\n",
" self.features_order_ = [c for c in df.columns if c not in to_drop] + self.features\n",
" return self\n",
"\n",
Expand Down Expand Up @@ -1647,7 +1662,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L466){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L481){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
Expand All @@ -1660,16 +1675,16 @@
"> max_horizon:Optional[int]=None,\n",
"> return_X_y:bool=False, as_numpy:bool=False)\n",
"\n",
"Add the features to `data` and save the required information for the predictions step.\n",
"*Add the features to `data` and save the required information for the predictions step.\n",
"\n",
"If not all features are static, specify which ones are in `static_features`.\n",
"If you don't want to drop rows with null values after the transformations set `dropna=False`\n",
"If `keep_last_n` is not None then that number of observations is kept across all series for updates."
"If `keep_last_n` is not None then that number of observations is kept across all series for updates.*"
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L466){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L481){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
Expand All @@ -1682,11 +1697,11 @@
"> max_horizon:Optional[int]=None,\n",
"> return_X_y:bool=False, as_numpy:bool=False)\n",
"\n",
"Add the features to `data` and save the required information for the predictions step.\n",
"*Add the features to `data` and save the required information for the predictions step.\n",
"\n",
"If not all features are static, specify which ones are in `static_features`.\n",
"If you don't want to drop rows with null values after the transformations set `dropna=False`\n",
"If `keep_last_n` is not None then that number of observations is kept across all series for updates."
"If `keep_last_n` is not None then that number of observations is kept across all series for updates.*"
]
},
"execution_count": null,
Expand Down Expand Up @@ -1917,7 +1932,9 @@
"ts = TimeSeries(**flow_config)\n",
"df = ts.fit_transform(series, id_col='unique_id', time_col='ds', target_col='y')\n",
"non_std_series = series.reset_index().rename(columns={'unique_id': 'some_id', 'ds': 'timestamp', 'y': 'value'})\n",
"non_std_res = ts.fit_transform(non_std_series, id_col='some_id', time_col='timestamp', target_col='value')\n",
"non_std_res = ts.fit_transform(\n",
" non_std_series, id_col='some_id', time_col='timestamp', target_col='value', static_features=[]\n",
")\n",
"non_std_res = non_std_res.reset_index(drop=True)\n",
"pd.testing.assert_frame_equal(\n",
" df.reset_index(),\n",
Expand Down Expand Up @@ -1960,7 +1977,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L711){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L726){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
Expand All @@ -1975,7 +1992,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L711){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L726){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
Expand Down Expand Up @@ -2114,28 +2131,28 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L817){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L831){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
"> TimeSeries.update\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n",
"> .frame.DataFrame])\n",
"\n",
"Update the values of the stored series."
"*Update the values of the stored series.*"
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L817){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L831){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
"> TimeSeries.update\n",
"> (df:Union[pandas.core.frame.DataFrame,polars.dataframe\n",
"> .frame.DataFrame])\n",
"\n",
"Update the values of the stored series."
"*Update the values of the stored series.*"
]
},
"execution_count": null,
Expand Down
2 changes: 1 addition & 1 deletion nbs/docs/how-to-guides/transforming_exog.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@
" lags=[1],\n",
" date_features=['dayofweek'],\n",
")\n",
"fcst.preprocess(series_with_prices, dropna=True).head()"
"fcst.preprocess(series_with_prices, static_features=[], dropna=True).head()"
]
},
{
Expand Down
21 changes: 8 additions & 13 deletions nbs/forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4525,21 +4525,12 @@
"execution_count": null,
"id": "4e9188fd-8264-41d1-a4d7-89fa51d915d4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_79672/2131933040.py:4: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.\n",
" non_std_series['ds'] = non_std_series.groupby('unique_id').cumcount()\n"
]
}
],
"outputs": [],
"source": [
"#| hide\n",
"series = generate_daily_series(100, equal_ends=True, n_static_features=2, static_as_categorical=False)\n",
"non_std_series = series.copy()\n",
"non_std_series['ds'] = non_std_series.groupby('unique_id').cumcount()\n",
"non_std_series['ds'] = non_std_series.groupby('unique_id', observed=True).cumcount()\n",
"non_std_series = non_std_series.rename(columns={'unique_id': 'some_id', 'ds': 'time', 'y': 'value'})\n",
"models = [\n",
" lgb.LGBMRegressor(n_jobs=1, random_state=0, verbosity=-1),\n",
Expand Down Expand Up @@ -4705,7 +4696,9 @@
" fitted_pl = fcst_pl.forecast_fitted_values()\n",
" preds_pl = fcst_pl.predict(X_df=prices_pl, **predict_kwargs)\n",
" preds_pl_subset = fcst_pl.predict(X_df=prices_pl, ids=fcst_pl.ts.uids[[0, 6]], **predict_kwargs)\n",
" cv_pl = fcst_pl.cross_validation(series_pl, n_windows=2, h=horizon, fitted=True, as_numpy=as_numpy)\n",
" cv_pl = fcst_pl.cross_validation(\n",
" series_pl, n_windows=2, h=horizon, fitted=True, static_features=['product_id', 'static_1'], as_numpy=as_numpy\n",
" )\n",
" cv_fitted_pl = fcst_pl.cross_validation_fitted_values()\n",
" \n",
" fcst_pd = MLForecast(**cfg)\n",
Expand All @@ -4714,7 +4707,9 @@
" preds_pd = fcst_pd.predict(X_df=prices_pd, **predict_kwargs)\n",
" preds_pd_subset = fcst_pd.predict(X_df=prices_pd, ids=fcst_pd.ts.uids[[0, 6]], **predict_kwargs)\n",
" assert preds_pd_subset['unique_id'].unique().tolist() == ['id_0', 'id_6']\n",
" cv_pd = fcst_pd.cross_validation(series_pd, n_windows=2, h=horizon, fitted=True, as_numpy=as_numpy)\n",
" cv_pd = fcst_pd.cross_validation(\n",
" series_pd, n_windows=2, h=horizon, fitted=True, static_features=['product_id', 'static_1'], as_numpy=as_numpy\n",
" )\n",
" cv_fitted_pd = fcst_pd.cross_validation_fitted_values()\n",
"\n",
" if max_horizon is not None:\n",
Expand Down
6 changes: 3 additions & 3 deletions settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ author = José Morales
author_email = [email protected]
copyright = Nixtla
branch = main
version = 0.13.2
version = 0.13.3
min_python = 3.8
audience = Developers
language = English
Expand All @@ -35,8 +35,8 @@ title = mlforecast
tst_flags = polars
black_formatting = True
readme_nb = index.ipynb
allowed_metadata_keys =
allowed_cell_metadata_keys =
allowed_metadata_keys =
allowed_cell_metadata_keys =
jupyter_hooks = True
clean_ids = True
clear_all = False
Expand Down

0 comments on commit 2edde26

Please sign in to comment.