diff --git a/mlforecast/core.py b/mlforecast/core.py
index 846330e3..6357ab07 100644
--- a/mlforecast/core.py
+++ b/mlforecast/core.py
@@ -402,16 +402,23 @@ def _transform(
target = target[self._restore_idxs]
# determine rows to keep
+ target_nulls = np.isnan(target)
+ if target_nulls.ndim == 2:
+ # target nulls for each horizon are dropped in MLForecast.fit_models
+ # we just drop rows here for which all the target values are null
+ target_nulls = target_nulls.all(axis=1)
if dropna:
feature_nulls = np.full(df.shape[0], False)
for feature_vals in features.values():
feature_nulls |= np.isnan(feature_vals)
- target_nulls = np.isnan(target)
- if target_nulls.ndim == 2:
- # target nulls for each horizon are dropped in MLForecast.fit_models
- # we just drop rows here for which all the target values are null
- target_nulls = target_nulls.all(axis=1)
keep_rows = ~(feature_nulls | target_nulls)
+ else:
+ # we always want to drop rows with nulls in the target
+ keep_rows = ~target_nulls
+
+ self._dropped_series: Optional[np.ndarray] = None
+ if not keep_rows.all():
+ # remove rows with nulls
for k, v in features.items():
features[k] = v[keep_rows]
target = target[keep_rows]
@@ -422,7 +429,7 @@ def _transform(
last_idxs = self._sort_idxs[last_idxs]
last_vals_nan = ~keep_rows[last_idxs]
if last_vals_nan.any():
- self._dropped_series: Optional[np.ndarray] = np.where(last_vals_nan)[0]
+ self._dropped_series = np.where(last_vals_nan)[0]
dropped_ids = reprlib.repr(list(self.uids[self._dropped_series]))
warnings.warn(
"The following series were dropped completely "
@@ -430,11 +437,9 @@ def _transform(
"These series won't show up if you use `MLForecast.forecast_fitted_values()`.\n"
"You can set `dropna=False` or use transformations that require less samples to mitigate this"
)
- else:
- self._dropped_series = None
elif isinstance(df, pd.DataFrame):
+ # we'll be assigning columns below, so we need to copy
df = df.copy(deep=False)
- self._dropped_series = None
# once we've computed the features and target we can slice the series
update_samples = [
diff --git a/nbs/core.ipynb b/nbs/core.ipynb
index df3e198c..96a58a49 100644
--- a/nbs/core.ipynb
+++ b/nbs/core.ipynb
@@ -20,16 +20,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"#|hide\n",
"%load_ext autoreload\n",
@@ -895,16 +886,23 @@
" target = target[self._restore_idxs] \n",
"\n",
" # determine rows to keep\n",
+ " target_nulls = np.isnan(target)\n",
+ " if target_nulls.ndim == 2:\n",
+ " # target nulls for each horizon are dropped in MLForecast.fit_models\n",
+ " # we just drop rows here for which all the target values are null\n",
+ " target_nulls = target_nulls.all(axis=1)\n",
" if dropna:\n",
" feature_nulls = np.full(df.shape[0], False)\n",
" for feature_vals in features.values():\n",
" feature_nulls |= np.isnan(feature_vals)\n",
- " target_nulls = np.isnan(target)\n",
- " if target_nulls.ndim == 2:\n",
- " # target nulls for each horizon are dropped in MLForecast.fit_models\n",
- " # we just drop rows here for which all the target values are null\n",
- " target_nulls = target_nulls.all(axis=1)\n",
" keep_rows = ~(feature_nulls | target_nulls)\n",
+ " else:\n",
+ " # we always want to drop rows with nulls in the target\n",
+ " keep_rows = ~target_nulls\n",
+ "\n",
+ " self._dropped_series: Optional[np.ndarray] = None\n",
+ " if not keep_rows.all():\n",
+ " # remove rows with nulls\n",
" for k, v in features.items():\n",
" features[k] = v[keep_rows]\n",
" target = target[keep_rows]\n",
@@ -915,7 +913,7 @@
" last_idxs = self._sort_idxs[last_idxs]\n",
" last_vals_nan = ~keep_rows[last_idxs]\n",
" if last_vals_nan.any():\n",
- " self._dropped_series: Optional[np.ndarray] = np.where(last_vals_nan)[0] \n",
+ " self._dropped_series = np.where(last_vals_nan)[0] \n",
" dropped_ids = reprlib.repr(list(self.uids[self._dropped_series]))\n",
" warnings.warn(\n",
" \"The following series were dropped completely \"\n",
@@ -923,11 +921,9 @@
" \"These series won't show up if you use `MLForecast.forecast_fitted_values()`.\\n\"\n",
" \"You can set `dropna=False` or use transformations that require less samples to mitigate this\"\n",
" )\n",
- " else:\n",
- " self._dropped_series = None\n",
" elif isinstance(df, pd.DataFrame):\n",
+ " # we'll be assigning columns below, so we need to copy\n",
" df = df.copy(deep=False)\n",
- " self._dropped_series = None\n",
"\n",
" # once we've computed the features and target we can slice the series\n",
" update_samples = [\n",
@@ -1701,7 +1697,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L511){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
@@ -1723,7 +1719,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L511){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.fit_transform\n",
"\n",
@@ -2016,7 +2012,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
@@ -2030,7 +2026,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
@@ -2168,7 +2164,7 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L863){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
@@ -2181,7 +2177,7 @@
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L863){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
@@ -2575,6 +2571,23 @@
"ts.fit_transform(series, 'unique_id', 'ds', 'y')\n",
"assert ts.keep_last_n is None"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# no target nulls when dropna=False\n",
+ "ts = TimeSeries(\n",
+ " freq='D',\n",
+ " lags=[1, 2],\n",
+ " target_transforms=[Differences([5])],\n",
+ ")\n",
+ "prep = ts.fit_transform(series, 'unique_id', 'ds', 'y', dropna=False)\n",
+ "assert not prep['y'].isnull().any()"
+ ]
}
],
"metadata": {
diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb
index 6f355810..4330ffdf 100644
--- a/nbs/docs/getting-started/quick_start_distributed.ipynb
+++ b/nbs/docs/getting-started/quick_start_distributed.ipynb
@@ -448,36 +448,36 @@
"
0 | \n",
" id_00 | \n",
" 2002-09-27 00:00:00 | \n",
- " 21.609526 | \n",
- " 22.114111 | \n",
+ " 22.267619 | \n",
+ " 21.835798 | \n",
" \n",
" \n",
" 1 | \n",
" id_00 | \n",
" 2002-09-28 00:00:00 | \n",
- " 85.623013 | \n",
- " 84.309696 | \n",
+ " 85.230055 | \n",
+ " 83.996424 | \n",
"
\n",
" \n",
" 2 | \n",
" id_00 | \n",
" 2002-09-29 00:00:00 | \n",
- " 163.107685 | \n",
- " 163.20679 | \n",
+ " 168.256154 | \n",
+ " 163.076652 | \n",
"
\n",
" \n",
" 3 | \n",
" id_00 | \n",
" 2002-09-30 00:00:00 | \n",
- " 246.96872 | \n",
- " 245.510858 | \n",
+ " 246.712244 | \n",
+ " 245.827467 | \n",
"
\n",
" \n",
" 4 | \n",
" id_00 | \n",
" 2002-10-01 00:00:00 | \n",
- " 318.521367 | \n",
- " 314.479718 | \n",
+ " 314.184225 | \n",
+ " 315.257849 | \n",
"
\n",
" \n",
"\n",
@@ -485,11 +485,11 @@
],
"text/plain": [
" unique_id ds DaskXGBForecast DaskLGBMForecast\n",
- "0 id_00 2002-09-27 00:00:00 21.609526 22.114111\n",
- "1 id_00 2002-09-28 00:00:00 85.623013 84.309696\n",
- "2 id_00 2002-09-29 00:00:00 163.107685 163.20679\n",
- "3 id_00 2002-09-30 00:00:00 246.96872 245.510858\n",
- "4 id_00 2002-10-01 00:00:00 318.521367 314.479718"
+ "0 id_00 2002-09-27 00:00:00 22.267619 21.835798\n",
+ "1 id_00 2002-09-28 00:00:00 85.230055 83.996424\n",
+ "2 id_00 2002-09-29 00:00:00 168.256154 163.076652\n",
+ "3 id_00 2002-09-30 00:00:00 246.712244 245.827467\n",
+ "4 id_00 2002-10-01 00:00:00 314.184225 315.257849"
]
},
"execution_count": null,
@@ -792,68 +792,68 @@
" \n",
" \n",
" \n",
- " 0 | \n",
- " id_00 | \n",
- " 2002-08-16 00:00:00 | \n",
- " 23.192749 | \n",
- " 21.986437 | \n",
+ " 61 | \n",
+ " id_04 | \n",
+ " 2002-08-21 00:00:00 | \n",
+ " 68.3418 | \n",
+ " 68.944539 | \n",
" 2002-08-15 00:00:00 | \n",
- " 11.878591 | \n",
+ " 69.699857 | \n",
"
\n",
" \n",
- " 30 | \n",
- " id_02 | \n",
- " 2002-08-18 00:00:00 | \n",
- " 96.59974 | \n",
- " 96.568057 | \n",
+ " 83 | \n",
+ " id_15 | \n",
+ " 2002-08-29 00:00:00 | \n",
+ " 199.315403 | \n",
+ " 199.663555 | \n",
" 2002-08-15 00:00:00 | \n",
- " 94.706551 | \n",
+ " 206.082864 | \n",
"
\n",
" \n",
- " 80 | \n",
- " id_05 | \n",
- " 2002-08-26 00:00:00 | \n",
- " 257.210466 | \n",
- " 255.908309 | \n",
+ " 103 | \n",
+ " id_17 | \n",
+ " 2002-08-21 00:00:00 | \n",
+ " 156.822598 | \n",
+ " 158.018246 | \n",
" 2002-08-15 00:00:00 | \n",
- " 246.051086 | \n",
+ " 152.227984 | \n",
"
\n",
" \n",
- " 36 | \n",
- " id_12 | \n",
- " 2002-08-24 00:00:00 | \n",
- " 401.081335 | \n",
- " 401.697836 | \n",
+ " 61 | \n",
+ " id_24 | \n",
+ " 2002-08-21 00:00:00 | \n",
+ " 136.598356 | \n",
+ " 136.576865 | \n",
" 2002-08-15 00:00:00 | \n",
- " 424.296882 | \n",
+ " 138.559945 | \n",
"
\n",
" \n",
- " 91 | \n",
- " id_16 | \n",
- " 2002-08-23 00:00:00 | \n",
- " 315.036479 | \n",
- " 315.368377 | \n",
+ " 36 | \n",
+ " id_33 | \n",
+ " 2002-08-24 00:00:00 | \n",
+ " 95.6072 | \n",
+ " 96.249354 | \n",
" 2002-08-15 00:00:00 | \n",
- " 300.419406 | \n",
+ " 102.068997 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n",
- "0 id_00 2002-08-16 00:00:00 23.192749 21.986437 \n",
- "30 id_02 2002-08-18 00:00:00 96.59974 96.568057 \n",
- "80 id_05 2002-08-26 00:00:00 257.210466 255.908309 \n",
- "36 id_12 2002-08-24 00:00:00 401.081335 401.697836 \n",
- "91 id_16 2002-08-23 00:00:00 315.036479 315.368377 \n",
+ " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n",
+ "61 id_04 2002-08-21 00:00:00 68.3418 68.944539 \n",
+ "83 id_15 2002-08-29 00:00:00 199.315403 199.663555 \n",
+ "103 id_17 2002-08-21 00:00:00 156.822598 158.018246 \n",
+ "61 id_24 2002-08-21 00:00:00 136.598356 136.576865 \n",
+ "36 id_33 2002-08-24 00:00:00 95.6072 96.249354 \n",
"\n",
- " cutoff y \n",
- "0 2002-08-15 00:00:00 11.878591 \n",
- "30 2002-08-15 00:00:00 94.706551 \n",
- "80 2002-08-15 00:00:00 246.051086 \n",
- "36 2002-08-15 00:00:00 424.296882 \n",
- "91 2002-08-15 00:00:00 300.419406 "
+ " cutoff y \n",
+ "61 2002-08-15 00:00:00 69.699857 \n",
+ "83 2002-08-15 00:00:00 206.082864 \n",
+ "103 2002-08-15 00:00:00 152.227984 \n",
+ "61 2002-08-15 00:00:00 138.559945 \n",
+ "36 2002-08-15 00:00:00 102.068997 "
]
},
"execution_count": null,
@@ -918,7 +918,8 @@
" ),\n",
" static_features=['static_0', 'static_1'],\n",
")\n",
- "assert reduced_train.groupby('unique_id').size().compute().max() == input_size"
+ "dropped_samples = fcst._base_ts.target_transforms[0].differences[0]\n",
+ "assert reduced_train.groupby('unique_id').size().compute().max() == input_size - dropped_samples"
]
},
{
@@ -1183,7 +1184,15 @@
"execution_count": null,
"id": "d06d2230-60f5-47f1-820b-af2ca7311b41",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
"source": [
"preds = fcst.predict(7, X_df=future).toPandas()"
]
@@ -1226,35 +1235,35 @@
" 0 | \n",
" id_00 | \n",
" 2002-09-27 | \n",
- " 15.102403 | \n",
+ " 15.053577 | \n",
" 18.631477 | \n",
" \n",
" \n",
" 1 | \n",
" id_00 | \n",
" 2002-09-28 | \n",
- " 92.980261 | \n",
+ " 93.010037 | \n",
" 93.796269 | \n",
"
\n",
" \n",
" 2 | \n",
" id_00 | \n",
" 2002-09-29 | \n",
- " 160.090375 | \n",
+ " 160.120148 | \n",
" 159.582315 | \n",
"
\n",
" \n",
" 3 | \n",
" id_00 | \n",
" 2002-09-30 | \n",
- " 250.416113 | \n",
+ " 250.445885 | \n",
" 250.861651 | \n",
"
\n",
" \n",
" 4 | \n",
" id_00 | \n",
" 2002-10-01 | \n",
- " 323.306184 | \n",
+ " 323.335956 | \n",
" 321.564089 | \n",
"
\n",
" \n",
@@ -1263,11 +1272,11 @@
],
"text/plain": [
" unique_id ds SparkLGBMForecast SparkXGBForecast\n",
- "0 id_00 2002-09-27 15.102403 18.631477\n",
- "1 id_00 2002-09-28 92.980261 93.796269\n",
- "2 id_00 2002-09-29 160.090375 159.582315\n",
- "3 id_00 2002-09-30 250.416113 250.861651\n",
- "4 id_00 2002-10-01 323.306184 321.564089"
+ "0 id_00 2002-09-27 15.053577 18.631477\n",
+ "1 id_00 2002-09-28 93.010037 93.796269\n",
+ "2 id_00 2002-09-29 160.120148 159.582315\n",
+ "3 id_00 2002-09-30 250.445885 250.861651\n",
+ "4 id_00 2002-10-01 323.335956 321.564089"
]
},
"execution_count": null,
@@ -2008,48 +2017,48 @@
" \n",
" \n",
" 0 | \n",
- " id_04 | \n",
- " 2002-09-20 | \n",
- " 118.982094 | \n",
- " 117.577477 | \n",
+ " id_05 | \n",
+ " 2002-09-21 | \n",
+ " 108.285187 | \n",
+ " 108.619698 | \n",
" 2002-09-12 | \n",
- " 118.603489 | \n",
+ " 108.726387 | \n",
"
\n",
" \n",
" 1 | \n",
- " id_04 | \n",
- " 2002-09-24 | \n",
- " 51.461491 | \n",
- " 50.120552 | \n",
+ " id_08 | \n",
+ " 2002-09-16 | \n",
+ " 26.287956 | \n",
+ " 26.589603 | \n",
" 2002-09-12 | \n",
- " 52.668389 | \n",
+ " 27.980670 | \n",
"
\n",
" \n",
" 2 | \n",
- " id_05 | \n",
- " 2002-09-20 | \n",
- " 27.594826 | \n",
- " 24.421537 | \n",
+ " id_08 | \n",
+ " 2002-09-25 | \n",
+ " 83.210945 | \n",
+ " 84.194962 | \n",
" 2002-09-12 | \n",
- " 20.120710 | \n",
+ " 86.344885 | \n",
"
\n",
" \n",
" 3 | \n",
- " id_05 | \n",
- " 2002-09-25 | \n",
- " 411.615204 | \n",
- " 412.093384 | \n",
+ " id_11 | \n",
+ " 2002-09-22 | \n",
+ " 416.994843 | \n",
+ " 417.106506 | \n",
" 2002-09-12 | \n",
- " 419.621422 | \n",
+ " 425.434661 | \n",
"
\n",
" \n",
" 4 | \n",
- " id_08 | \n",
- " 2002-09-25 | \n",
- " 83.210945 | \n",
- " 83.842705 | \n",
+ " id_16 | \n",
+ " 2002-09-14 | \n",
+ " 377.916382 | \n",
+ " 375.421600 | \n",
" 2002-09-12 | \n",
- " 86.344885 | \n",
+ " 400.361977 | \n",
"
\n",
" \n",
"\n",
@@ -2057,11 +2066,11 @@
],
"text/plain": [
" unique_id ds RayLGBMForecast RayXGBForecast cutoff y\n",
- "0 id_04 2002-09-20 118.982094 117.577477 2002-09-12 118.603489\n",
- "1 id_04 2002-09-24 51.461491 50.120552 2002-09-12 52.668389\n",
- "2 id_05 2002-09-20 27.594826 24.421537 2002-09-12 20.120710\n",
- "3 id_05 2002-09-25 411.615204 412.093384 2002-09-12 419.621422\n",
- "4 id_08 2002-09-25 83.210945 83.842705 2002-09-12 86.344885"
+ "0 id_05 2002-09-21 108.285187 108.619698 2002-09-12 108.726387\n",
+ "1 id_08 2002-09-16 26.287956 26.589603 2002-09-12 27.980670\n",
+ "2 id_08 2002-09-25 83.210945 84.194962 2002-09-12 86.344885\n",
+ "3 id_11 2002-09-22 416.994843 417.106506 2002-09-12 425.434661\n",
+ "4 id_16 2002-09-14 377.916382 375.421600 2002-09-12 400.361977"
]
},
"execution_count": null,
diff --git a/nbs/target_transforms.ipynb b/nbs/target_transforms.ipynb
index 589b9c57..b6163f55 100644
--- a/nbs/target_transforms.ipynb
+++ b/nbs/target_transforms.ipynb
@@ -732,6 +732,7 @@
" sk_boxcox.fit_transform(series[['y']])[:, 0], index=series['unique_id']\n",
" ).groupby('unique_id', observed=True)\n",
" .diff()\n",
+ " .dropna()\n",
" .values\n",
")\n",
"np.testing.assert_allclose(prep['y'].values, expected)"
@@ -754,7 +755,7 @@
" target_transforms=[boxcox_global, single_difference]\n",
")\n",
"prep_pl = fcst_pl.preprocess(series_pl, dropna=False)\n",
- "pd.testing.assert_frame_equal(prep, prep_pl.to_pandas())"
+ "pd.testing.assert_frame_equal(prep.reset_index(drop=True), prep_pl.to_pandas())"
]
}
],