From fb42b70aefbcfc5d204b23304a8e6bb2517a68f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 22 Nov 2024 12:31:10 -0600 Subject: [PATCH 1/2] feat(distributed): support ids in predict --- mlforecast/distributed/forecast.py | 34 +++++-- nbs/distributed.forecast.ipynb | 32 +++++-- .../quick_start_distributed.ipynb | 95 +++++++++++-------- 3 files changed, 105 insertions(+), 56 deletions(-) diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index defef0ad..3d033aa8 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -38,6 +38,7 @@ except ModuleNotFoundError: RAY_INSTALLED = False from sklearn.base import clone +from triad import Schema from mlforecast.core import ( DateFeature, @@ -455,31 +456,43 @@ def _predict( before_predict_callback=None, after_predict_callback=None, X_df=None, + ids=None, + schema=None, ) -> Iterable[pd.DataFrame]: for serialized_ts, _, serialized_valid in items: valid = cloudpickle.loads(serialized_valid) if valid is not None: X_df = valid ts = cloudpickle.loads(serialized_ts) + if ids is not None: + ids = ts.uids.intersection(ids).tolist() + if not ids: + yield pd.DataFrame( + { + field.name: pd.Series(dtype=field.type.to_pandas_dtype()) + for field in schema.values() + } + ) + return res = ts.predict( models=models, horizon=horizon, before_predict_callback=before_predict_callback, after_predict_callback=after_predict_callback, X_df=X_df, + ids=ids, ) if valid is not None: res = res.merge(valid, how="left") yield res - def _get_predict_schema(self) -> str: - model_names = self.models.keys() - models_schema = ",".join(f"{model_name}:double" for model_name in model_names) - schema = ( - f"{self._base_ts.id_col}:string,{self._base_ts.time_col}:datetime," - + models_schema - ) - return schema + def _get_predict_schema(self) -> Schema: + ids_schema = [ + (self._base_ts.id_col, "string"), + (self._base_ts.time_col, "datetime"), + ] + models_schema = [(model, "double") for model in self.models.keys()] + return Schema(ids_schema + models_schema) def predict( self, @@ -488,6 +501,7 @@ def predict( after_predict_callback: Optional[Callable] = None, X_df: Optional[pd.DataFrame] = None, new_df: Optional[fugue.AnyDataFrame] = None, + ids: Optional[List[str]] = None, ) -> fugue.AnyDataFrame: """Compute the predictions for the next `horizon` steps. @@ -509,6 +523,8 @@ def predict( Series data of new observations for which forecasts are to be generated. This dataframe should have the same structure as the one used to fit the model, including any features and time series data. If `new_df` is not None, the method will generate forecasts for the new observations. + ids : list of str, optional (default=None) + List with subset of ids seen during training for which the forecasts should be computed. Returns ------- @@ -540,6 +556,8 @@ def predict( "before_predict_callback": before_predict_callback, "after_predict_callback": after_predict_callback, "X_df": X_df, + "ids": ids, + "schema": schema, }, schema=schema, engine=self.engine, diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index e287d034..b9871d64 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -98,6 +98,7 @@ "except ModuleNotFoundError:\n", " RAY_INSTALLED = False\n", "from sklearn.base import clone\n", + "from triad import Schema\n", "\n", "from mlforecast.core import (\n", " DateFeature,\n", @@ -506,29 +507,41 @@ " horizon,\n", " before_predict_callback=None,\n", " after_predict_callback=None,\n", - " X_df=None, \n", + " X_df=None,\n", + " ids=None,\n", + " schema=None,\n", " ) -> Iterable[pd.DataFrame]:\n", " for serialized_ts, _, serialized_valid in items:\n", " valid = cloudpickle.loads(serialized_valid)\n", " if valid is not None:\n", " X_df = valid\n", " ts = cloudpickle.loads(serialized_ts)\n", + " if ids is not None:\n", + " ids = ts.uids.intersection(ids).tolist()\n", + " if not ids:\n", + " yield pd.DataFrame(\n", + " {\n", + " field.name: pd.Series(dtype=field.type.to_pandas_dtype())\n", + " for field in schema.values()\n", + " }\n", + " )\n", + " return\n", " res = ts.predict(\n", " models=models,\n", " horizon=horizon,\n", " before_predict_callback=before_predict_callback,\n", " after_predict_callback=after_predict_callback,\n", " X_df=X_df,\n", + " ids=ids,\n", " )\n", " if valid is not None:\n", " res = res.merge(valid, how='left')\n", " yield res\n", " \n", - " def _get_predict_schema(self) -> str:\n", - " model_names = self.models.keys()\n", - " models_schema = ','.join(f'{model_name}:double' for model_name in model_names)\n", - " schema = f'{self._base_ts.id_col}:string,{self._base_ts.time_col}:datetime,' + models_schema\n", - " return schema\n", + " def _get_predict_schema(self) -> Schema:\n", + " ids_schema = [(self._base_ts.id_col, 'string'), (self._base_ts.time_col, 'datetime')]\n", + " models_schema = [(model, 'double') for model in self.models.keys()]\n", + " return Schema(ids_schema + models_schema)\n", "\n", " def predict(\n", " self,\n", @@ -537,6 +550,7 @@ " after_predict_callback: Optional[Callable] = None,\n", " X_df: Optional[pd.DataFrame] = None,\n", " new_df: Optional[fugue.AnyDataFrame] = None,\n", + " ids: Optional[List[str]] = None,\n", " ) -> fugue.AnyDataFrame:\n", " \"\"\"Compute the predictions for the next `horizon` steps.\n", "\n", @@ -557,7 +571,9 @@ " new_df : dask or spark DataFrame, optional (default=None)\n", " Series data of new observations for which forecasts are to be generated.\n", " This dataframe should have the same structure as the one used to fit the model, including any features and time series data.\n", - " If `new_df` is not None, the method will generate forecasts for the new observations. \n", + " If `new_df` is not None, the method will generate forecasts for the new observations.\n", + " ids : list of str, optional (default=None)\n", + " List with subset of ids seen during training for which the forecasts should be computed. \n", "\n", " Returns\n", " -------\n", @@ -589,6 +605,8 @@ " 'before_predict_callback': before_predict_callback,\n", " 'after_predict_callback': after_predict_callback,\n", " 'X_df': X_df,\n", + " 'ids': ids,\n", + " 'schema': schema,\n", " },\n", " schema=schema,\n", " engine=self.engine,\n", diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb index 2e8d53c3..ac5b75c4 100644 --- a/nbs/docs/getting-started/quick_start_distributed.ipynb +++ b/nbs/docs/getting-started/quick_start_distributed.ipynb @@ -366,32 +366,31 @@ "source": [ "#| hide\n", "# test num_partitions works properly\n", - "if sys.version_info >= (3, 9):\n", - " num_partitions_test = 4\n", - " test_dd = dd.from_pandas(series, npartitions=num_partitions_test) # In this case we dont have to specify the column\n", - " test_dd['unique_id'] = test_dd['unique_id'].astype(str)\n", - " fcst_np = DistributedMLForecast(\n", - " models=models,\n", - " freq='D',\n", - " target_transforms=[Differences([7])], \n", - " lags=[7],\n", - " lag_transforms={\n", - " 1: [ExpandingMean()],\n", - " 7: [RollingMean(window_size=14)]\n", - " },\n", - " date_features=['dayofweek', 'month'],\n", - " num_threads=1,\n", - " engine=client,\n", - " num_partitions=num_partitions_test\n", - " )\n", - " fcst_np.fit(test_dd)\n", - " test_partition_results_size(fcst_np, num_partitions_test)\n", - " preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - " preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - " pd.testing.assert_frame_equal(\n", - " preds[['unique_id', 'ds']], \n", - " preds_np[['unique_id', 'ds']], \n", - " )" + "num_partitions_test = 4\n", + "test_dd = dd.from_pandas(series, npartitions=num_partitions_test) # In this case we dont have to specify the column\n", + "test_dd['unique_id'] = test_dd['unique_id'].astype(str)\n", + "fcst_np = DistributedMLForecast(\n", + " models=models,\n", + " freq='D',\n", + " target_transforms=[Differences([7])], \n", + " lags=[7],\n", + " lag_transforms={\n", + " 1: [ExpandingMean()],\n", + " 7: [RollingMean(window_size=14)]\n", + " },\n", + " date_features=['dayofweek', 'month'],\n", + " num_threads=1,\n", + " engine=client,\n", + " num_partitions=num_partitions_test\n", + ")\n", + "fcst_np.fit(test_dd)\n", + "test_partition_results_size(fcst_np, num_partitions_test)\n", + "preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "pd.testing.assert_frame_equal(\n", + " preds[['unique_id', 'ds']], \n", + " preds_np[['unique_id', 'ds']], \n", + ")" ] }, { @@ -448,36 +447,36 @@ " 0\n", " id_00\n", " 2002-09-27 00:00:00\n", - " 22.267619\n", - " 21.835798\n", + " 21.722841\n", + " 21.725511\n", " \n", " \n", " 1\n", " id_00\n", " 2002-09-28 00:00:00\n", - " 85.230055\n", - " 83.996424\n", + " 84.918194\n", + " 84.606362\n", " \n", " \n", " 2\n", " id_00\n", " 2002-09-29 00:00:00\n", - " 168.256154\n", - " 163.076652\n", + " 162.067624\n", + " 163.36802\n", " \n", " \n", " 3\n", " id_00\n", " 2002-09-30 00:00:00\n", - " 246.712244\n", - " 245.827467\n", + " 249.001477\n", + " 246.422894\n", " \n", " \n", " 4\n", " id_00\n", " 2002-10-01 00:00:00\n", - " 314.184225\n", - " 315.257849\n", + " 317.149512\n", + " 315.538403\n", " \n", " \n", "\n", @@ -485,11 +484,11 @@ ], "text/plain": [ " unique_id ds DaskXGBForecast DaskLGBMForecast\n", - "0 id_00 2002-09-27 00:00:00 22.267619 21.835798\n", - "1 id_00 2002-09-28 00:00:00 85.230055 83.996424\n", - "2 id_00 2002-09-29 00:00:00 168.256154 163.076652\n", - "3 id_00 2002-09-30 00:00:00 246.712244 245.827467\n", - "4 id_00 2002-10-01 00:00:00 314.184225 315.257849" + "0 id_00 2002-09-27 00:00:00 21.722841 21.725511\n", + "1 id_00 2002-09-28 00:00:00 84.918194 84.606362\n", + "2 id_00 2002-09-29 00:00:00 162.067624 163.36802\n", + "3 id_00 2002-09-30 00:00:00 249.001477 246.422894\n", + "4 id_00 2002-10-01 00:00:00 317.149512 315.538403" ] }, "execution_count": null, @@ -502,6 +501,20 @@ "preds.head()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0150de6d-88b5-4513-bd82-c835ba945e79", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# predict with ids\n", + "ids = np.random.choice(series['unique_id'].unique(), size=10, replace=False)\n", + "preds_ids = fcst.predict(7, X_df=future[future['unique_id'].isin(ids)], ids=ids).compute()\n", + "assert set(preds_ids['unique_id']) == set(ids)" + ] + }, { "cell_type": "code", "execution_count": null, From aa016aaee346545cb9a9fa011a0e29f1cf526d02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Fri, 22 Nov 2024 13:46:55 -0600 Subject: [PATCH 2/2] update cv schema --- mlforecast/distributed/forecast.py | 5 ++--- nbs/distributed.forecast.ipynb | 5 ++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index 3d033aa8..fab92a53 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -654,9 +654,8 @@ def cross_validation( keep_last_n=keep_last_n, window_info=window_info, ) - schema = ( - self._get_predict_schema() - + f",cutoff:datetime,{self._base_ts.target_col}:double" + schema = self._get_predict_schema() + Schema( + ("cutoff", "datetime"), (self._base_ts.target_col, "double") ) preds = fa.transform( partition_results, diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index b9871d64..ab3f528d 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -703,7 +703,10 @@ " keep_last_n=keep_last_n,\n", " window_info=window_info,\n", " )\n", - " schema = self._get_predict_schema() + f',cutoff:datetime,{self._base_ts.target_col}:double'\n", + " schema = (\n", + " self._get_predict_schema() + Schema(\n", + " ('cutoff', 'datetime'), (self._base_ts.target_col, 'double'))\n", + " )\n", " preds = fa.transform(\n", " partition_results,\n", " DistributedMLForecast._predict,\n",