Skip to content

Commit

Permalink
add DistributedMLForecast.update (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Mar 5, 2024
1 parent 94d9663 commit 4fcfcc7
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 22 deletions.
7 changes: 6 additions & 1 deletion mlforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast._save_ts': ( 'distributed.forecast.html#distributedmlforecast._save_ts',
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast._update': ( 'distributed.forecast.html#distributedmlforecast._update',
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.cross_validation': ( 'distributed.forecast.html#distributedmlforecast.cross_validation',
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.fit': ( 'distributed.forecast.html#distributedmlforecast.fit',
Expand All @@ -95,7 +97,9 @@
'mlforecast.distributed.forecast.DistributedMLForecast.save': ( 'distributed.forecast.html#distributedmlforecast.save',
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.to_local': ( 'distributed.forecast.html#distributedmlforecast.to_local',
'mlforecast/distributed/forecast.py')},
'mlforecast/distributed/forecast.py'),
'mlforecast.distributed.forecast.DistributedMLForecast.update': ( 'distributed.forecast.html#distributedmlforecast.update',
'mlforecast/distributed/forecast.py')},
'mlforecast.distributed.models.dask.lgb': { 'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast': ( 'distributed.models.dask.lgb.html#dasklgbmforecast',
'mlforecast/distributed/models/dask/lgb.py'),
'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast.model_': ( 'distributed.models.dask.lgb.html#dasklgbmforecast.model_',
Expand Down Expand Up @@ -161,6 +165,7 @@
'mlforecast.forecast.MLForecast.preprocess': ( 'forecast.html#mlforecast.preprocess',
'mlforecast/forecast.py'),
'mlforecast.forecast.MLForecast.save': ('forecast.html#mlforecast.save', 'mlforecast/forecast.py'),
'mlforecast.forecast.MLForecast.update': ('forecast.html#mlforecast.update', 'mlforecast/forecast.py'),
'mlforecast.forecast._add_conformal_distribution_intervals': ( 'forecast.html#_add_conformal_distribution_intervals',
'mlforecast/forecast.py'),
'mlforecast.forecast._add_conformal_error_intervals': ( 'forecast.html#_add_conformal_error_intervals',
Expand Down
28 changes: 28 additions & 0 deletions mlforecast/distributed/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,34 @@ def load(path: str, engine) -> "DistributedMLForecast":
fcst.num_partitions = len(paths)
return fcst

@staticmethod
def _update(items: List[List[Any]], new_df) -> Iterable[List[Any]]:
for serialized_ts, serialized_transformed, serialized_valid in items:
ts = cloudpickle.loads(serialized_ts)
partition_mask = ufp.is_in(new_df[ts.id_col], ts.uids)
partition_df = ufp.filter_with_mask(new_df, partition_mask)
ts.update(partition_df)
yield [cloudpickle.dumps(ts), serialized_transformed, serialized_valid]

def update(self, df: pd.DataFrame) -> None:
"""Update the values of the stored series.
Parameters
----------
df : pandas DataFrame
Dataframe with new observations."""
if not isinstance(df, pd.DataFrame):
raise ValueError("`df` must be a pandas DataFrame.")
res = fa.transform(
self._partition_results,
DistributedMLForecast._update,
params={"new_df": df},
schema="ts:binary,train:binary,valid:binary",
engine=self.engine,
as_fugue=True,
)
self._partition_results = fa.persist(res)

def to_local(self) -> MLForecast:
"""Convert this distributed forecast object into a local one
Expand Down
9 changes: 9 additions & 0 deletions mlforecast/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,3 +974,12 @@ def load(path: Union[str, Path]) -> "MLForecast":
fcst.ts = ts
fcst.models_ = models
return fcst

def update(self, df: DataFrame) -> None:
"""Update the values of the stored series.
Parameters
----------
df : pandas or polars DataFrame
Dataframe with new observations."""
self.ts.update(df)
114 changes: 96 additions & 18 deletions nbs/distributed.forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,34 @@
" fcst.num_partitions = len(paths)\n",
" return fcst\n",
"\n",
" @staticmethod\n",
" def _update(items: List[List[Any]], new_df) -> Iterable[List[Any]]:\n",
" for serialized_ts, serialized_transformed, serialized_valid in items:\n",
" ts = cloudpickle.loads(serialized_ts)\n",
" partition_mask = ufp.is_in(new_df[ts.id_col], ts.uids)\n",
" partition_df = ufp.filter_with_mask(new_df, partition_mask)\n",
" ts.update(partition_df)\n",
" yield [cloudpickle.dumps(ts), serialized_transformed, serialized_valid]\n",
"\n",
" def update(self, df: pd.DataFrame) -> None:\n",
" \"\"\"Update the values of the stored series.\n",
"\n",
" Parameters\n",
" ----------\n",
" df : pandas DataFrame\n",
" Dataframe with new observations.\"\"\"\n",
" if not isinstance(df, pd.DataFrame):\n",
" raise ValueError(\"`df` must be a pandas DataFrame.\")\n",
" res = fa.transform(\n",
" self._partition_results,\n",
" DistributedMLForecast._update,\n",
" params={\"new_df\": df},\n",
" schema=\"ts:binary,train:binary,valid:binary\",\n",
" engine=self.engine,\n",
" as_fugue=True,\n",
" )\n",
" self._partition_results = fa.persist(res)\n",
"\n",
" def to_local(self) -> MLForecast:\n",
" \"\"\"Convert this distributed forecast object into a local one\n",
" \n",
Expand Down Expand Up @@ -890,8 +918,8 @@
"> [Iterable[Union[str,Callable]]]=None,\n",
"> num_threads:int=1, target_transforms:Optional[List\n",
"> [Union[mlforecast.target_transforms.BaseTargetTran\n",
"> sform,mlforecast.target_transforms.BaseGroupedArra\n",
"> yTargetTransform]]]=None, engine=None,\n",
"> sform,mlforecast.target_transforms._BaseGroupedArr\n",
"> ayTargetTransform]]]=None, engine=None,\n",
"> num_partitions:Optional[int]=None)\n",
"\n",
"Multi backend distributed pipeline"
Expand All @@ -911,8 +939,8 @@
"> [Iterable[Union[str,Callable]]]=None,\n",
"> num_threads:int=1, target_transforms:Optional[List\n",
"> [Union[mlforecast.target_transforms.BaseTargetTran\n",
"> sform,mlforecast.target_transforms.BaseGroupedArra\n",
"> yTargetTransform]]]=None, engine=None,\n",
"> sform,mlforecast.target_transforms._BaseGroupedArr\n",
"> ayTargetTransform]]]=None, engine=None,\n",
"> num_partitions:Optional[int]=None)\n",
"\n",
"Multi backend distributed pipeline"
Expand All @@ -938,7 +966,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L386){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L390){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.fit\n",
"\n",
Expand All @@ -964,7 +992,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L386){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L390){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.fit\n",
"\n",
Expand Down Expand Up @@ -1008,7 +1036,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L462){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L466){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.predict\n",
"\n",
Expand All @@ -1033,7 +1061,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L462){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L466){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.predict\n",
"\n",
Expand Down Expand Up @@ -1076,7 +1104,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L645){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L649){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.save\n",
"\n",
Expand All @@ -1092,7 +1120,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L645){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L649){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.save\n",
"\n",
Expand Down Expand Up @@ -1126,7 +1154,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L678){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L682){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.load\n",
"\n",
Expand All @@ -1143,7 +1171,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L678){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L682){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.load\n",
"\n",
Expand All @@ -1167,6 +1195,56 @@
"show_doc(DistributedMLForecast.load)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cee581ec-35a7-4443-bb08-1edf15f6c783",
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L728){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.update\n",
"\n",
"> DistributedMLForecast.update (df:pandas.core.frame.DataFrame)\n",
"\n",
"Update the values of the stored series.\n",
"\n",
"| | **Type** | **Details** |\n",
"| -- | -------- | ----------- |\n",
"| df | DataFrame | Dataframe with new observations. |\n",
"| **Returns** | **None** | |"
],
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L728){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.update\n",
"\n",
"> DistributedMLForecast.update (df:pandas.core.frame.DataFrame)\n",
"\n",
"Update the values of the stored series.\n",
"\n",
"| | **Type** | **Details** |\n",
"| -- | -------- | ----------- |\n",
"| df | DataFrame | Dataframe with new observations. |\n",
"| **Returns** | **None** | |"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"show_doc(DistributedMLForecast.update)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -1178,7 +1256,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L715){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L739){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.to_local\n",
"\n",
Expand All @@ -1192,7 +1270,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L715){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L739){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.to_local\n",
"\n",
Expand Down Expand Up @@ -1224,7 +1302,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L290){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L294){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.preprocess\n",
"\n",
Expand All @@ -1251,7 +1329,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L290){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L294){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.preprocess\n",
"\n",
Expand Down Expand Up @@ -1296,7 +1374,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L527){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L531){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.cross_validation\n",
"\n",
Expand Down Expand Up @@ -1339,7 +1417,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L527){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/distributed/forecast.py#L531){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### DistributedMLForecast.cross_validation\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions nbs/docs/getting-started/end_to_end_walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@
"id": "0a1058f2-e6e2-4566-aad4-67c45bbbbe69",
"metadata": {},
"source": [
"After you've trained a forecast object you can save and load it with the previous methods. If by the time you want to use it you already know the following values of the target you can use the `MLForecast.ts.update` method to incorporate these, which will allow you to use these new values when computing predictions.\n",
"After you've trained a forecast object you can save and load it with the previous methods. If by the time you want to use it you already know the following values of the target you can use the `MLForecast.update` method to incorporate these, which will allow you to use these new values when computing predictions.\n",
"\n",
"* If no new values are provided for a serie that's currently stored, only the previous ones are kept.\n",
"* If new series are included they are added to the existing ones."
Expand Down Expand Up @@ -1907,7 +1907,7 @@
" 'ds': [1009, 1009],\n",
" 'y': [17.0, 14.0],\n",
"})\n",
"fcst.ts.update(new_values)\n",
"fcst.update(new_values)\n",
"preds = fcst.predict(1)\n",
"preds"
]
Expand Down
Loading

0 comments on commit 4fcfcc7

Please sign in to comment.