Skip to content

Commit

Permalink
use context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Feb 29, 2024
1 parent 0c2cc7f commit e08e10c
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 305 deletions.
2 changes: 2 additions & 0 deletions mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'mlforecast.core': { 'mlforecast.core.TimeSeries': ('core.html#timeseries', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries.__init__': ('core.html#timeseries.__init__', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries.__repr__': ('core.html#timeseries.__repr__', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries._backup': ('core.html#timeseries._backup', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries._compute_date_feature': ( 'core.html#timeseries._compute_date_feature',
'mlforecast/core.py'),
'mlforecast.core.TimeSeries._compute_transforms': ( 'core.html#timeseries._compute_transforms',
Expand All @@ -35,6 +36,7 @@
'mlforecast/core.py'),
'mlforecast.core.TimeSeries._has_ga_target_tfms': ( 'core.html#timeseries._has_ga_target_tfms',
'mlforecast/core.py'),
'mlforecast.core.TimeSeries._maybe_subset': ('core.html#timeseries._maybe_subset', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries._predict_multi': ('core.html#timeseries._predict_multi', 'mlforecast/core.py'),
'mlforecast.core.TimeSeries._predict_recursive': ( 'core.html#timeseries._predict_recursive',
'mlforecast/core.py'),
Expand Down
288 changes: 163 additions & 125 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,19 @@
import reprlib
import warnings
from collections import Counter, OrderedDict
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)

import cloudpickle
import fsspec
Expand Down Expand Up @@ -486,7 +497,7 @@ def _update_features(self) -> DataFrame:
)
self.test_dates.append(self.curr_dates)

features = self._compute_transforms(self._transforms, updates_only=True)
features = self._compute_transforms(self.transforms, updates_only=True)

for feature in self.date_features:
feat_name, feat_vals = self._compute_date_feature(self.curr_dates, feature)
Expand All @@ -497,24 +508,24 @@ def _update_features(self) -> DataFrame:
else:
df_constructor = pd.DataFrame
features_df = df_constructor(features)[self.features]
return ufp.horizontal_concat([self._static_features, features_df])
return ufp.horizontal_concat([self.static_features_, features_df])

def _get_raw_predictions(self) -> np.ndarray:
return np.array(self.y_pred).ravel("F")

def _get_future_ids(self, h: int):
if isinstance(self._uids, pl_Series):
uids = pl.concat([self._uids for _ in range(h)]).sort()
if isinstance(self.uids, pl_Series):
uids = pl.concat([self.uids for _ in range(h)]).sort()
else:
uids = pd.Series(
np.repeat(self._uids, h), name=self.id_col, dtype=self.uids.dtype
np.repeat(self.uids, h), name=self.id_col, dtype=self.uids.dtype
)
return uids

def _get_predictions(self) -> DataFrame:
"""Get all the predicted values with their corresponding ids and datestamps."""
h = len(self.y_pred)
if isinstance(self._uids, pl_Series):
if isinstance(self.uids, pl_Series):
df_constructor = pl_DataFrame
else:
df_constructor = pd.DataFrame
Expand All @@ -528,28 +539,10 @@ def _get_predictions(self) -> DataFrame:
)
return df

def _predict_setup(self) -> None:
self.ga = copy.copy(self._ga)
self._transforms = copy.deepcopy(self.transforms)
if isinstance(self.last_dates, pl_Series):
self.curr_dates = self.last_dates.clone()
else:
self.curr_dates = self.last_dates.copy()
if self._idxs is not None:
self.ga = self.ga.take(self._idxs)
for name, tfm in self._transforms.items():
if hasattr(tfm, "take"):
tfm = tfm.take(self._idxs)
self._transforms[name] = tfm
self.curr_dates = self.curr_dates[self._idxs]
self.test_dates: List[Union[pd.Index, pl_Series]] = []
self.y_pred = []
self._h = 0

def _get_features_for_next_step(self, X_df=None):
new_x = self._update_features()
if X_df is not None:
n_series = len(self._uids)
n_series = len(self.uids)
h = X_df.shape[0] // n_series
rows = np.arange(self._h, X_df.shape[0], h)
X = ufp.take_rows(X_df, rows)
Expand All @@ -569,6 +562,28 @@ def _get_features_for_next_step(self, X_df=None):
new_x = ufp.to_numpy(new_x)
return new_x

@contextmanager
def _backup(self) -> Iterator[None]:
# this gets modified during predict because the predictions are appended
ga = copy.copy(self.ga)
# if these save state (like ExpandingMean) they'll get modified by the updates
lag_tfms = copy.deepcopy(self.transforms)
try:
yield
finally:
self.ga = ga
self.transforms = lag_tfms

def _predict_setup(self) -> None:
# TODO: move to utils
if isinstance(self.last_dates, pl_Series):
self.curr_dates = self.last_dates.clone()
else:
self.curr_dates = self.last_dates.copy()
self.test_dates: List[Union[pd.Index, pl_Series]] = []
self.y_pred = []
self._h = 0

def _predict_recursive(
self,
models: Dict[str, BaseEstimator],
Expand All @@ -579,22 +594,23 @@ def _predict_recursive(
) -> DataFrame:
"""Use `model` to predict the next `horizon` timesteps."""
for i, (name, model) in enumerate(models.items()):
self._predict_setup()
for _ in range(horizon):
new_x = self._get_features_for_next_step(X_df)
if before_predict_callback is not None:
new_x = before_predict_callback(new_x)
predictions = model.predict(new_x)
if after_predict_callback is not None:
predictions = after_predict_callback(predictions)
self._update_y(predictions)
if i == 0:
preds = self._get_predictions()
rename_dict = {f"{self.target_col}_pred": name}
preds = ufp.rename(preds, rename_dict)
else:
raw_preds = self._get_raw_predictions()
preds = ufp.assign_columns(preds, name, raw_preds)
with self._backup():
self._predict_setup()
for _ in range(horizon):
new_x = self._get_features_for_next_step(X_df)
if before_predict_callback is not None:
new_x = before_predict_callback(new_x)
predictions = model.predict(new_x)
if after_predict_callback is not None:
predictions = after_predict_callback(predictions)
self._update_y(predictions)
if i == 0:
preds = self._get_predictions()
rename_dict = {f"{self.target_col}_pred": name}
preds = ufp.rename(preds, rename_dict)
else:
raw_preds = self._get_raw_predictions()
preds = ufp.assign_columns(preds, name, raw_preds)
return preds

def _predict_multi(
Expand All @@ -619,15 +635,15 @@ def _predict_multi(
df_constructor = pd.DataFrame
result = df_constructor({self.id_col: uids, self.time_col: dates})
for name, model in models.items():
self._predict_setup()
new_x = self._get_features_for_next_step(X_df)
if before_predict_callback is not None:
new_x = before_predict_callback(new_x)
predictions = np.empty((new_x.shape[0], horizon))
for i in range(horizon):
predictions[:, i] = model[i].predict(new_x)
raw_preds = predictions.ravel()
result = ufp.assign_columns(result, name, raw_preds)
with self._backup():
new_x = self._get_features_for_next_step(X_df)
if before_predict_callback is not None:
new_x = before_predict_callback(new_x)
predictions = np.empty((new_x.shape[0], horizon))
for i in range(horizon):
predictions[:, i] = model[i].predict(new_x)
raw_preds = predictions.ravel()
result = ufp.assign_columns(result, name, raw_preds)
return result

def _has_ga_target_tfms(self):
Expand All @@ -636,6 +652,41 @@ def _has_ga_target_tfms(self):
for tfm in self.target_transforms
)

@contextmanager
def _maybe_subset(self, idxs: Optional[np.ndarray]) -> Iterator[None]:
# save original
ga = self.ga
uids = self.uids
statics = self.static_features_
last_dates = self.last_dates
targ_tfms = copy.copy(self.target_transforms)
lag_tfms = copy.deepcopy(self.transforms)

if idxs is not None:
# assign subsets
self.ga = self.ga.take(idxs)
self.uids = uids[idxs]
self.static_features_ = ufp.take_rows(statics, idxs)
self.static_features_ = ufp.drop_index_if_pandas(self.static_features_)
self.last_dates = last_dates[idxs]
if self.target_transforms is not None:
for i, tfm in enumerate(self.target_transforms):
if isinstance(tfm, _BaseGroupedArrayTargetTransform):
self.target_transforms[i] = tfm.take(idxs)
for name, tfm in self.transforms.items():
if hasattr(tfm, "take"):
tfm = tfm.take(idxs)
self.transforms[name] = tfm
try:
yield
finally:
self.ga = ga
self.uids = uids
self.static_features_ = statics
self.last_dates = last_dates
self.target_transforms = targ_tfms
self.lag_tfms = lag_tfms

def predict(
self,
models: Dict[str, Union[BaseEstimator, List[BaseEstimator]]],
Expand All @@ -651,61 +702,52 @@ def predict(
raise ValueError(
f"The following ids weren't seen during training and thus can't be forecasted: {unseen}"
)
self._idxs: Optional[np.ndarray] = np.where(ufp.is_in(self.uids, ids))[0]
self._uids = self.uids[self._idxs]
self._static_features = ufp.take_rows(self.static_features_, self._idxs)
self._static_features = ufp.drop_index_if_pandas(self._static_features)
last_dates = self.last_dates[self._idxs]
idxs: Optional[np.ndarray] = np.where(ufp.is_in(self.uids, ids))[0]
else:
self._idxs = None
self._uids = self.uids
self._static_features = self.static_features_
last_dates = self.last_dates
if X_df is not None:
if self.id_col not in X_df or self.time_col not in X_df:
raise ValueError(
f"X_df must have '{self.id_col}' and '{self.time_col}' columns."
idxs = None
with self._maybe_subset(idxs):
if X_df is not None:
if self.id_col not in X_df or self.time_col not in X_df:
raise ValueError(
f"X_df must have '{self.id_col}' and '{self.time_col}' columns."
)
if X_df.shape[1] < 3:
raise ValueError("Found no exogenous features in `X_df`.")
statics = [c for c in self.static_features_.columns if c != self.id_col]
dynamics = [
c for c in X_df.columns if c not in [self.id_col, self.time_col]
]
common = [c for c in dynamics if c in statics]
if common:
raise ValueError(
f"The following features were provided through `X_df` but were considered as static during fit: {common}.\n"
"Please re-run the fit step using the `static_features` argument to indicate which features are static. "
"If all your features are dynamic please pass an empty list (static_features=[])."
)
starts = ufp.offset_times(self.last_dates, self.freq, 1)
ends = ufp.offset_times(self.last_dates, self.freq, horizon)
dates_validation = type(X_df)(
{
self.id_col: self.uids,
"_start": starts,
"_end": ends,
}
)
if X_df.shape[1] < 3:
raise ValueError("Found no exogenous features in `X_df`.")
statics = [c for c in self.static_features_.columns if c != self.id_col]
dynamics = [
c for c in X_df.columns if c not in [self.id_col, self.time_col]
]
common = [c for c in dynamics if c in statics]
if common:
raise ValueError(
f"The following features were provided through `X_df` but were considered as static during fit: {common}.\n"
"Please re-run the fit step using the `static_features` argument to indicate which features are static. "
"If all your features are dynamic please pass an empty list (static_features=[])."
X_df = ufp.join(X_df, dates_validation, on=self.id_col)
mask = ufp.between(X_df[self.time_col], X_df["_start"], X_df["_end"])
X_df = ufp.filter_with_mask(X_df, mask)
if X_df.shape[0] != len(self.uids) * horizon:
msg = (
"Found missing inputs in X_df. "
"It should have one row per id and time for the complete forecasting horizon.\n"
"You can get the expected structure by running `MLForecast.make_future_dataframe(h)` "
"or get the missing combinatins in your current `X_df` by running `MLForecast.get_missing_future(h, X_df)`."
)
raise ValueError(msg)
drop_cols = [self.id_col, self.time_col, "_start", "_end"]
X_df = ufp.sort(X_df, [self.id_col, self.time_col]).drop(
columns=drop_cols
)
starts = ufp.offset_times(last_dates, self.freq, 1)
ends = ufp.offset_times(last_dates, self.freq, horizon)
df_constructor = type(X_df)
dates_validation = df_constructor(
{
self.id_col: self._uids,
"_start": starts,
"_end": ends,
}
)
X_df = ufp.join(X_df, dates_validation, on=self.id_col)
mask = ufp.between(X_df[self.time_col], X_df["_start"], X_df["_end"])
X_df = ufp.filter_with_mask(X_df, mask)
if X_df.shape[0] != len(self._uids) * horizon:
msg = (
"Found missing inputs in X_df. "
"It should have one row per id and time for the complete forecasting horizon.\n"
"You can get the expected structure by running `MLForecast.make_future_dataframe(h)` "
"or get the missing combinatins in your current `X_df` by running `MLForecast.get_missing_future(h, X_df)`."
)
raise ValueError(msg)
drop_cols = [self.id_col, self.time_col, "_start", "_end"]
X_df = ufp.sort(X_df, [self.id_col, self.time_col]).drop(columns=drop_cols)
# backup original series. the ga attribute gets modified
# and is copied from _ga at the start of each model's predict
self._ga = copy.copy(self.ga)
try:
if getattr(self, "max_horizon", None) is None:
preds = self._predict_recursive(
models=models,
Expand All @@ -721,28 +763,24 @@ def predict(
before_predict_callback=before_predict_callback,
X_df=X_df,
)
finally:
self.ga = self._ga
del self._ga
if self.target_transforms is not None:
if self._has_ga_target_tfms():
model_cols = [
c for c in preds.columns if c not in (self.id_col, self.time_col)
]
indptr = np.arange(0, horizon * (len(self._uids) + 1), horizon)
for tfm in self.target_transforms[::-1]:
if isinstance(tfm, _BaseGroupedArrayTargetTransform):
if self._idxs is not None:
tfm = tfm.take(self._idxs)
for col in model_cols:
ga = GroupedArray(
preds[col].to_numpy().astype(self.ga.data.dtype), indptr
)
ga = tfm.inverse_transform(ga)
preds = ufp.assign_columns(preds, col, ga.data)
else:
preds = tfm.inverse_transform(preds)
del self._uids, self._idxs, self._static_features, self._transforms
if self.target_transforms is not None:
if self._has_ga_target_tfms():
model_cols = [
c
for c in preds.columns
if c not in (self.id_col, self.time_col)
]
indptr = np.arange(0, horizon * (len(self.uids) + 1), horizon)
for tfm in self.target_transforms[::-1]:
if isinstance(tfm, _BaseGroupedArrayTargetTransform):
for col in model_cols:
ga = GroupedArray(
preds[col].to_numpy().astype(self.ga.data.dtype), indptr
)
ga = tfm.inverse_transform(ga)
preds = ufp.assign_columns(preds, col, ga.data)
else:
preds = tfm.inverse_transform(preds)
return preds

def save(self, path: Union[str, Path]) -> None:
Expand Down
Loading

0 comments on commit e08e10c

Please sign in to comment.