Skip to content

Commit

Permalink
add X_df argument to distributed predict (#286)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Dec 13, 2023
1 parent 7b62f6f commit 4fdade5
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 147 deletions.
8 changes: 8 additions & 0 deletions mlforecast/distributed/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def _predict(
horizon,
before_predict_callback=None,
after_predict_callback=None,
X_df=None,
) -> Iterable[pd.DataFrame]:
for serialized_ts, _, serialized_valid in items:
valid = cloudpickle.loads(serialized_valid)
Expand All @@ -437,6 +438,7 @@ def _predict(
horizon=horizon,
before_predict_callback=before_predict_callback,
after_predict_callback=after_predict_callback,
X_df=X_df,
)
if valid is not None:
res = res.merge(valid, how="left")
Expand All @@ -453,6 +455,7 @@ def predict(
h: int,
before_predict_callback: Optional[Callable] = None,
after_predict_callback: Optional[Callable] = None,
X_df: Optional[pd.DataFrame] = None,
new_df: Optional[fugue.AnyDataFrame] = None,
) -> fugue.AnyDataFrame:
"""Compute the predictions for the next `horizon` steps.
Expand All @@ -469,6 +472,8 @@ def predict(
Function to call on the predictions before updating the targets.
This function will take a pandas Series with the predictions and should return another one with the same structure.
The series identifier is on the index.
X_df : pandas DataFrame, optional (default=None)
Dataframe with the future exogenous features. Should have the id column and the time column.
new_df : dask or spark DataFrame, optional (default=None)
Series data of new observations for which forecasts are to be generated.
This dataframe should have the same structure as the one used to fit the model, including any features and time series data.
Expand All @@ -493,6 +498,8 @@ def predict(
else:
partition_results = self.partition_results
schema = self._get_predict_schema()
if X_df is not None and not isinstance(X_df, pd.DataFrame):
raise ValueError("`X_df` should be a pandas DataFrame")
res = fa.transform(
partition_results,
DistributedMLForecast._predict,
Expand All @@ -501,6 +508,7 @@ def predict(
"horizon": h,
"before_predict_callback": before_predict_callback,
"after_predict_callback": after_predict_callback,
"X_df": X_df,
},
schema=schema,
engine=self.engine,
Expand Down
Loading

0 comments on commit 4fdade5

Please sign in to comment.