Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add re-exporting mechanism. #3002

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/gluonts/meta/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.


from importlib import import_module
from types import ModuleType

from dataclasses import dataclass, field


@dataclass
class Exports:
name: str
module: ModuleType = field(init=False)
exports: dict = field(default_factory=dict)

def __post_init__(self):
self.module = import_module(self.name)

def re_export(self, *paths):
for path in paths:
path = ".".join([self.name, path])

parts = path.rsplit(":", 1)
if len(parts) == 1:
module_name = parts[0]
name = ""
else:
module_name, name = parts

module = import_module(module_name)

if name:
value = getattr(module, name)
if hasattr(value, "__module__"):
value.__module__ = self.name
else:
value = module

setattr(self.module, name, value)

self.exports[name] = value

def definition(self, cls):
self.exports[cls.__name__] = cls
return cls

def register(self, name, value):
self.exports[name] = value
return value

def __iter__(self):
yield from self.exports

def __len__(self):
return len(self.exports)

def __getitem__(self, index):
return list(self.exports)[index]


def re_export(module, *names, **kwargs):
"""Import sub-paths and assign `__module__` on definitions.

Often there is a difference between the intended external interface and
internal code structure. For example, the ``Predictor`` class is defined
in ``gluonts/model/predictor.py``, but the canonical path is
``gluonts.model.Predictor`` and not ``gluonts.model.predictor.Predictor``.

The desired behaviour can be achieved using the following in
``model/__init__.py``::

__all__ = re_export(
__name__,
"estimator:Estimator",
"predictor:Predictor",
...
)

"""

names = list(names)

for base, paths in kwargs.items():
for path in paths:
names.append(":".join([base, path]))

exports = Exports(module)
for name in names:
exports.re_export(name)
return exports
38 changes: 21 additions & 17 deletions src/gluonts/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,25 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__all__ = [
"Estimator",
"IncrementallyTrainable",
"Predictor",
"Forecast",
"SampleForecast",
"QuantileForecast",
"Input",
"InputSpec",
"evaluate_forecasts",
"evaluate_model",
]
from gluonts.meta.export import re_export

from .estimator import Estimator, IncrementallyTrainable
from .predictor import Predictor
from .forecast import Forecast, SampleForecast, QuantileForecast
from .inputs import Input, InputSpec
from .evaluation import evaluate_forecasts, evaluate_model
__all__ = re_export(
__name__,
# NOTE: order is important! we need to import from predictor and forecast
# first, otherwise we get circular imports
predictor=["Predictor"],
forecast=[
"Forecast",
"SampleForecast",
"QuantileForecast",
],
estimator=[
"Estimator",
"IncrementallyTrainable",
],
evaluation=[
"evaluate_forecasts",
"evaluate_model",
],
inputs=["Input", "InputSpec"],
)
88 changes: 37 additions & 51 deletions src/gluonts/time_feature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,41 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from ._base import (
Constant,
day_of_month,
day_of_month_index,
day_of_week,
day_of_week_index,
day_of_year,
day_of_year_index,
hour_of_day,
hour_of_day_index,
minute_of_hour,
minute_of_hour_index,
month_of_year,
month_of_year_index,
second_of_minute,
second_of_minute_index,
TimeFeature,
week_of_year,
week_of_year_index,
norm_freq_str,
time_features_from_frequency_str,
)
from .holiday import SPECIAL_DATE_FEATURES, SpecialDateFeatureSet
from .lag import get_lags_for_frequency
from .seasonality import get_seasonality
from gluonts.meta.export import re_export


__all__ = [
"Constant",
"day_of_month",
"day_of_month_index",
"day_of_week",
"day_of_week_index",
"day_of_year",
"day_of_year_index",
"get_lags_for_frequency",
"get_seasonality",
"hour_of_day",
"hour_of_day_index",
"minute_of_hour",
"minute_of_hour_index",
"month_of_year",
"month_of_year_index",
"second_of_minute",
"second_of_minute_index",
"norm_freq_str",
"SPECIAL_DATE_FEATURES",
"SpecialDateFeatureSet",
"time_features_from_frequency_str",
"TimeFeature",
"week_of_year",
"week_of_year_index",
]
__all__ = re_export(
__name__,
_base=[
"Constant",
"day_of_month",
"day_of_month_index",
"day_of_week",
"day_of_week_index",
"day_of_year",
"day_of_year_index",
"hour_of_day",
"hour_of_day_index",
"minute_of_hour",
"minute_of_hour_index",
"month_of_year",
"month_of_year_index",
"second_of_minute",
"second_of_minute_index",
"TimeFeature",
"week_of_year",
"week_of_year_index",
"norm_freq_str",
"time_features_from_frequency_str",
],
holiday=[
"SPECIAL_DATE_FEATURES",
"SpecialDateFeatureSet",
],
lag=[
"get_lags_for_frequency",
],
seasonality=[
"get_seasonality",
],
)
39 changes: 17 additions & 22 deletions src/gluonts/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,24 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__all__ = [
"PyTorchLightningEstimator",
"PyTorchPredictor",
"DeepNPTSEstimator",
"DeepAREstimator",
"SimpleFeedForwardEstimator",
"TemporalFusionTransformerEstimator",
"WaveNetEstimator",
"DLinearEstimator",
"PatchTSTEstimator",
"LagTSTEstimator",
]

import torch
from .model.estimator import PyTorchLightningEstimator
from .model.predictor import PyTorchPredictor
from .model.deep_npts import DeepNPTSEstimator
from .model.deepar import DeepAREstimator
from .model.simple_feedforward import SimpleFeedForwardEstimator
from .model.tft import TemporalFusionTransformerEstimator
from .model.wavenet import WaveNetEstimator
from .model.d_linear import DLinearEstimator
from .model.patch_tst import PatchTSTEstimator
from .model.lag_tst import LagTSTEstimator

from gluonts.meta.export import re_export

__all__ = re_export(
__name__,
".model.estimator:PyTorchLightningEstimator",
".model.predictor:PyTorchPredictor",
".model.deep_npts:DeepNPTSEstimator",
".model.deepar:DeepAREstimator",
".model.simple_feedforward:SimpleFeedForwardEstimator",
".model.tft:TemporalFusionTransformerEstimator",
".model.wavenet:WaveNetEstimator",
".model.d_linear:DLinearEstimator",
".model.patch_tst:PatchTSTEstimator",
".model.lag_tst:LagTSTEstimator",
)


from . import prelude as _ # noqa
Expand Down
37 changes: 15 additions & 22 deletions src/gluonts/zebras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,26 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from gluonts.meta.export import re_export


__all__ = [
"Freq",
"freq",
"Period",
"Periods",
"period",
"periods",
"BatchTimeFrame",
"TimeFrame",
"time_frame",
"BatchSplitFrame",
"SplitFrame",
"split_frame",
"time_series",
"BatchTimeSeries",
"TimeSeries",
"schema",
"batch",
"from_pandas",
*re_export(
__name__,
"schema",
_freq=["Freq", "freq"],
_period=["Period", "Periods", "period", "periods"],
_time_frame=["BatchTimeFrame", "TimeFrame", "time_frame"],
_split_frame=["BatchSplitFrame", "SplitFrame", "split_frame"],
_time_series=["time_series", "BatchTimeSeries", "TimeSeries"],
),
]

from typing import TypeVar

from ._freq import Freq, freq
from ._period import period, Period, periods, Periods
from ._split_frame import split_frame, SplitFrame, BatchSplitFrame
from ._time_frame import time_frame, TimeFrame, BatchTimeFrame
from ._time_series import time_series, TimeSeries, BatchTimeSeries
from . import schema
from . import TimeSeries, TimeFrame, SplitFrame, Periods, Freq # type: ignore

Batchable = TypeVar("Batchable", TimeSeries, TimeFrame, SplitFrame)

Expand Down