Skip to content

Commit

Permalink
Backports for v0.13.3 (#2954)
Browse files Browse the repository at this point in the history
* Fix JsonLinesFile slicing. (#2925)

* Zebras: Fix index handling of SplitFrame.resize. (#2938)

* Docs: fix missing values use-case in `PandasDataset` docs (#2941)

* Ignore F403 errors in preludes. (#2948)

* Fix: prevent accumulation of `SelectFields` in `PyTorchPredictor` (#2951)

* Prevent redundant accumulation of fields

* update fix

---------

Co-authored-by: Cameronwood611 <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* [Docs] fix link to NPTS implementation (#2953)

* Revert "Fix JsonLinesFile slicing. (#2925)"

This reverts commit fa7f9a0.

---------

Co-authored-by: Jasper <[email protected]>
Co-authored-by: cneely33 <[email protected]>
Co-authored-by: cameronwood611 <[email protected]>
Co-authored-by: Cameronwood611 <[email protected]>
5 people authored Aug 7, 2023
1 parent e1c33cd commit 48d22d7
Showing 6 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/getting_started/models.md
Original file line number Diff line number Diff line change
@@ -78,4 +78,4 @@ NPTS | Local | Un
[Prophet]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/ext/prophet/_predictor.py
[NaiveSeasonal]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/model/seasonal_naive/_predictor.py
[Naive2]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/ext/naive_2/_predictor.py
[NPTS]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/ext/npts/_predictor.py
[NPTS]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/model/npts/_predictor.py
4 changes: 2 additions & 2 deletions docs/tutorials/data_manipulation/pandasdataframes.md.template
Original file line number Diff line number Diff line change
@@ -123,8 +123,8 @@ from gluonts.dataset.pandas import PandasDataset

max_end = max(df.groupby("item_id").apply(lambda _df: _df.index[-1]))
dfs_dict = {}
for item_id, gdf in df.groupby("item_id"):
new_index = pd.date_range(gdf.index[0], end=max_end, freq="1D")
for item_id, gdf in df_missing_val.groupby("item_id"):
new_index = pd.date_range(gdf.index[0], end=max_end, freq="1H")
dfs_dict[item_id] = gdf.reindex(new_index).drop("item_id", axis=1)

ds = PandasDataset(dfs_dict, target="target")
2 changes: 1 addition & 1 deletion src/gluonts/mx/prelude.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# flake8: noqa: F401
# flake8: noqa: F401, F403

from .component import *
from .serde import *
8 changes: 4 additions & 4 deletions src/gluonts/torch/model/predictor.py
Original file line number Diff line number Diff line change
@@ -73,12 +73,12 @@ def network(self) -> nn.Module:
def predict(
self, dataset: Dataset, num_samples: Optional[int] = None
) -> Iterator[Forecast]:
self.input_transform += SelectFields(
self.input_names + self.required_fields, allow_missing=True
)
inference_data_loader = InferenceDataLoader(
dataset,
transform=self.input_transform,
transform=self.input_transform
+ SelectFields(
self.input_names + self.required_fields, allow_missing=True
),
batch_size=self.batch_size,
stack_fn=lambda data: batchify(data, self.device),
)
2 changes: 1 addition & 1 deletion src/gluonts/torch/prelude.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# flake8: noqa: F401
# flake8: noqa: F401, F403

from .component import *
from .model.forecast_generator import *
5 changes: 4 additions & 1 deletion src/gluonts/zebras/_split_frame.py
Original file line number Diff line number Diff line change
@@ -180,7 +180,10 @@ def resize(
future_length = maybe.unwrap_or(future_length, self.future_length)

if index is not None:
start = index[0] + (self.past_length + past_length)
# Calculate new start. If current past_length is larger than the
# the new one, we shift it to the right, if it's smaller, we need
# to go further into the past (shift to the left)
start = index[0] + (self.past_length - past_length)
index = start.periods(past_length + future_length)

return _replace(

0 comments on commit 48d22d7

Please sign in to comment.