From e869465e183ad647e6319aa0b88b14b1bee2bfca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Tue, 12 Nov 2024 12:37:47 -0600 Subject: [PATCH] enh(distributed): propagate null features in spark (#448) --- mlforecast/distributed/forecast.py | 4 +++- nbs/distributed.forecast.ipynb | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index 401a1cde..defef0ad 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -377,7 +377,9 @@ def _fit( ] self.models_ = {} if SPARK_INSTALLED and isinstance(data, SparkDataFrame): - featurizer = VectorAssembler(inputCols=features, outputCol="features") + featurizer = VectorAssembler( + inputCols=features, outputCol="features", handleInvalid="keep" + ) train_data = featurizer.transform(prep)[target_col, "features"] for name, model in self.models.items(): trained_model = model._pre_fit(target_col).fit(train_data) diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index 1ccd9418..e287d034 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -431,7 +431,9 @@ " features = [x for x in fa.get_column_names(prep) if x not in {id_col, time_col, target_col}]\n", " self.models_ = {}\n", " if SPARK_INSTALLED and isinstance(data, SparkDataFrame):\n", - " featurizer = VectorAssembler(inputCols=features, outputCol=\"features\")\n", + " featurizer = VectorAssembler(\n", + " inputCols=features, outputCol=\"features\", handleInvalid=\"keep\"\n", + " )\n", " train_data = featurizer.transform(prep)[target_col, \"features\"]\n", " for name, model in self.models.items():\n", " trained_model = model._pre_fit(target_col).fit(train_data)\n",