diff --git a/mlforecast/core.py b/mlforecast/core.py index 846330e3..6357ab07 100644 --- a/mlforecast/core.py +++ b/mlforecast/core.py @@ -402,16 +402,23 @@ def _transform( target = target[self._restore_idxs] # determine rows to keep + target_nulls = np.isnan(target) + if target_nulls.ndim == 2: + # target nulls for each horizon are dropped in MLForecast.fit_models + # we just drop rows here for which all the target values are null + target_nulls = target_nulls.all(axis=1) if dropna: feature_nulls = np.full(df.shape[0], False) for feature_vals in features.values(): feature_nulls |= np.isnan(feature_vals) - target_nulls = np.isnan(target) - if target_nulls.ndim == 2: - # target nulls for each horizon are dropped in MLForecast.fit_models - # we just drop rows here for which all the target values are null - target_nulls = target_nulls.all(axis=1) keep_rows = ~(feature_nulls | target_nulls) + else: + # we always want to drop rows with nulls in the target + keep_rows = ~target_nulls + + self._dropped_series: Optional[np.ndarray] = None + if not keep_rows.all(): + # remove rows with nulls for k, v in features.items(): features[k] = v[keep_rows] target = target[keep_rows] @@ -422,7 +429,7 @@ def _transform( last_idxs = self._sort_idxs[last_idxs] last_vals_nan = ~keep_rows[last_idxs] if last_vals_nan.any(): - self._dropped_series: Optional[np.ndarray] = np.where(last_vals_nan)[0] + self._dropped_series = np.where(last_vals_nan)[0] dropped_ids = reprlib.repr(list(self.uids[self._dropped_series])) warnings.warn( "The following series were dropped completely " @@ -430,11 +437,9 @@ def _transform( "These series won't show up if you use `MLForecast.forecast_fitted_values()`.\n" "You can set `dropna=False` or use transformations that require less samples to mitigate this" ) - else: - self._dropped_series = None elif isinstance(df, pd.DataFrame): + # we'll be assigning columns below, so we need to copy df = df.copy(deep=False) - self._dropped_series = None # once we've computed the features and target we can slice the series update_samples = [ diff --git a/nbs/core.ipynb b/nbs/core.ipynb index df3e198c..96a58a49 100644 --- a/nbs/core.ipynb +++ b/nbs/core.ipynb @@ -20,16 +20,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "#|hide\n", "%load_ext autoreload\n", @@ -895,16 +886,23 @@ " target = target[self._restore_idxs] \n", "\n", " # determine rows to keep\n", + " target_nulls = np.isnan(target)\n", + " if target_nulls.ndim == 2:\n", + " # target nulls for each horizon are dropped in MLForecast.fit_models\n", + " # we just drop rows here for which all the target values are null\n", + " target_nulls = target_nulls.all(axis=1)\n", " if dropna:\n", " feature_nulls = np.full(df.shape[0], False)\n", " for feature_vals in features.values():\n", " feature_nulls |= np.isnan(feature_vals)\n", - " target_nulls = np.isnan(target)\n", - " if target_nulls.ndim == 2:\n", - " # target nulls for each horizon are dropped in MLForecast.fit_models\n", - " # we just drop rows here for which all the target values are null\n", - " target_nulls = target_nulls.all(axis=1)\n", " keep_rows = ~(feature_nulls | target_nulls)\n", + " else:\n", + " # we always want to drop rows with nulls in the target\n", + " keep_rows = ~target_nulls\n", + "\n", + " self._dropped_series: Optional[np.ndarray] = None\n", + " if not keep_rows.all():\n", + " # remove rows with nulls\n", " for k, v in features.items():\n", " features[k] = v[keep_rows]\n", " target = target[keep_rows]\n", @@ -915,7 +913,7 @@ " last_idxs = self._sort_idxs[last_idxs]\n", " last_vals_nan = ~keep_rows[last_idxs]\n", " if last_vals_nan.any():\n", - " self._dropped_series: Optional[np.ndarray] = np.where(last_vals_nan)[0] \n", + " self._dropped_series = np.where(last_vals_nan)[0] \n", " dropped_ids = reprlib.repr(list(self.uids[self._dropped_series]))\n", " warnings.warn(\n", " \"The following series were dropped completely \"\n", @@ -923,11 +921,9 @@ " \"These series won't show up if you use `MLForecast.forecast_fitted_values()`.\\n\"\n", " \"You can set `dropna=False` or use transformations that require less samples to mitigate this\"\n", " )\n", - " else:\n", - " self._dropped_series = None\n", " elif isinstance(df, pd.DataFrame):\n", + " # we'll be assigning columns below, so we need to copy\n", " df = df.copy(deep=False)\n", - " self._dropped_series = None\n", "\n", " # once we've computed the features and target we can slice the series\n", " update_samples = [\n", @@ -1701,7 +1697,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L511){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.fit_transform\n", "\n", @@ -1723,7 +1719,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L496){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L511){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.fit_transform\n", "\n", @@ -2016,7 +2012,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.predict\n", "\n", @@ -2030,7 +2026,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L743){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.predict\n", "\n", @@ -2168,7 +2164,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L863){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.update\n", "\n", @@ -2181,7 +2177,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L848){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L863){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TimeSeries.update\n", "\n", @@ -2575,6 +2571,23 @@ "ts.fit_transform(series, 'unique_id', 'ds', 'y')\n", "assert ts.keep_last_n is None" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# no target nulls when dropna=False\n", + "ts = TimeSeries(\n", + " freq='D',\n", + " lags=[1, 2],\n", + " target_transforms=[Differences([5])],\n", + ")\n", + "prep = ts.fit_transform(series, 'unique_id', 'ds', 'y', dropna=False)\n", + "assert not prep['y'].isnull().any()" + ] } ], "metadata": { diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb index 6f355810..4330ffdf 100644 --- a/nbs/docs/getting-started/quick_start_distributed.ipynb +++ b/nbs/docs/getting-started/quick_start_distributed.ipynb @@ -448,36 +448,36 @@ " 0\n", " id_00\n", " 2002-09-27 00:00:00\n", - " 21.609526\n", - " 22.114111\n", + " 22.267619\n", + " 21.835798\n", " \n", " \n", " 1\n", " id_00\n", " 2002-09-28 00:00:00\n", - " 85.623013\n", - " 84.309696\n", + " 85.230055\n", + " 83.996424\n", " \n", " \n", " 2\n", " id_00\n", " 2002-09-29 00:00:00\n", - " 163.107685\n", - " 163.20679\n", + " 168.256154\n", + " 163.076652\n", " \n", " \n", " 3\n", " id_00\n", " 2002-09-30 00:00:00\n", - " 246.96872\n", - " 245.510858\n", + " 246.712244\n", + " 245.827467\n", " \n", " \n", " 4\n", " id_00\n", " 2002-10-01 00:00:00\n", - " 318.521367\n", - " 314.479718\n", + " 314.184225\n", + " 315.257849\n", " \n", " \n", "\n", @@ -485,11 +485,11 @@ ], "text/plain": [ " unique_id ds DaskXGBForecast DaskLGBMForecast\n", - "0 id_00 2002-09-27 00:00:00 21.609526 22.114111\n", - "1 id_00 2002-09-28 00:00:00 85.623013 84.309696\n", - "2 id_00 2002-09-29 00:00:00 163.107685 163.20679\n", - "3 id_00 2002-09-30 00:00:00 246.96872 245.510858\n", - "4 id_00 2002-10-01 00:00:00 318.521367 314.479718" + "0 id_00 2002-09-27 00:00:00 22.267619 21.835798\n", + "1 id_00 2002-09-28 00:00:00 85.230055 83.996424\n", + "2 id_00 2002-09-29 00:00:00 168.256154 163.076652\n", + "3 id_00 2002-09-30 00:00:00 246.712244 245.827467\n", + "4 id_00 2002-10-01 00:00:00 314.184225 315.257849" ] }, "execution_count": null, @@ -792,68 +792,68 @@ " \n", " \n", " \n", - " 0\n", - " id_00\n", - " 2002-08-16 00:00:00\n", - " 23.192749\n", - " 21.986437\n", + " 61\n", + " id_04\n", + " 2002-08-21 00:00:00\n", + " 68.3418\n", + " 68.944539\n", " 2002-08-15 00:00:00\n", - " 11.878591\n", + " 69.699857\n", " \n", " \n", - " 30\n", - " id_02\n", - " 2002-08-18 00:00:00\n", - " 96.59974\n", - " 96.568057\n", + " 83\n", + " id_15\n", + " 2002-08-29 00:00:00\n", + " 199.315403\n", + " 199.663555\n", " 2002-08-15 00:00:00\n", - " 94.706551\n", + " 206.082864\n", " \n", " \n", - " 80\n", - " id_05\n", - " 2002-08-26 00:00:00\n", - " 257.210466\n", - " 255.908309\n", + " 103\n", + " id_17\n", + " 2002-08-21 00:00:00\n", + " 156.822598\n", + " 158.018246\n", " 2002-08-15 00:00:00\n", - " 246.051086\n", + " 152.227984\n", " \n", " \n", - " 36\n", - " id_12\n", - " 2002-08-24 00:00:00\n", - " 401.081335\n", - " 401.697836\n", + " 61\n", + " id_24\n", + " 2002-08-21 00:00:00\n", + " 136.598356\n", + " 136.576865\n", " 2002-08-15 00:00:00\n", - " 424.296882\n", + " 138.559945\n", " \n", " \n", - " 91\n", - " id_16\n", - " 2002-08-23 00:00:00\n", - " 315.036479\n", - " 315.368377\n", + " 36\n", + " id_33\n", + " 2002-08-24 00:00:00\n", + " 95.6072\n", + " 96.249354\n", " 2002-08-15 00:00:00\n", - " 300.419406\n", + " 102.068997\n", " \n", " \n", "\n", "" ], "text/plain": [ - " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n", - "0 id_00 2002-08-16 00:00:00 23.192749 21.986437 \n", - "30 id_02 2002-08-18 00:00:00 96.59974 96.568057 \n", - "80 id_05 2002-08-26 00:00:00 257.210466 255.908309 \n", - "36 id_12 2002-08-24 00:00:00 401.081335 401.697836 \n", - "91 id_16 2002-08-23 00:00:00 315.036479 315.368377 \n", + " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n", + "61 id_04 2002-08-21 00:00:00 68.3418 68.944539 \n", + "83 id_15 2002-08-29 00:00:00 199.315403 199.663555 \n", + "103 id_17 2002-08-21 00:00:00 156.822598 158.018246 \n", + "61 id_24 2002-08-21 00:00:00 136.598356 136.576865 \n", + "36 id_33 2002-08-24 00:00:00 95.6072 96.249354 \n", "\n", - " cutoff y \n", - "0 2002-08-15 00:00:00 11.878591 \n", - "30 2002-08-15 00:00:00 94.706551 \n", - "80 2002-08-15 00:00:00 246.051086 \n", - "36 2002-08-15 00:00:00 424.296882 \n", - "91 2002-08-15 00:00:00 300.419406 " + " cutoff y \n", + "61 2002-08-15 00:00:00 69.699857 \n", + "83 2002-08-15 00:00:00 206.082864 \n", + "103 2002-08-15 00:00:00 152.227984 \n", + "61 2002-08-15 00:00:00 138.559945 \n", + "36 2002-08-15 00:00:00 102.068997 " ] }, "execution_count": null, @@ -918,7 +918,8 @@ " ),\n", " static_features=['static_0', 'static_1'],\n", ")\n", - "assert reduced_train.groupby('unique_id').size().compute().max() == input_size" + "dropped_samples = fcst._base_ts.target_transforms[0].differences[0]\n", + "assert reduced_train.groupby('unique_id').size().compute().max() == input_size - dropped_samples" ] }, { @@ -1183,7 +1184,15 @@ "execution_count": null, "id": "d06d2230-60f5-47f1-820b-af2ca7311b41", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], "source": [ "preds = fcst.predict(7, X_df=future).toPandas()" ] @@ -1226,35 +1235,35 @@ " 0\n", " id_00\n", " 2002-09-27\n", - " 15.102403\n", + " 15.053577\n", " 18.631477\n", " \n", " \n", " 1\n", " id_00\n", " 2002-09-28\n", - " 92.980261\n", + " 93.010037\n", " 93.796269\n", " \n", " \n", " 2\n", " id_00\n", " 2002-09-29\n", - " 160.090375\n", + " 160.120148\n", " 159.582315\n", " \n", " \n", " 3\n", " id_00\n", " 2002-09-30\n", - " 250.416113\n", + " 250.445885\n", " 250.861651\n", " \n", " \n", " 4\n", " id_00\n", " 2002-10-01\n", - " 323.306184\n", + " 323.335956\n", " 321.564089\n", " \n", " \n", @@ -1263,11 +1272,11 @@ ], "text/plain": [ " unique_id ds SparkLGBMForecast SparkXGBForecast\n", - "0 id_00 2002-09-27 15.102403 18.631477\n", - "1 id_00 2002-09-28 92.980261 93.796269\n", - "2 id_00 2002-09-29 160.090375 159.582315\n", - "3 id_00 2002-09-30 250.416113 250.861651\n", - "4 id_00 2002-10-01 323.306184 321.564089" + "0 id_00 2002-09-27 15.053577 18.631477\n", + "1 id_00 2002-09-28 93.010037 93.796269\n", + "2 id_00 2002-09-29 160.120148 159.582315\n", + "3 id_00 2002-09-30 250.445885 250.861651\n", + "4 id_00 2002-10-01 323.335956 321.564089" ] }, "execution_count": null, @@ -2008,48 +2017,48 @@ " \n", " \n", " 0\n", - " id_04\n", - " 2002-09-20\n", - " 118.982094\n", - " 117.577477\n", + " id_05\n", + " 2002-09-21\n", + " 108.285187\n", + " 108.619698\n", " 2002-09-12\n", - " 118.603489\n", + " 108.726387\n", " \n", " \n", " 1\n", - " id_04\n", - " 2002-09-24\n", - " 51.461491\n", - " 50.120552\n", + " id_08\n", + " 2002-09-16\n", + " 26.287956\n", + " 26.589603\n", " 2002-09-12\n", - " 52.668389\n", + " 27.980670\n", " \n", " \n", " 2\n", - " id_05\n", - " 2002-09-20\n", - " 27.594826\n", - " 24.421537\n", + " id_08\n", + " 2002-09-25\n", + " 83.210945\n", + " 84.194962\n", " 2002-09-12\n", - " 20.120710\n", + " 86.344885\n", " \n", " \n", " 3\n", - " id_05\n", - " 2002-09-25\n", - " 411.615204\n", - " 412.093384\n", + " id_11\n", + " 2002-09-22\n", + " 416.994843\n", + " 417.106506\n", " 2002-09-12\n", - " 419.621422\n", + " 425.434661\n", " \n", " \n", " 4\n", - " id_08\n", - " 2002-09-25\n", - " 83.210945\n", - " 83.842705\n", + " id_16\n", + " 2002-09-14\n", + " 377.916382\n", + " 375.421600\n", " 2002-09-12\n", - " 86.344885\n", + " 400.361977\n", " \n", " \n", "\n", @@ -2057,11 +2066,11 @@ ], "text/plain": [ " unique_id ds RayLGBMForecast RayXGBForecast cutoff y\n", - "0 id_04 2002-09-20 118.982094 117.577477 2002-09-12 118.603489\n", - "1 id_04 2002-09-24 51.461491 50.120552 2002-09-12 52.668389\n", - "2 id_05 2002-09-20 27.594826 24.421537 2002-09-12 20.120710\n", - "3 id_05 2002-09-25 411.615204 412.093384 2002-09-12 419.621422\n", - "4 id_08 2002-09-25 83.210945 83.842705 2002-09-12 86.344885" + "0 id_05 2002-09-21 108.285187 108.619698 2002-09-12 108.726387\n", + "1 id_08 2002-09-16 26.287956 26.589603 2002-09-12 27.980670\n", + "2 id_08 2002-09-25 83.210945 84.194962 2002-09-12 86.344885\n", + "3 id_11 2002-09-22 416.994843 417.106506 2002-09-12 425.434661\n", + "4 id_16 2002-09-14 377.916382 375.421600 2002-09-12 400.361977" ] }, "execution_count": null, diff --git a/nbs/target_transforms.ipynb b/nbs/target_transforms.ipynb index 589b9c57..b6163f55 100644 --- a/nbs/target_transforms.ipynb +++ b/nbs/target_transforms.ipynb @@ -732,6 +732,7 @@ " sk_boxcox.fit_transform(series[['y']])[:, 0], index=series['unique_id']\n", " ).groupby('unique_id', observed=True)\n", " .diff()\n", + " .dropna()\n", " .values\n", ")\n", "np.testing.assert_allclose(prep['y'].values, expected)" @@ -754,7 +755,7 @@ " target_transforms=[boxcox_global, single_difference]\n", ")\n", "prep_pl = fcst_pl.preprocess(series_pl, dropna=False)\n", - "pd.testing.assert_frame_equal(prep, prep_pl.to_pandas())" + "pd.testing.assert_frame_equal(prep.reset_index(drop=True), prep_pl.to_pandas())" ] } ],