diff --git a/local_environment.yml b/local_environment.yml
index 0a4f5771..ded66892 100644
--- a/local_environment.yml
+++ b/local_environment.yml
@@ -2,7 +2,7 @@ name: mlforecast
channels:
- conda-forge
dependencies:
- - coreforecast>=0.0.5
+ - fsspec
- holidays<0.21
- lightgbm
- matplotlib
diff --git a/mlforecast/target_transforms.py b/mlforecast/target_transforms.py
index fe79a5d0..af8efe01 100644
--- a/mlforecast/target_transforms.py
+++ b/mlforecast/target_transforms.py
@@ -147,6 +147,11 @@ def stack(scalers: Sequence["Differences"]) -> "Differences":
core_scaler = first_scaler.scalers_[0]
diffs = first_scaler.differences
out = Differences(diffs)
+ out.fitted_ = []
+ for i in range(len(scalers[0].fitted_)):
+ data = np.hstack([sc.fitted_[i].data for sc in scalers])
+ sizes = np.hstack([np.diff(sc.fitted_[i].indptr) for sc in scalers])
+ out.fitted_.append(GroupedArray(data, np.append(0, sizes.cumsum())))
out.scalers_ = [
core_scaler.stack([sc.scalers_[i] for sc in scalers])
for i in range(len(diffs))
diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb
index 3a9b2bf8..ebc386af 100644
--- a/nbs/docs/getting-started/quick_start_distributed.ipynb
+++ b/nbs/docs/getting-started/quick_start_distributed.ipynb
@@ -278,7 +278,7 @@
"fcst = DistributedMLForecast(\n",
" models=models,\n",
" freq='D',\n",
- " target_transforms=[Differences([1])],\n",
+ " target_transforms=[Differences([7])],\n",
" lags=[7],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -350,6 +350,7 @@
"fcst_np = DistributedMLForecast(\n",
" models=models,\n",
" freq='D',\n",
+ " target_transforms=[Differences([7])], \n",
" lags=[7],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -424,36 +425,36 @@
"
0 | \n",
" id_00 | \n",
" 2002-09-27 | \n",
- " 3.828538 | \n",
- " 3.519942 | \n",
+ " 20.999371 | \n",
+ " 21.892795 | \n",
" \n",
" \n",
" 1 | \n",
" id_00 | \n",
" 2002-09-28 | \n",
- " 91.890763 | \n",
- " 84.406499 | \n",
+ " 84.771692 | \n",
+ " 83.002009 | \n",
"
\n",
" \n",
" 2 | \n",
" id_00 | \n",
" 2002-09-29 | \n",
- " 165.303125 | \n",
- " 157.733402 | \n",
+ " 162.389419 | \n",
+ " 163.528475 | \n",
"
\n",
" \n",
" 3 | \n",
" id_00 | \n",
" 2002-09-30 | \n",
- " 243.340097 | \n",
- " 231.514761 | \n",
+ " 245.002456 | \n",
+ " 245.472042 | \n",
"
\n",
" \n",
" 4 | \n",
" id_00 | \n",
" 2002-10-01 | \n",
- " 314.397042 | \n",
- " 305.351770 | \n",
+ " 317.240952 | \n",
+ " 313.948840 | \n",
"
\n",
" \n",
"\n",
@@ -461,11 +462,11 @@
],
"text/plain": [
" unique_id ds DaskXGBForecast DaskLGBMForecast\n",
- "0 id_00 2002-09-27 3.828538 3.519942\n",
- "1 id_00 2002-09-28 91.890763 84.406499\n",
- "2 id_00 2002-09-29 165.303125 157.733402\n",
- "3 id_00 2002-09-30 243.340097 231.514761\n",
- "4 id_00 2002-10-01 314.397042 305.351770"
+ "0 id_00 2002-09-27 20.999371 21.892795\n",
+ "1 id_00 2002-09-28 84.771692 83.002009\n",
+ "2 id_00 2002-09-29 162.389419 163.528475\n",
+ "3 id_00 2002-09-30 245.002456 245.472042\n",
+ "4 id_00 2002-10-01 317.240952 313.948840"
]
},
"execution_count": null,
@@ -510,6 +511,7 @@
"fcst_exog = DistributedMLForecast(\n",
" models=models,\n",
" freq='D',\n",
+ " target_transforms=[Differences([7])], \n",
" lags=[7],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -627,7 +629,6 @@
"metadata": {},
"outputs": [],
"source": [
- "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"local_fcst = fcst.to_local()\n",
"local_preds = local_fcst.predict(10)\n",
"# we don't check the dtype because sometimes these are arrow dtypes\n",
@@ -663,7 +664,107 @@
"execution_count": null,
"id": "07e80450-d582-42bb-8bf4-ff925d5e74e1",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " DaskXGBForecast | \n",
+ " DaskLGBMForecast | \n",
+ " cutoff | \n",
+ " y | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2002-08-16 | \n",
+ " 22.706938 | \n",
+ " 21.967568 | \n",
+ " 2002-08-15 | \n",
+ " 11.878591 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2002-08-17 | \n",
+ " 95.885948 | \n",
+ " 98.285482 | \n",
+ " 2002-08-15 | \n",
+ " 75.108162 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2002-08-18 | \n",
+ " 172.546631 | \n",
+ " 171.527272 | \n",
+ " 2002-08-15 | \n",
+ " 175.278407 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2002-08-19 | \n",
+ " 238.256594 | \n",
+ " 238.375726 | \n",
+ " 2002-08-15 | \n",
+ " 226.062025 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2002-08-20 | \n",
+ " 306.005923 | \n",
+ " 305.146636 | \n",
+ " 2002-08-15 | \n",
+ " 318.433401 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds DaskXGBForecast DaskLGBMForecast cutoff \\\n",
+ "0 id_00 2002-08-16 22.706938 21.967568 2002-08-15 \n",
+ "1 id_00 2002-08-17 95.885948 98.285482 2002-08-15 \n",
+ "2 id_00 2002-08-18 172.546631 171.527272 2002-08-15 \n",
+ "3 id_00 2002-08-19 238.256594 238.375726 2002-08-15 \n",
+ "4 id_00 2002-08-20 306.005923 305.146636 2002-08-15 \n",
+ "\n",
+ " y \n",
+ "0 11.878591 \n",
+ "1 75.108162 \n",
+ "2 175.278407 \n",
+ "3 226.062025 \n",
+ "4 318.433401 "
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"cv_res.compute().head()"
]
@@ -743,6 +844,7 @@
"non_std_series = non_std_series.rename(columns={'ds': 'time', 'y': 'value', 'unique_id': 'some_id'})\n",
"flow_params = dict(\n",
" models=[DaskXGBForecast(random_state=0)],\n",
+ " target_transforms=[Differences([7])], \n",
" lags=[7],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -865,7 +967,7 @@
"metadata": {},
"outputs": [],
"source": [
- "models = [SparkLGBMForecast(), SparkXGBForecast()]"
+ "models = [SparkLGBMForecast(seed=0), SparkXGBForecast(random_state=0)]"
]
},
{
@@ -886,6 +988,7 @@
"fcst = DistributedMLForecast(\n",
" models,\n",
" freq='D',\n",
+ " target_transforms=[Differences([7])], \n",
" lags=[1],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -922,7 +1025,7 @@
"num_partitions_test = 10\n",
"fcst_np = DistributedMLForecast(\n",
" models=models,\n",
- " freq='D',\n",
+ " freq='D', \n",
" lags=[7],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -965,7 +1068,88 @@
"execution_count": null,
"id": "5da6bbd2-9806-4b12-91a2-1b5571ae1550",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " SparkLGBMForecast | \n",
+ " SparkXGBForecast | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2001-05-15 | \n",
+ " 431.677682 | \n",
+ " 424.488985 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2001-05-16 | \n",
+ " 503.673189 | \n",
+ " 502.923172 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2001-05-17 | \n",
+ " 8.150285 | \n",
+ " 8.019412 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2001-05-18 | \n",
+ " 97.620923 | \n",
+ " 97.031792 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2001-05-19 | \n",
+ " 194.568960 | \n",
+ " 193.862475 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds SparkLGBMForecast SparkXGBForecast\n",
+ "0 id_00 2001-05-15 431.677682 424.488985\n",
+ "1 id_00 2001-05-16 503.673189 502.923172\n",
+ "2 id_00 2001-05-17 8.150285 8.019412\n",
+ "3 id_00 2001-05-18 97.620923 97.031792\n",
+ "4 id_00 2001-05-19 194.568960 193.862475"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"preds.head()"
]
@@ -993,7 +1177,15 @@
"execution_count": null,
"id": "3a1bff80-c052-4a24-a225-33b774d7d75a",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
"source": [
"save_dir = build_unique_name('spark')\n",
"save_path = f's3://nixtla-tmp/mlf/{save_dir}'\n",
@@ -1005,7 +1197,7 @@
"id": "9bd4a6de-1c34-4112-bb18-2818b808b8eb",
"metadata": {},
"source": [
- "Once you've saved your forecast object you can then load it back by specifying the path where it was saved along with an engine, which will be used to perform the distributed computations (in this case the dask client)."
+ "Once you've saved your forecast object you can then load it back by specifying the path where it was saved along with an engine, which will be used to perform the distributed computations (in this case the spark session)."
]
},
{
@@ -1013,7 +1205,15 @@
"execution_count": null,
"id": "6191d277-ce70-464b-a44f-0332a96ec7c6",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
"source": [
"fcst2 = DistributedMLForecast.load(save_path, engine=spark)"
]
@@ -1089,7 +1289,107 @@
"execution_count": null,
"id": "0c735385-7104-4ced-a253-d4a16b8bbb4e",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " SparkLGBMForecast | \n",
+ " SparkXGBForecast | \n",
+ " cutoff | \n",
+ " y | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_12 | \n",
+ " 2001-04-03 | \n",
+ " 342.978379 | \n",
+ " 341.930127 | \n",
+ " 2001-04-02 | \n",
+ " 328.907629 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_23 | \n",
+ " 2001-04-03 | \n",
+ " 429.591043 | \n",
+ " 428.320398 | \n",
+ " 2001-04-02 | \n",
+ " 424.716749 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_26 | \n",
+ " 2001-04-10 | \n",
+ " 7.554284 | \n",
+ " 7.707686 | \n",
+ " 2001-04-02 | \n",
+ " 19.814264 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_18 | \n",
+ " 2001-04-11 | \n",
+ " 98.885044 | \n",
+ " 98.848126 | \n",
+ " 2001-04-02 | \n",
+ " 98.877898 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2001-04-13 | \n",
+ " 122.727000 | \n",
+ " 117.713487 | \n",
+ " 2001-04-02 | \n",
+ " 98.526008 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds SparkLGBMForecast SparkXGBForecast cutoff \\\n",
+ "0 id_12 2001-04-03 342.978379 341.930127 2001-04-02 \n",
+ "1 id_23 2001-04-03 429.591043 428.320398 2001-04-02 \n",
+ "2 id_26 2001-04-10 7.554284 7.707686 2001-04-02 \n",
+ "3 id_18 2001-04-11 98.885044 98.848126 2001-04-02 \n",
+ "4 id_00 2001-04-13 122.727000 117.713487 2001-04-02 \n",
+ "\n",
+ " y \n",
+ "0 328.907629 \n",
+ "1 424.716749 \n",
+ "2 19.814264 \n",
+ "3 98.877898 \n",
+ "4 98.526008 "
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"cv_res.head()"
]
@@ -1198,10 +1498,7 @@
"metadata": {},
"outputs": [],
"source": [
- "models = [\n",
- " RayLGBMForecast(),\n",
- " RayXGBForecast(),\n",
- "]"
+ "models = [RayLGBMForecast(random_state=0), RayXGBForecast(random_state=0)]"
]
},
{
@@ -1240,6 +1537,7 @@
"fcst = DistributedMLForecast(\n",
" models,\n",
" freq='D',\n",
+ " target_transforms=[Differences([7])],\n",
" lags=[1],\n",
" lag_transforms={\n",
" 1: [ExpandingMean()],\n",
@@ -1322,7 +1620,88 @@
"execution_count": null,
"id": "962ee74e-1346-4991-bcf2-5b2a887eb5b0",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " RayLGBMForecast | \n",
+ " RayXGBForecast | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2001-05-15 | \n",
+ " 431.677682 | \n",
+ " 427.262462 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_00 | \n",
+ " 2001-05-16 | \n",
+ " 503.673189 | \n",
+ " 502.605670 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_00 | \n",
+ " 2001-05-17 | \n",
+ " 8.150285 | \n",
+ " 7.604773 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_00 | \n",
+ " 2001-05-18 | \n",
+ " 97.620923 | \n",
+ " 97.582869 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_00 | \n",
+ " 2001-05-19 | \n",
+ " 194.568960 | \n",
+ " 192.818578 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds RayLGBMForecast RayXGBForecast\n",
+ "0 id_00 2001-05-15 431.677682 427.262462\n",
+ "1 id_00 2001-05-16 503.673189 502.605670\n",
+ "2 id_00 2001-05-17 8.150285 7.604773\n",
+ "3 id_00 2001-05-18 97.620923 97.582869\n",
+ "4 id_00 2001-05-19 194.568960 192.818578"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"preds.head()"
]
@@ -1362,7 +1741,7 @@
"id": "fcce44d0-3035-4919-830a-d45a3d47d9b1",
"metadata": {},
"source": [
- "Once you've saved your forecast object you can then load it back by specifying the path where it was saved along with an engine, which will be used to perform the distributed computations (in this case the dask client)."
+ "Once you've saved your forecast object you can then load it back by specifying the path where it was saved along with an engine, which will be used to perform the distributed computations (in this case the 'ray' string)."
]
},
{
@@ -1446,7 +1825,100 @@
"execution_count": null,
"id": "d9ecc467-80c7-4c0e-98d6-af77c9fc7fe8",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " unique_id | \n",
+ " ds | \n",
+ " RayLGBMForecast | \n",
+ " RayXGBForecast | \n",
+ " cutoff | \n",
+ " y | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " id_10 | \n",
+ " 2001-05-01 | \n",
+ " 24.962461 | \n",
+ " 22.998615 | \n",
+ " 2001-04-30 | \n",
+ " 31.878545 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " id_10 | \n",
+ " 2001-05-02 | \n",
+ " 53.219645 | \n",
+ " 54.298105 | \n",
+ " 2001-04-30 | \n",
+ " 48.349363 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " id_10 | \n",
+ " 2001-05-03 | \n",
+ " 78.068732 | \n",
+ " 76.111907 | \n",
+ " 2001-04-30 | \n",
+ " 71.607111 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " id_10 | \n",
+ " 2001-05-04 | \n",
+ " 103.153889 | \n",
+ " 104.344135 | \n",
+ " 2001-04-30 | \n",
+ " 103.482107 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " id_10 | \n",
+ " 2001-05-05 | \n",
+ " 116.708231 | \n",
+ " 115.950523 | \n",
+ " 2001-04-30 | \n",
+ " 124.719690 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " unique_id ds RayLGBMForecast RayXGBForecast cutoff y\n",
+ "0 id_10 2001-05-01 24.962461 22.998615 2001-04-30 31.878545\n",
+ "1 id_10 2001-05-02 53.219645 54.298105 2001-04-30 48.349363\n",
+ "2 id_10 2001-05-03 78.068732 76.111907 2001-04-30 71.607111\n",
+ "3 id_10 2001-05-04 103.153889 104.344135 2001-04-30 103.482107\n",
+ "4 id_10 2001-05-05 116.708231 115.950523 2001-04-30 124.719690"
+ ]
+ },
+ "execution_count": null,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"cv_res.head()"
]
diff --git a/nbs/target_transforms.ipynb b/nbs/target_transforms.ipynb
index 26221d6f..8975bca5 100644
--- a/nbs/target_transforms.ipynb
+++ b/nbs/target_transforms.ipynb
@@ -221,6 +221,11 @@
" core_scaler = first_scaler.scalers_[0]\n",
" diffs = first_scaler.differences\n",
" out = Differences(diffs)\n",
+ " out.fitted_ = []\n",
+ " for i in range(len(scalers[0].fitted_)):\n",
+ " data = np.hstack([sc.fitted_[i].data for sc in scalers])\n",
+ " sizes = np.hstack([np.diff(sc.fitted_[i].indptr) for sc in scalers])\n",
+ " out.fitted_.append(GroupedArray(data, np.append(0, sizes.cumsum())))\n",
" out.scalers_ = [\n",
" core_scaler.stack([sc.scalers_[i] for sc in scalers])\n",
" for i in range(len(diffs))\n",
diff --git a/settings.ini b/settings.ini
index 6eb44dad..17c5b1c3 100644
--- a/settings.ini
+++ b/settings.ini
@@ -15,11 +15,10 @@ language = English
custom_sidebar = True
license = apache2
status = 3
-requirements = cloudpickle coreforecast>=0.0.5 fsspec numba packaging pandas scikit-learn utilsforecast>=0.0.27 window-ops
+requirements = cloudpickle coreforecast>=0.0.7 fsspec numba packaging pandas scikit-learn utilsforecast>=0.0.27 window-ops
dask_requirements = fugue dask[complete] lightgbm xgboost
ray_requirements = fugue[ray] lightgbm_ray xgboost_ray
spark_requirements = fugue pyspark>=3.3 lightgbm xgboost
-lag_tfms_requirements = coreforecast>=0.0.5
aws_requirements = fsspec[s3]
gcp_requirements = fsspec[gcs]
azure_requirements = fsspec[adl]
diff --git a/setup.py b/setup.py
index 33bf3c83..342cc7fa 100644
--- a/setup.py
+++ b/setup.py
@@ -61,7 +61,7 @@
'dask': dask_requirements,
'ray': ray_requirements,
'spark': spark_requirements,
- 'lag_transforms': lag_tfms_requirements,
+ 'lag_transforms': [],
'aws': aws_requirements,
'azure': azure_requirements,
'gcp': gcp_requirements,
diff --git a/tests/test_m4.py b/tests/test_m4.py
index 3e320987..853b4397 100644
--- a/tests/test_m4.py
+++ b/tests/test_m4.py
@@ -68,9 +68,9 @@
},
"metrics": {
"lgb": {
- "SMAPE": 2.984803,
- "MASE": 3.202900,
- "OWA": 0.978585,
+ "SMAPE": 2.984652,
+ "MASE": 3.205519,
+ "OWA": 0.978931,
},
"enet": {
"SMAPE": 2.989625,