From c83395276e9c4be6ef5644350a6486531349ce42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 22 Nov 2024 14:44:31 -0600 Subject: [PATCH] feat(distributed): support ids in predict (#454) --- mlforecast/distributed/forecast.py | 39 +++++--- nbs/distributed.forecast.ipynb | 37 ++++++-- .../quick_start_distributed.ipynb | 95 +++++++++++-------- 3 files changed, 111 insertions(+), 60 deletions(-) diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index defef0ad..fab92a53 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -38,6 +38,7 @@ except ModuleNotFoundError: RAY_INSTALLED = False from sklearn.base import clone +from triad import Schema from mlforecast.core import ( DateFeature, @@ -455,31 +456,43 @@ def _predict( before_predict_callback=None, after_predict_callback=None, X_df=None, + ids=None, + schema=None, ) -> Iterable[pd.DataFrame]: for serialized_ts, _, serialized_valid in items: valid = cloudpickle.loads(serialized_valid) if valid is not None: X_df = valid ts = cloudpickle.loads(serialized_ts) + if ids is not None: + ids = ts.uids.intersection(ids).tolist() + if not ids: + yield pd.DataFrame( + { + field.name: pd.Series(dtype=field.type.to_pandas_dtype()) + for field in schema.values() + } + ) + return res = ts.predict( models=models, horizon=horizon, before_predict_callback=before_predict_callback, after_predict_callback=after_predict_callback, X_df=X_df, + ids=ids, ) if valid is not None: res = res.merge(valid, how="left") yield res - def _get_predict_schema(self) -> str: - model_names = self.models.keys() - models_schema = ",".join(f"{model_name}:double" for model_name in model_names) - schema = ( - f"{self._base_ts.id_col}:string,{self._base_ts.time_col}:datetime," - + models_schema - ) - return schema + def _get_predict_schema(self) -> Schema: + ids_schema = [ + (self._base_ts.id_col, "string"), + (self._base_ts.time_col, "datetime"), + ] + models_schema = [(model, "double") for model in self.models.keys()] + return Schema(ids_schema + models_schema) def predict( self, @@ -488,6 +501,7 @@ def predict( after_predict_callback: Optional[Callable] = None, X_df: Optional[pd.DataFrame] = None, new_df: Optional[fugue.AnyDataFrame] = None, + ids: Optional[List[str]] = None, ) -> fugue.AnyDataFrame: """Compute the predictions for the next `horizon` steps. @@ -509,6 +523,8 @@ def predict( Series data of new observations for which forecasts are to be generated. This dataframe should have the same structure as the one used to fit the model, including any features and time series data. If `new_df` is not None, the method will generate forecasts for the new observations. + ids : list of str, optional (default=None) + List with subset of ids seen during training for which the forecasts should be computed. Returns ------- @@ -540,6 +556,8 @@ def predict( "before_predict_callback": before_predict_callback, "after_predict_callback": after_predict_callback, "X_df": X_df, + "ids": ids, + "schema": schema, }, schema=schema, engine=self.engine, @@ -636,9 +654,8 @@ def cross_validation( keep_last_n=keep_last_n, window_info=window_info, ) - schema = ( - self._get_predict_schema() - + f",cutoff:datetime,{self._base_ts.target_col}:double" + schema = self._get_predict_schema() + Schema( + ("cutoff", "datetime"), (self._base_ts.target_col, "double") ) preds = fa.transform( partition_results, diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index e287d034..ab3f528d 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -98,6 +98,7 @@ "except ModuleNotFoundError:\n", " RAY_INSTALLED = False\n", "from sklearn.base import clone\n", + "from triad import Schema\n", "\n", "from mlforecast.core import (\n", " DateFeature,\n", @@ -506,29 +507,41 @@ " horizon,\n", " before_predict_callback=None,\n", " after_predict_callback=None,\n", - " X_df=None, \n", + " X_df=None,\n", + " ids=None,\n", + " schema=None,\n", " ) -> Iterable[pd.DataFrame]:\n", " for serialized_ts, _, serialized_valid in items:\n", " valid = cloudpickle.loads(serialized_valid)\n", " if valid is not None:\n", " X_df = valid\n", " ts = cloudpickle.loads(serialized_ts)\n", + " if ids is not None:\n", + " ids = ts.uids.intersection(ids).tolist()\n", + " if not ids:\n", + " yield pd.DataFrame(\n", + " {\n", + " field.name: pd.Series(dtype=field.type.to_pandas_dtype())\n", + " for field in schema.values()\n", + " }\n", + " )\n", + " return\n", " res = ts.predict(\n", " models=models,\n", " horizon=horizon,\n", " before_predict_callback=before_predict_callback,\n", " after_predict_callback=after_predict_callback,\n", " X_df=X_df,\n", + " ids=ids,\n", " )\n", " if valid is not None:\n", " res = res.merge(valid, how='left')\n", " yield res\n", " \n", - " def _get_predict_schema(self) -> str:\n", - " model_names = self.models.keys()\n", - " models_schema = ','.join(f'{model_name}:double' for model_name in model_names)\n", - " schema = f'{self._base_ts.id_col}:string,{self._base_ts.time_col}:datetime,' + models_schema\n", - " return schema\n", + " def _get_predict_schema(self) -> Schema:\n", + " ids_schema = [(self._base_ts.id_col, 'string'), (self._base_ts.time_col, 'datetime')]\n", + " models_schema = [(model, 'double') for model in self.models.keys()]\n", + " return Schema(ids_schema + models_schema)\n", "\n", " def predict(\n", " self,\n", @@ -537,6 +550,7 @@ " after_predict_callback: Optional[Callable] = None,\n", " X_df: Optional[pd.DataFrame] = None,\n", " new_df: Optional[fugue.AnyDataFrame] = None,\n", + " ids: Optional[List[str]] = None,\n", " ) -> fugue.AnyDataFrame:\n", " \"\"\"Compute the predictions for the next `horizon` steps.\n", "\n", @@ -557,7 +571,9 @@ " new_df : dask or spark DataFrame, optional (default=None)\n", " Series data of new observations for which forecasts are to be generated.\n", " This dataframe should have the same structure as the one used to fit the model, including any features and time series data.\n", - " If `new_df` is not None, the method will generate forecasts for the new observations. \n", + " If `new_df` is not None, the method will generate forecasts for the new observations.\n", + " ids : list of str, optional (default=None)\n", + " List with subset of ids seen during training for which the forecasts should be computed. \n", "\n", " Returns\n", " -------\n", @@ -589,6 +605,8 @@ " 'before_predict_callback': before_predict_callback,\n", " 'after_predict_callback': after_predict_callback,\n", " 'X_df': X_df,\n", + " 'ids': ids,\n", + " 'schema': schema,\n", " },\n", " schema=schema,\n", " engine=self.engine,\n", @@ -685,7 +703,10 @@ " keep_last_n=keep_last_n,\n", " window_info=window_info,\n", " )\n", - " schema = self._get_predict_schema() + f',cutoff:datetime,{self._base_ts.target_col}:double'\n", + " schema = (\n", + " self._get_predict_schema() + Schema(\n", + " ('cutoff', 'datetime'), (self._base_ts.target_col, 'double'))\n", + " )\n", " preds = fa.transform(\n", " partition_results,\n", " DistributedMLForecast._predict,\n", diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb index 2e8d53c3..ac5b75c4 100644 --- a/nbs/docs/getting-started/quick_start_distributed.ipynb +++ b/nbs/docs/getting-started/quick_start_distributed.ipynb @@ -366,32 +366,31 @@ "source": [ "#| hide\n", "# test num_partitions works properly\n", - "if sys.version_info >= (3, 9):\n", - " num_partitions_test = 4\n", - " test_dd = dd.from_pandas(series, npartitions=num_partitions_test) # In this case we dont have to specify the column\n", - " test_dd['unique_id'] = test_dd['unique_id'].astype(str)\n", - " fcst_np = DistributedMLForecast(\n", - " models=models,\n", - " freq='D',\n", - " target_transforms=[Differences([7])], \n", - " lags=[7],\n", - " lag_transforms={\n", - " 1: [ExpandingMean()],\n", - " 7: [RollingMean(window_size=14)]\n", - " },\n", - " date_features=['dayofweek', 'month'],\n", - " num_threads=1,\n", - " engine=client,\n", - " num_partitions=num_partitions_test\n", - " )\n", - " fcst_np.fit(test_dd)\n", - " test_partition_results_size(fcst_np, num_partitions_test)\n", - " preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - " preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - " pd.testing.assert_frame_equal(\n", - " preds[['unique_id', 'ds']], \n", - " preds_np[['unique_id', 'ds']], \n", - " )" + "num_partitions_test = 4\n", + "test_dd = dd.from_pandas(series, npartitions=num_partitions_test) # In this case we dont have to specify the column\n", + "test_dd['unique_id'] = test_dd['unique_id'].astype(str)\n", + "fcst_np = DistributedMLForecast(\n", + " models=models,\n", + " freq='D',\n", + " target_transforms=[Differences([7])], \n", + " lags=[7],\n", + " lag_transforms={\n", + " 1: [ExpandingMean()],\n", + " 7: [RollingMean(window_size=14)]\n", + " },\n", + " date_features=['dayofweek', 'month'],\n", + " num_threads=1,\n", + " engine=client,\n", + " num_partitions=num_partitions_test\n", + ")\n", + "fcst_np.fit(test_dd)\n", + "test_partition_results_size(fcst_np, num_partitions_test)\n", + "preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "pd.testing.assert_frame_equal(\n", + " preds[['unique_id', 'ds']], \n", + " preds_np[['unique_id', 'ds']], \n", + ")" ] }, { @@ -448,36 +447,36 @@ " 0\n", " id_00\n", " 2002-09-27 00:00:00\n", - " 22.267619\n", - " 21.835798\n", + " 21.722841\n", + " 21.725511\n", " \n", " \n", " 1\n", " id_00\n", " 2002-09-28 00:00:00\n", - " 85.230055\n", - " 83.996424\n", + " 84.918194\n", + " 84.606362\n", " \n", " \n", " 2\n", " id_00\n", " 2002-09-29 00:00:00\n", - " 168.256154\n", - " 163.076652\n", + " 162.067624\n", + " 163.36802\n", " \n", " \n", " 3\n", " id_00\n", " 2002-09-30 00:00:00\n", - " 246.712244\n", - " 245.827467\n", + " 249.001477\n", + " 246.422894\n", " \n", " \n", " 4\n", " id_00\n", " 2002-10-01 00:00:00\n", - " 314.184225\n", - " 315.257849\n", + " 317.149512\n", + " 315.538403\n", " \n", " \n", "\n", @@ -485,11 +484,11 @@ ], "text/plain": [ " unique_id ds DaskXGBForecast DaskLGBMForecast\n", - "0 id_00 2002-09-27 00:00:00 22.267619 21.835798\n", - "1 id_00 2002-09-28 00:00:00 85.230055 83.996424\n", - "2 id_00 2002-09-29 00:00:00 168.256154 163.076652\n", - "3 id_00 2002-09-30 00:00:00 246.712244 245.827467\n", - "4 id_00 2002-10-01 00:00:00 314.184225 315.257849" + "0 id_00 2002-09-27 00:00:00 21.722841 21.725511\n", + "1 id_00 2002-09-28 00:00:00 84.918194 84.606362\n", + "2 id_00 2002-09-29 00:00:00 162.067624 163.36802\n", + "3 id_00 2002-09-30 00:00:00 249.001477 246.422894\n", + "4 id_00 2002-10-01 00:00:00 317.149512 315.538403" ] }, "execution_count": null, @@ -502,6 +501,20 @@ "preds.head()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0150de6d-88b5-4513-bd82-c835ba945e79", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# predict with ids\n", + "ids = np.random.choice(series['unique_id'].unique(), size=10, replace=False)\n", + "preds_ids = fcst.predict(7, X_df=future[future['unique_id'].isin(ids)], ids=ids).compute()\n", + "assert set(preds_ids['unique_id']) == set(ids)" + ] + }, { "cell_type": "code", "execution_count": null,