Skip to content

Commit

Permalink
use TypeVar for dataframes and distribute py.typed file (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Aug 30, 2024
1 parent 18bf6ab commit 0b5b1ca
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 81 deletions.
21 changes: 11 additions & 10 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sklearn.base import BaseEstimator, clone
from sklearn.pipeline import Pipeline
from utilsforecast.compat import (
DFType,
DataFrame,
pl,
pl_DataFrame,
Expand Down Expand Up @@ -367,12 +368,12 @@ def _compute_date_feature(self, dates, feature):

def _transform(
self,
df: DataFrame,
df: DFType,
dropna: bool = True,
max_horizon: Optional[int] = None,
return_X_y: bool = False,
as_numpy: bool = False,
) -> pd.DataFrame:
) -> DFType:
"""Add the features to `df`.
if `dropna=True` then all the null rows are dropped."""
Expand Down Expand Up @@ -480,7 +481,7 @@ def _transform(

def fit_transform(
self,
data: DataFrame,
data: DFType,
id_col: str,
time_col: str,
target_col: str,
Expand All @@ -490,7 +491,7 @@ def fit_transform(
max_horizon: Optional[int] = None,
return_X_y: bool = False,
as_numpy: bool = False,
) -> Union[DataFrame, Tuple[DataFrame, np.ndarray]]:
) -> Union[DFType, Tuple[DFType, np.ndarray]]:
"""Add the features to `data` and save the required information for the predictions step.
If not all features are static, specify which ones are in `static_features`.
Expand Down Expand Up @@ -626,8 +627,8 @@ def _predict_recursive(
horizon: int,
before_predict_callback: Optional[Callable] = None,
after_predict_callback: Optional[Callable] = None,
X_df: Optional[DataFrame] = None,
) -> DataFrame:
X_df: Optional[DFType] = None,
) -> DFType:
"""Use `model` to predict the next `horizon` timesteps."""
for i, (name, model) in enumerate(models.items()):
with self._backup():
Expand All @@ -654,8 +655,8 @@ def _predict_multi(
models: Dict[str, BaseEstimator],
horizon: int,
before_predict_callback: Optional[Callable] = None,
X_df: Optional[DataFrame] = None,
) -> DataFrame:
X_df: Optional[DFType] = None,
) -> DFType:
assert self.max_horizon is not None
if horizon > self.max_horizon:
raise ValueError(
Expand Down Expand Up @@ -729,9 +730,9 @@ def predict(
horizon: int,
before_predict_callback: Optional[Callable] = None,
after_predict_callback: Optional[Callable] = None,
X_df: Optional[DataFrame] = None,
X_df: Optional[DFType] = None,
ids: Optional[List[str]] = None,
) -> DataFrame:
) -> DFType:
if ids is not None:
unseen = set(ids) - set(self.uids)
if unseen:
Expand Down
47 changes: 23 additions & 24 deletions mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import cloudpickle
import fsspec
import numpy as np
import pandas as pd
import utilsforecast.processing as ufp
from sklearn.base import BaseEstimator, clone
from utilsforecast.compat import DataFrame
from utilsforecast.compat import DFType, DataFrame

from mlforecast.core import (
DateFeature,
Expand All @@ -38,15 +37,15 @@

# %% ../nbs/forecast.ipynb 6
def _add_conformal_distribution_intervals(
fcst_df: DataFrame,
cs_df: DataFrame,
fcst_df: DFType,
cs_df: DFType,
model_names: List[str],
level: List[Union[int, float]],
cs_n_windows: int,
cs_h: int,
n_series: int,
horizon: int,
) -> DataFrame:
) -> DFType:
"""
Adds conformal intervals to a `fcst_df` based on conformal scores `cs_df`.
`level` should be already sorted. This strategy creates forecasts paths
Expand Down Expand Up @@ -76,15 +75,15 @@ def _add_conformal_distribution_intervals(

# %% ../nbs/forecast.ipynb 7
def _add_conformal_error_intervals(
fcst_df: DataFrame,
cs_df: DataFrame,
fcst_df: DFType,
cs_df: DFType,
model_names: List[str],
level: List[Union[int, float]],
cs_n_windows: int,
cs_h: int,
n_series: int,
horizon: int,
) -> DataFrame:
) -> DFType:
"""
Adds conformal intervals to a `fcst_df` based on conformal scores `cs_df`.
`level` should be already sorted. This startegy creates prediction intervals
Expand Down Expand Up @@ -205,7 +204,7 @@ def from_cv(cls, cv: "LightGBMCV") -> "MLForecast":

def preprocess(
self,
df: DataFrame,
df: DFType,
id_col: str = "unique_id",
time_col: str = "ds",
target_col: str = "y",
Expand All @@ -215,7 +214,7 @@ def preprocess(
max_horizon: Optional[int] = None,
return_X_y: bool = False,
as_numpy: bool = False,
) -> Union[DataFrame, Tuple[DataFrame, np.ndarray]]:
) -> Union[DFType, Tuple[DFType, np.ndarray]]:
"""Add the features to `data`.
Parameters
Expand Down Expand Up @@ -297,7 +296,7 @@ def fit_models(

def _conformity_scores(
self,
df: DataFrame,
df: DFType,
id_col: str,
time_col: str,
target_col: str,
Expand All @@ -308,7 +307,7 @@ def _conformity_scores(
n_windows: int = 2,
h: int = 1,
as_numpy: bool = False,
):
) -> DFType:
"""Compute conformity scores.
We need at least two cross validation errors to compute
Expand Down Expand Up @@ -349,7 +348,7 @@ def _conformity_scores(
cv_results = ufp.assign_columns(cv_results, model, abs_err)
return ufp.drop_columns(cv_results, target_col)

def _invert_transforms_fitted(self, df: pd.DataFrame) -> pd.DataFrame:
def _invert_transforms_fitted(self, df: DFType) -> DFType:
if self.ts.target_transforms is None:
return df
if any(
Expand Down Expand Up @@ -379,9 +378,9 @@ def _invert_transforms_fitted(self, df: pd.DataFrame) -> pd.DataFrame:

def _extract_X_y(
self,
prep: DataFrame,
prep: DFType,
target_col: str,
) -> Tuple[Union[DataFrame, np.ndarray], np.ndarray]:
) -> Tuple[Union[DFType, np.ndarray], np.ndarray]:
X = prep[self.ts.features_order_]
targets = [c for c in prep.columns if re.match(rf"^{target_col}\d*$", c)]
if len(targets) == 1:
Expand All @@ -391,14 +390,14 @@ def _extract_X_y(

def _compute_fitted_values(
self,
base: DataFrame,
X: Union[DataFrame, np.ndarray],
base: DFType,
X: Union[DFType, np.ndarray],
y: np.ndarray,
id_col: str,
time_col: str,
target_col: str,
max_horizon: Optional[int],
) -> DataFrame:
) -> DFType:
base = ufp.copy_if_pandas(base, deep=False)
sort_idxs = ufp.maybe_compute_sort_indices(base, id_col, time_col)
if sort_idxs is not None:
Expand Down Expand Up @@ -597,7 +596,7 @@ def make_future_dataframe(self, h: int) -> DataFrame:
time_col=self.ts.time_col,
)

def get_missing_future(self, h: int, X_df: DataFrame) -> DataFrame:
def get_missing_future(self, h: int, X_df: DFType) -> DFType:
"""Get the missing id and time combinations in `X_df`.
Parameters
Expand All @@ -621,11 +620,11 @@ def predict(
h: int,
before_predict_callback: Optional[Callable] = None,
after_predict_callback: Optional[Callable] = None,
new_df: Optional[DataFrame] = None,
new_df: Optional[DFType] = None,
level: Optional[List[Union[int, float]]] = None,
X_df: Optional[DataFrame] = None,
X_df: Optional[DFType] = None,
ids: Optional[List[str]] = None,
) -> DataFrame:
) -> DFType:
"""Compute the predictions for the next `h` steps.
Parameters
Expand Down Expand Up @@ -766,7 +765,7 @@ def predict(

def cross_validation(
self,
df: DataFrame,
df: DFType,
n_windows: int,
h: int,
id_col: str = "unique_id",
Expand All @@ -785,7 +784,7 @@ def cross_validation(
input_size: Optional[int] = None,
fitted: bool = False,
as_numpy: bool = False,
) -> DataFrame:
) -> DFType:
"""Perform time series cross validation.
Creates `n_windows` splits where each window has `h` test periods,
trains the models, computes the predictions and merges the actuals.
Expand Down
Empty file added mlforecast/py.typed
Empty file.
21 changes: 11 additions & 10 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"from sklearn.base import BaseEstimator, clone\n",
"from sklearn.pipeline import Pipeline\n",
"from utilsforecast.compat import (\n",
" DFType,\n",
" DataFrame,\n",
" pl,\n",
" pl_DataFrame,\n",
Expand Down Expand Up @@ -848,12 +849,12 @@
"\n",
" def _transform(\n",
" self,\n",
" df: DataFrame,\n",
" df: DFType,\n",
" dropna: bool = True,\n",
" max_horizon: Optional[int] = None,\n",
" return_X_y: bool = False,\n",
" as_numpy: bool = False,\n",
" ) -> pd.DataFrame:\n",
" ) -> DFType:\n",
" \"\"\"Add the features to `df`.\n",
" \n",
" if `dropna=True` then all the null rows are dropped.\"\"\"\n",
Expand Down Expand Up @@ -958,7 +959,7 @@
"\n",
" def fit_transform(\n",
" self,\n",
" data: DataFrame,\n",
" data: DFType,\n",
" id_col: str,\n",
" time_col: str,\n",
" target_col: str,\n",
Expand All @@ -968,7 +969,7 @@
" max_horizon: Optional[int] = None,\n",
" return_X_y: bool = False,\n",
" as_numpy: bool = False,\n",
" ) -> Union[DataFrame, Tuple[DataFrame, np.ndarray]]:\n",
" ) -> Union[DFType, Tuple[DFType, np.ndarray]]:\n",
" \"\"\"Add the features to `data` and save the required information for the predictions step.\n",
" \n",
" If not all features are static, specify which ones are in `static_features`.\n",
Expand Down Expand Up @@ -1103,8 +1104,8 @@
" horizon: int,\n",
" before_predict_callback: Optional[Callable] = None,\n",
" after_predict_callback: Optional[Callable] = None,\n",
" X_df: Optional[DataFrame] = None,\n",
" ) -> DataFrame:\n",
" X_df: Optional[DFType] = None,\n",
" ) -> DFType:\n",
" \"\"\"Use `model` to predict the next `horizon` timesteps.\"\"\"\n",
" for i, (name, model) in enumerate(models.items()):\n",
" with self._backup():\n",
Expand All @@ -1131,8 +1132,8 @@
" models: Dict[str, BaseEstimator],\n",
" horizon: int,\n",
" before_predict_callback: Optional[Callable] = None,\n",
" X_df: Optional[DataFrame] = None,\n",
" ) -> DataFrame:\n",
" X_df: Optional[DFType] = None,\n",
" ) -> DFType:\n",
" assert self.max_horizon is not None\n",
" if horizon > self.max_horizon:\n",
" raise ValueError(f'horizon must be at most max_horizon ({self.max_horizon})')\n",
Expand Down Expand Up @@ -1201,9 +1202,9 @@
" horizon: int,\n",
" before_predict_callback: Optional[Callable] = None,\n",
" after_predict_callback: Optional[Callable] = None,\n",
" X_df: Optional[DataFrame] = None,\n",
" X_df: Optional[DFType] = None,\n",
" ids: Optional[List[str]] = None,\n",
" ) -> DataFrame:\n",
" ) -> DFType:\n",
" if ids is not None:\n",
" unseen = set(ids) - set(self.uids)\n",
" if unseen:\n",
Expand Down
Loading

0 comments on commit 0b5b1ca

Please sign in to comment.