From 3b5fdb89ad7da51a4f82b1f98f6048eecb917494 Mon Sep 17 00:00:00 2001 From: javiber Date: Fri, 3 May 2024 09:49:21 -0300 Subject: [PATCH 1/9] Added moving quantile operator --- temporian/core/event_set_ops.py | 18 + temporian/core/operators/window/BUILD | 15 + temporian/core/operators/window/__init__.py | 1 + temporian/core/operators/window/base.py | 12 +- .../core/operators/window/moving_quantile.py | 93 ++++ temporian/core/operators/window/test/BUILD | 13 + .../window/test/test_moving_quantile.py | 371 ++++++++++++++ .../implementation/numpy/operators/BUILD | 1 + .../numpy/operators/__init__.py | 1 + .../numpy/operators/window/BUILD | 12 + .../numpy/operators/window/moving_quantile.py | 64 +++ .../implementation/numpy_cc/operators/BUILD | 9 +- .../numpy_cc/operators/custom_heap.h | 76 +++ .../numpy_cc/operators/window.cc | 459 ++++++++++++++---- 14 files changed, 1036 insertions(+), 109 deletions(-) create mode 100644 temporian/core/operators/window/moving_quantile.py create mode 100644 temporian/core/operators/window/test/test_moving_quantile.py create mode 100644 temporian/implementation/numpy/operators/window/moving_quantile.py create mode 100644 temporian/implementation/numpy_cc/operators/custom_heap.h diff --git a/temporian/core/event_set_ops.py b/temporian/core/event_set_ops.py index 08bb70fd5..247780177 100644 --- a/temporian/core/event_set_ops.py +++ b/temporian/core/event_set_ops.py @@ -3218,6 +3218,24 @@ def moving_min( return moving_min(self, window_length=window_length, sampling=sampling) + def moving_quantile( + self: EventSetOrNode, + window_length: WindowLength, + quantile: float, + sampling: Optional[EventSetOrNode] = None, + ) -> EventSetOrNode: + """ """ + from temporian.core.operators.window.moving_quantile import ( + moving_quantile, + ) + + return moving_quantile( + self, + window_length=window_length, + quantile=quantile, + sampling=sampling, + ) + def moving_standard_deviation( self: EventSetOrNode, window_length: WindowLength, diff --git a/temporian/core/operators/window/BUILD b/temporian/core/operators/window/BUILD index 680fcda90..0d8c049e4 100644 --- a/temporian/core/operators/window/BUILD +++ b/temporian/core/operators/window/BUILD @@ -142,3 +142,18 @@ py_library( "//temporian/core/data:schema", ], ) + +py_library( + name = "moving_quantile", + srcs = ["moving_quantile.py"], + srcs_version = "PY3", + deps = [ + ":base", + "//temporian/core:compilation", + "//temporian/core:operator_lib", + "//temporian/core:typing", + "//temporian/core/data:dtype", + "//temporian/core/data:node", + "//temporian/core/data:schema", + ], +) diff --git a/temporian/core/operators/window/__init__.py b/temporian/core/operators/window/__init__.py index f4202294e..6f25f3cf5 100644 --- a/temporian/core/operators/window/__init__.py +++ b/temporian/core/operators/window/__init__.py @@ -27,3 +27,4 @@ from temporian.core.operators.window.moving_max import moving_max from temporian.core.operators.window.moving_product import cumprod from temporian.core.operators.window.moving_product import moving_product +from temporian.core.operators.window.moving_quantile import moving_quantile diff --git a/temporian/core/operators/window/base.py b/temporian/core/operators/window/base.py index 577abbf13..6590898df 100644 --- a/temporian/core/operators/window/base.py +++ b/temporian/core/operators/window/base.py @@ -15,7 +15,7 @@ """Base calendar operator class definition.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, List, Optional, Mapping from temporian.core.data.duration_utils import normalize_duration @@ -34,6 +34,8 @@ class BaseWindowOperator(Operator, ABC): """Interface definition and common logic for window operators.""" + extra_attribute_def: List[Mapping[str, Any]] = [] + def __init__( self, input: EventSetNode, @@ -93,6 +95,7 @@ def __init__( creator=self, ), ) + self.add_extra_attributes() self.check() @@ -119,8 +122,14 @@ def has_sampling(self) -> bool: def has_variable_winlen(self) -> bool: return self._has_variable_winlen + def add_extra_attributes(self): + pass + @classmethod def build_op_definition(cls) -> pb.OperatorDef: + extra_attr_def = [ + pb.OperatorDef.Attribute(**attr) for attr in cls.extra_attribute_def + ] return pb.OperatorDef( key=cls.operator_def_key(), attributes=[ @@ -129,6 +138,7 @@ def build_op_definition(cls) -> pb.OperatorDef: type=pb.OperatorDef.Attribute.Type.FLOAT_64, is_optional=True, ), + *extra_attr_def, ], inputs=[ pb.OperatorDef.Input(key="input"), diff --git a/temporian/core/operators/window/moving_quantile.py b/temporian/core/operators/window/moving_quantile.py new file mode 100644 index 000000000..98dd82a59 --- /dev/null +++ b/temporian/core/operators/window/moving_quantile.py @@ -0,0 +1,93 @@ +# Copyright 2021 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Moving count operator class and public API function definition.""" + +from typing import Optional + +from temporian.core import operator_lib +from temporian.core.compilation import compile +from temporian.core.data.dtype import DType +from temporian.core.data.node import EventSetNode +from temporian.core.data.schema import FeatureSchema +from temporian.core.operators.window.base import BaseWindowOperator +from temporian.core.typing import EventSetOrNode, WindowLength +from temporian.proto import core_pb2 as pb + + +class MovingQuantileOperator(BaseWindowOperator): + extra_attribute_def = [ + { + "key": "quantile", + "is_optional": True, + "type": pb.OperatorDef.Attribute.Type.FLOAT_64, + } + ] + + def __init__( + self, + input: EventSetNode, + window_length: WindowLength, + quantile: float, + sampling: Optional[EventSetNode] = None, + ): + if quantile < 0 or quantile > 1: + raise ValueError( + ( + "`quantile` must be a float between 0 and 1. " + f"Received {quantile}" + ) + ) + self.quantile = quantile + super().__init__(input, window_length, sampling) + + def add_extra_attributes(self): + self.add_attribute("quantile", self.quantile) + + @classmethod + def operator_def_key(cls) -> str: + return "MOVING_QUANTILE" + + def get_feature_dtype(self, feature: FeatureSchema) -> DType: + if not feature.dtype.is_numerical: + raise ValueError( + "moving_quantile requires the input EventSet to contain numerical" + f" features only, but received feature {feature.name!r} with" + f" type {feature.dtype}" + ) + if feature.dtype.is_integer: + return DType.FLOAT32 + return feature.dtype + + +operator_lib.register_operator(MovingQuantileOperator) + + +@compile +def moving_quantile( + input: EventSetOrNode, + window_length: WindowLength, + quantile: float, + sampling: Optional[EventSetOrNode] = None, +) -> EventSetOrNode: + assert isinstance(input, EventSetNode) + if sampling is not None: + assert isinstance(sampling, EventSetNode) + + return MovingQuantileOperator( + input=input, + window_length=window_length, + quantile=quantile, + sampling=sampling, + ).outputs["output"] diff --git a/temporian/core/operators/window/test/BUILD b/temporian/core/operators/window/test/BUILD index ed625c49f..a6ab14761 100644 --- a/temporian/core/operators/window/test/BUILD +++ b/temporian/core/operators/window/test/BUILD @@ -134,3 +134,16 @@ py_test( "//temporian/test:utils", ], ) + +py_test( + name = "test_moving_quantile", + srcs = ["test_moving_quantile.py"], + srcs_version = "PY3", + deps = [ + # already_there/absl/testing:absltest + # already_there/absl/testing:parameterized + "//temporian/implementation/numpy/data:io", + "//temporian/core/data:duration", + "//temporian/test:utils", + ], +) diff --git a/temporian/core/operators/window/test/test_moving_quantile.py b/temporian/core/operators/window/test/test_moving_quantile.py new file mode 100644 index 000000000..702b6bae0 --- /dev/null +++ b/temporian/core/operators/window/test/test_moving_quantile.py @@ -0,0 +1,371 @@ +# Copyright 2021 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import nan + +import numpy as np +from absl.testing import absltest +from absl.testing.parameterized import parameters +from absl.testing.parameterized import TestCase + +from temporian.implementation.numpy.data.io import event_set +from temporian.test.utils import f32, f64, assertOperatorResult, i32, i64 +from temporian.core.data.duration import shortest + + +class QuantileTest(TestCase): + @parameters( + { # empty f32 + "timestamps": [], + "feature": f32([]), + "window": 5.0, + "output_feature": f32([]), + "quantile": 0.5, + }, + { # empty f64 + "timestamps": [], + "feature": f64([]), + "window": 5.0, + "output_feature": f64([]), + "quantile": 0.5, + }, + { # normal + "timestamps": [1, 2, 3, 5, 20], + "feature": f32([0, 1, 2, 3, 4]), + "window": 5.0, + "output_feature": f32([0.0, 0.5, 1.0, 1.5, 4.0]), + "quantile": 0.5, + }, + { # normal float 64 + "timestamps": [1, 2, 3, 5, 20], + "feature": f64([0, 1, 2, 3, 4]), + "window": 5.0, + "output_feature": f64([0.0, 0.5, 1.0, 1.5, 4.0]), + "quantile": 0.5, + }, + { # normal int32 + "timestamps": [1, 2, 3, 5, 20], + "feature": i32([0, 1, 2, 3, 4]), + "window": 5.0, + "output_feature": f32([0.0, 0.5, 1.0, 1.5, 4.0]), + "quantile": 0.5, + }, + { # normal int32 + "timestamps": [1, 2, 3, 5, 20], + "feature": i64([0, 1, 2, 3, 4]), + "window": 5.0, + "output_feature": f32([0.0, 0.5, 1.0, 1.5, 4.0]), + "quantile": 0.5, + }, + { # w nan + "timestamps": [1, 1.5, 2, 5, 20], + "feature": f32([10, nan, nan, 13, 14]), + "window": 1.0, + "output_feature": f32([10.0, 10.0, nan, 13.0, 14.0]), + "quantile": 0.5, + }, + { # other quantiles + "timestamps": [1, 2, 3, 5, 20], + "feature": f32([0, 1, 2, 3, 4]), + "window": 5.0, + "output_feature": f32([0.0, 0.0, 0.5, 1.0, 4.0]), + "quantile": 1 / 3, + }, + { # other quantiles + "timestamps": [1, 2, 3, 5, 20], + "feature": f32([0, 1, 2, 3, 4]), + "window": 5.0, + "output_feature": f32([0.0, 1.0, 2.0, 2.5, 4.0]), + "quantile": 0.75, + }, + ) + def test_without_sampling( + self, + timestamps, + feature, + window, + output_feature, + quantile, + ): + evset = event_set(timestamps=timestamps, features={"a": feature}) + + expected = event_set( + timestamps=timestamps, + features={"a": output_feature}, + same_sampling_as=evset, + ) + + result = evset.moving_quantile(window_length=window, quantile=quantile) + assertOperatorResult(self, result, expected) + + def test_numpy_equivalence(self): + size = 1_000 + max_jump = 10 + w = 10 + q = 0.8 + np.random.seed(0) + + timestamps = np.random.random_integers( + low=1, high=max_jump, size=size + ).cumsum() + feature = np.random.normal(size=size).cumsum() + + evset = event_set(timestamps=timestamps, features={"feature": feature}) + + np_quantiles = [] + for i in timestamps: + window = feature[(timestamps <= i) & (timestamps > i - w)] + np_quantiles.append( + np.quantile(window, q, method="averaged_inverted_cdf") + ) + + expected = event_set( + timestamps=timestamps, + features={"feature": np.array(np_quantiles)}, + same_sampling_as=evset, + ) + + result = evset.moving_quantile(window_length=w, quantile=q) + + assertOperatorResult(self, result, expected) + + @parameters( + { # empty f32 + "timestamps": [], + "feature": f32([]), + "window": 5.0, + "sampling_timestamps": [], + "output_feature": f32([]), + "quantile": 0.5, + }, + { # empty f64 + "timestamps": [], + "feature": f64([]), + "window": 5.0, + "sampling_timestamps": [], + "output_feature": f64([]), + "quantile": 0.5, + }, + { # normal + "timestamps": [1, 2, 3, 5, 6], + "feature": f32([10, 11, 12, 13, 14]), + "window": 3.0, + "sampling_timestamps": [-1.0, 1.0, 1.1, 3.0, 3.5, 6.0, 10.0], + "output_feature": f32([nan, 10.0, 10.0, 11.0, 11.0, 13.5, nan]), + "quantile": 0.5, + }, + { # w nan + "timestamps": [1, 2, 3, 5, 6], + "feature": f32([nan, 11, nan, 13, 14]), + "window": 1.0, + "sampling_timestamps": [1.0, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 6.0], + "output_feature": f32([nan, 11.0, 11.0, nan, nan, nan, 13.0, 14]), + "quantile": 0.5, + }, + ) + def test_with_sampling( + self, + timestamps, + feature, + window, + sampling_timestamps, + output_feature, + quantile, + ): + evset = event_set(timestamps=timestamps, features={"a": feature}) + + sampling = event_set(timestamps=sampling_timestamps) + + expected = event_set( + timestamps=sampling_timestamps, + features={"a": output_feature}, + same_sampling_as=sampling, + ) + + result = evset.moving_quantile( + window_length=window, sampling=sampling, quantile=quantile + ) + assertOperatorResult(self, result, expected) + + @parameters( + { # normal + "timestamps": [0, 1, 2, 3, 5, 20], + "feature": f32([nan, 10, 11, 12, 13, 14]), + "variable_window": f64([1, 1, 1.5, 0.5, 3.5, 20]), + "output_feature": f32([nan, 10, 10.5, 12, 12, 12]), + "quantile": 0.5, + }, + { # invalid values + "timestamps": [0, 1, 2, 3, 5, 6, 20], + "feature": [nan, 10, 11, 12, 13, 14, 15], + "variable_window": [1, -20, 3, 0, 10, nan, 19], + "output_feature": [nan, nan, 10.5, nan, 11.5, nan, 13], + "quantile": 0.5, + }, + { # repeated ts + "timestamps": [0, 2, 2, 2, 2, 5], + "feature": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0], + "variable_window": [1, 3, 0.5, np.inf, -1, 5], + "output_feature": [10, 12, 12.5, 12, nan, 13], + "quantile": 0.5, + }, + { # repeated ts same winlen + "timestamps": [2, 2, 2, 2], + "feature": f64([10, 11, 12, 13]), + "variable_window": f64([0, 1, 1, 2]), + "output_feature": [nan, 11.5, 11.5, 11.5], + "quantile": 0.5, + }, + { # empty + "timestamps": f64([]), + "feature": f64([]), + "variable_window": f64([]), + "output_feature": f64([]), + "quantile": 0.5, + }, + ) + def test_with_variable_winlen_same_sampling( + self, + timestamps, + feature, + variable_window, + output_feature, + quantile, + ): + evset = event_set(timestamps=timestamps, features={"a": feature}) + + window = event_set( + timestamps=timestamps, + features={"a": variable_window}, + same_sampling_as=evset, + ) + + expected = event_set( + timestamps=timestamps, + features={"a": output_feature}, + same_sampling_as=evset, + ) + + result = evset.moving_quantile(window_length=window, quantile=quantile) + assertOperatorResult(self, result, expected) + + @parameters( + { # normal + "timestamps": [0, 1, 2, 3, 5, 20], + "feature": [nan, 10, 11, 12, 13, 14], + "window_timestamps": [-1, 1, 4, 19, 20, 20], + "variable_window": [10, 0.5, 2.5, 19, 16, np.inf], + "output_feature": [nan, 10, 11.5, 11.5, 13.5, 12], + "quantile": 0.5, + }, + { # repeated ts + "timestamps": [0, 1, 2, 3, 5, 20], + "feature": [nan, 10, 11, 12, 13, 14], + "window_timestamps": [20, 20, 20, 20, 20, 20], + "variable_window": [16, 0.001, np.inf, 0, 1, 19], + "output_feature": [13.5, 14, 12, nan, 14, 12.5], + "quantile": 0.5, + }, + { # shortest duration + "timestamps": [1.999999, 2], + "feature": [10.0, 11.0], + "window_timestamps": [2, 2, 2, 2], + "variable_window": [1, 0.001, shortest, 0], + "output_feature": [10.5, 10.5, 11, nan], + "quantile": 0.5, + }, + { # invalid values + "timestamps": [0, 1, 2, 3, 5, 20], + "feature": [nan, 10, 11, 12, 13, 14], + "window_timestamps": [2, 2, 5, 5, 20, 20], + "variable_window": [1, -10, 3, 0, nan, 19], + "output_feature": [11, nan, 12.5, nan, nan, 12.5], + "quantile": 0.5, + }, + { # repeated ts same winlen + "timestamps": [0, 1, 2, 3], + "feature": f64([10, 11, 12, 13]), + "window_timestamps": [2, 2, 2, 2], + "variable_window": f64([0, 1, 1, 2]), + "output_feature": f64([nan, 12, 12, 11.5]), + "quantile": 0.5, + }, + { # empty arrays + "timestamps": f64([1]), + "feature": f64([10]), + "variable_window": f64([]), + "window_timestamps": f64([]), + "output_feature": f64([]), + "quantile": 0.5, + }, + ) + def test_with_variable_winlen_diff_sampling( + self, + timestamps, + feature, + variable_window, + window_timestamps, + output_feature, + quantile, + ): + evset = event_set(timestamps=timestamps, features={"a": feature}) + + window = event_set( + timestamps=window_timestamps, + features={"a": variable_window}, + ) + + expected = event_set( + timestamps=window_timestamps, + features={"a": output_feature}, + same_sampling_as=window, + ) + + result = evset.moving_quantile(window_length=window, quantile=quantile) + assertOperatorResult(self, result, expected) + + def test_with_index(self): + evset = event_set( + timestamps=[1, 2, 3, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2], + features={ + "x": ["X1", "X1", "X1", "X2", "X2", "X2", "X2", "X2", "X2"], + "y": ["Y1", "Y1", "Y1", "Y1", "Y1", "Y1", "Y2", "Y2", "Y2"], + "a": [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0], + }, + indexes=["x", "y"], + ) + expected = event_set( + timestamps=[1, 2, 3, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2], + features={ + "x": ["X1", "X1", "X1", "X2", "X2", "X2", "X2", "X2", "X2"], + "y": ["Y1", "Y1", "Y1", "Y1", "Y1", "Y1", "Y2", "Y2", "Y2"], + "a": [10.0, 10.5, 11.0, 13.0, 13.5, 14.0, 16.0, 16.5, 17.0], + }, + indexes=["x", "y"], + same_sampling_as=evset, + ) + result = evset.moving_quantile(window_length=5.0, quantile=0.5) + assertOperatorResult(self, result, expected) + + def test_error_input_bytes(self): + evset = event_set([1, 2], {"f": ["A", "B"]}) + with self.assertRaisesRegex( + ValueError, + "moving_quantile requires the input EventSet to contain", + ): + _ = evset.moving_quantile(1, quantile=0.5) + + +if __name__ == "__main__": + absltest.main() diff --git a/temporian/implementation/numpy/operators/BUILD b/temporian/implementation/numpy/operators/BUILD index bf3066249..3cc27b292 100644 --- a/temporian/implementation/numpy/operators/BUILD +++ b/temporian/implementation/numpy/operators/BUILD @@ -62,6 +62,7 @@ py_library( "//temporian/implementation/numpy/operators/window:moving_standard_deviation", "//temporian/implementation/numpy/operators/window:moving_sum", "//temporian/implementation/numpy/operators/window:simple_moving_average", + "//temporian/implementation/numpy/operators/window:moving_quantile", ], ) diff --git a/temporian/implementation/numpy/operators/__init__.py b/temporian/implementation/numpy/operators/__init__.py index d417a2851..47ab05156 100644 --- a/temporian/implementation/numpy/operators/__init__.py +++ b/temporian/implementation/numpy/operators/__init__.py @@ -46,6 +46,7 @@ from temporian.implementation.numpy.operators.window import moving_count from temporian.implementation.numpy.operators.window import moving_min from temporian.implementation.numpy.operators.window import moving_max +from temporian.implementation.numpy.operators.window import moving_quantile from temporian.implementation.numpy.operators.calendar import day_of_month from temporian.implementation.numpy.operators.calendar import day_of_week from temporian.implementation.numpy.operators.calendar import day_of_year diff --git a/temporian/implementation/numpy/operators/window/BUILD b/temporian/implementation/numpy/operators/window/BUILD index 15e956b4b..0ab842130 100644 --- a/temporian/implementation/numpy/operators/window/BUILD +++ b/temporian/implementation/numpy/operators/window/BUILD @@ -111,3 +111,15 @@ py_library( "//temporian/implementation/numpy_cc/operators:operators_cc", ], ) + +py_library( + name = "moving_quantile", + srcs = ["moving_quantile.py"], + srcs_version = "PY3", + deps = [ + ":base", + "//temporian/core/operators/window:moving_quantile", + "//temporian/implementation/numpy:implementation_lib", + "//temporian/implementation/numpy_cc/operators:operators_cc", + ], +) diff --git a/temporian/implementation/numpy/operators/window/moving_quantile.py b/temporian/implementation/numpy/operators/window/moving_quantile.py new file mode 100644 index 000000000..268294507 --- /dev/null +++ b/temporian/implementation/numpy/operators/window/moving_quantile.py @@ -0,0 +1,64 @@ +# Copyright 2021 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + + +from typing import List, Union, Optional, Dict +from temporian.core.data.duration_utils import NormalizedDuration +from temporian.core.operators.window.moving_quantile import ( + MovingQuantileOperator, +) +from temporian.implementation.numpy.data.event_set import ( + EventSet, +) +from temporian.implementation.numpy import implementation_lib +from temporian.implementation.numpy.operators.window.base import ( + BaseWindowNumpyImplementation, +) +from temporian.implementation.numpy_cc.operators import operators_cc + + +class MovingQuantileNumpyImplementation(BaseWindowNumpyImplementation): + """Numpy implementation of the simple moving average operator.""" + + def _implementation(self): + return operators_cc.moving_quantile + + def _compute( + self, + src_timestamps: np.ndarray, + src_features: List[np.ndarray], + sampling_timestamps: Optional[np.ndarray], + dst_features: List[np.ndarray], + window_length: Union[NormalizedDuration, np.ndarray], + ) -> None: + assert isinstance(self.operator, MovingQuantileOperator) + + implementation = self._implementation() + for src_ts in src_features: + kwargs = { + "evset_timestamps": src_timestamps, + "evset_values": src_ts, + "window_length": window_length, + "quantile": self.operator.quantile, + } + if sampling_timestamps is not None: + kwargs["sampling_timestamps"] = sampling_timestamps + dst_feature = implementation(**kwargs) + dst_features.append(dst_feature) + + +implementation_lib.register_operator_implementation( + MovingQuantileOperator, MovingQuantileNumpyImplementation +) diff --git a/temporian/implementation/numpy_cc/operators/BUILD b/temporian/implementation/numpy_cc/operators/BUILD index 9d676bc38..78005e851 100644 --- a/temporian/implementation/numpy_cc/operators/BUILD +++ b/temporian/implementation/numpy_cc/operators/BUILD @@ -17,11 +17,18 @@ pybind_library( deps = [":common"], ) +cc_library( + name = "custom_heap", + srcs = [], + hdrs = ["custom_heap.h"], + deps = [":common"], +) + pybind_library( name = "window", srcs = ["window.cc"], hdrs = ["window.h"], - deps = [":common"], + deps = [":common", "custom_heap"], ) pybind_library( diff --git a/temporian/implementation/numpy_cc/operators/custom_heap.h b/temporian/implementation/numpy_cc/operators/custom_heap.h new file mode 100644 index 000000000..136a9fa67 --- /dev/null +++ b/temporian/implementation/numpy_cc/operators/custom_heap.h @@ -0,0 +1,76 @@ +#include +#include +template +class CustomHeap { + private: + std::function compare; + std::list heap; + std::unordered_map::iterator> val_to_node; + + public: + CustomHeap(std::function compare) : compare(compare) {} + + void push(T value) { + heap.push_back(value); + auto it = std::prev(heap.end()); + val_to_node[value] = it; + // TODO: better sorting? + while (it != heap.begin()) { + auto parent = std::prev(it); + if (!compare(*parent, *it)) { + break; + } + // TODO: check that this swap is doing what I want + std::swap(*parent, *it); + val_to_node[*it] = it; + val_to_node[*parent] = parent; + it = parent; + } + } + + std::optional pop() { + if (heap.size() == 0) { + return {}; + } else { + auto value = heap.back(); + heap.pop_back(); + auto it = val_to_node.find(value); + val_to_node.erase(it); + return value; + } + } + + std::optional top() { + if (heap.empty()) { + return {}; + } else { + return heap.back(); + } + } + + void remove(T value) { + auto it = val_to_node.find(value); + if (it != val_to_node.end()) { + heap.erase(it->second); + val_to_node.erase(it); + } else { + // TODO: exception meant for debugging, remove it + throw std::invalid_argument("removing a value that doesn't exists"); + } + } + int size() { return heap.size(); } + int empty() { return heap.empty(); } + + void print() { + std::cout << "my_heap{" << std::endl << " ["; + std::for_each(heap.begin(), heap.end(), + [](const int n) { std::cout << n << ' '; }); + std::cout << "]" << std::endl; + + // std::cout << " {" << std::endl; + // for (const auto& pair : val_to_node) { + // std::cout << " " << pair.first << ": " << *(pair.second) << std::endl; + // } + // std::cout << " }" << std::endl; + } +}; diff --git a/temporian/implementation/numpy_cc/operators/window.cc b/temporian/implementation/numpy_cc/operators/window.cc index 99043595f..a3cda8c97 100644 --- a/temporian/implementation/numpy_cc/operators/window.cc +++ b/temporian/implementation/numpy_cc/operators/window.cc @@ -10,6 +10,8 @@ #include #include +#include "temporian/implementation/numpy_cc/operators/custom_heap.h" + namespace { namespace py = pybind11; @@ -21,6 +23,29 @@ using ArrayRef = py::detail::unchecked_reference; typedef size_t Idx; +// Note: We only use inheritance to compile check the code. +template +struct Accumulator { + Accumulator(const ArrayRef &values) : values(values) {} + + virtual ~Accumulator() = default; + virtual void Add(Idx idx) = 0; + virtual void Remove(Idx idx) = 0; + virtual OUTPUT Result() = 0; + + // // Add a value to left of the window. Relevant in deque-based accumulators. + virtual void AddLeft(Idx idx) { return Add(idx); } + + ArrayRef values; +}; + +template +struct QuantileAccumulator : public Accumulator { + QuantileAccumulator(const ArrayRef &values, float quantile) + : Accumulator(values), quantile(quantile) {} + float quantile; +}; + // NOTE: accumulate() is overloaded for the 4 possible combinations of: // - with or without external sampling // - with constant or variable window length @@ -28,10 +53,11 @@ typedef size_t Idx; // TODO: refactor to avoid code duplication where possible. // No external sampling, constant window length -template +template py::array_t accumulate(const ArrayD &evset_timestamps, const py::array_t &evset_values, - const double window_length) { + const double window_length, + Accumulator &accumulator) { // Input size const size_t n_event = evset_timestamps.shape(0); @@ -40,9 +66,6 @@ py::array_t accumulate(const ArrayD &evset_timestamps, auto v_output = output.template mutable_unchecked<1>(); auto v_timestamps = evset_timestamps.unchecked<1>(); - auto v_values = evset_values.template unchecked<1>(); - - TAccumulator accumulator(v_values); // Index of the first value in the window. size_t begin_idx = 0; @@ -86,11 +109,12 @@ py::array_t accumulate(const ArrayD &evset_timestamps, } // External sampling, constant window length -template +template py::array_t accumulate(const ArrayD &evset_timestamps, const py::array_t &evset_values, const ArrayD &sampling_timestamps, - const double window_length) { + const double window_length, + Accumulator &accumulator) { // Input size const size_t n_event = evset_timestamps.shape(0); const size_t n_sampling = sampling_timestamps.shape(0); @@ -100,11 +124,8 @@ py::array_t accumulate(const ArrayD &evset_timestamps, auto v_output = output.template mutable_unchecked<1>(); auto v_timestamps = evset_timestamps.unchecked<1>(); - auto v_values = evset_values.template unchecked<1>(); auto v_sampling = sampling_timestamps.unchecked<1>(); - TAccumulator accumulator(v_values); - size_t begin_idx = 0; size_t end_idx = 0; @@ -137,10 +158,11 @@ bool begin_moved_forward(const double ts, const double prev_ts, } // No external sampling, variable window length -template +template py::array_t accumulate(const ArrayD &evset_timestamps, const py::array_t &evset_values, - const ArrayD &window_length) { + const ArrayD &window_length, + Accumulator &accumulator) { // Input size const size_t n_event = evset_timestamps.shape(0); @@ -149,14 +171,11 @@ py::array_t accumulate(const ArrayD &evset_timestamps, auto v_output = output.template mutable_unchecked<1>(); auto v_timestamps = evset_timestamps.unchecked<1>(); - auto v_values = evset_values.template unchecked<1>(); auto v_window_length = window_length.unchecked<1>(); assert(v_timestamps.shape(0) == v_window_length.shape(0)); assert(v_timestamps.shape(0) == v_values.shape(0)); - TAccumulator accumulator(v_values); - // Index of the first value in the window. size_t begin_idx = 0; // Index of the first value outside the window. @@ -210,11 +229,12 @@ py::array_t accumulate(const ArrayD &evset_timestamps, } // External sampling, variable window length -template +template py::array_t accumulate(const ArrayD &evset_timestamps, const py::array_t &evset_values, const ArrayD &sampling_timestamps, - const ArrayD &window_length) { + const ArrayD &window_length, + Accumulator &accumulator) { // Input size const size_t n_event = evset_timestamps.shape(0); const size_t n_sampling = sampling_timestamps.shape(0); @@ -224,15 +244,12 @@ py::array_t accumulate(const ArrayD &evset_timestamps, auto v_output = output.template mutable_unchecked<1>(); auto v_timestamps = evset_timestamps.unchecked<1>(); - auto v_values = evset_values.template unchecked<1>(); auto v_sampling = sampling_timestamps.unchecked<1>(); auto v_window_length = window_length.unchecked<1>(); assert(v_timestamps.shape(0) == v_values.shape(0)); assert(v_sampling.shape(0) == v_window_length.shape(0)); - TAccumulator accumulator(v_values); - size_t begin_idx = 0; size_t end_idx = 0; @@ -276,22 +293,6 @@ py::array_t accumulate(const ArrayD &evset_timestamps, return output; } -// Note: We only use inheritance to compile check the code. -template -struct Accumulator { - Accumulator(const ArrayRef &values) : values(values) {} - - virtual ~Accumulator() = default; - virtual void Add(Idx idx) = 0; - virtual void Remove(Idx idx) = 0; - virtual OUTPUT Result() = 0; - - // // Add a value to left of the window. Relevant in deque-based accumulators. - virtual void AddLeft(Idx idx) { return Add(idx); } - - ArrayRef values; -}; - template struct SimpleMovingAverageAccumulator : public Accumulator { SimpleMovingAverageAccumulator(const ArrayRef &values) @@ -520,54 +521,187 @@ struct MovingMaxAccumulator : MovingExtremumAccumulator { // improve calculation efficiency while maintaining accuracy. template struct MovingProductAccumulator : public Accumulator { - int start_idx = 0; - int end_idx = -1; // Initialize to -1 to indicate an empty window initially + int start_idx = 0; + int end_idx = -1; // Initialize to -1 to indicate an empty window initially - MovingProductAccumulator(const ArrayRef& values) - : Accumulator(values) {} - void Add(Idx idx) override { - // Simply move the end to the given index - end_idx = idx; + MovingProductAccumulator(const ArrayRef &values) + : Accumulator(values) {} + void Add(Idx idx) override { + // Simply move the end to the given index + end_idx = idx; + } + + void Remove(Idx idx) override { + // Adjust the start index to exclude the removed value, signaling a window + // shift. + start_idx = idx + 1; + } + + OUTPUT Result() override { + if (start_idx > end_idx) { + // No valid indices to process, indicating an empty window or EventSet + return std::numeric_limits::quiet_NaN(); } - void Remove(Idx idx) override { - // Adjust the start index to exclude the removed value, signaling a window shift. - start_idx = idx + 1; + double product = 1.0; + bool hasEncounteredValidValue = false; // This will be true if any non-NaN + // and non-zero value is encountered + + for (int idx = start_idx; idx <= end_idx; ++idx) { + const INPUT value = Accumulator::values[idx]; + if (value == 0) { + return 0; // If a zero is found, return 0 immediately. + } else if (!std::isnan(value)) { + product *= value; + hasEncounteredValidValue = true; + } } - OUTPUT Result() override { - if (start_idx > end_idx) { - // No valid indices to process, indicating an empty window or EventSet - return std::numeric_limits::quiet_NaN(); - } + if (!hasEncounteredValidValue) { + return std::numeric_limits::quiet_NaN(); + } - double product = 1.0; - bool hasEncounteredValidValue = false; // This will be true if any non-NaN and non-zero value is encountered - - for (int idx = start_idx; idx <= end_idx; ++idx) { - const INPUT value = Accumulator::values[idx]; - if (value == 0) { - return 0; // If a zero is found, return 0 immediately. - } else if (!std::isnan(value)) { - product *= value; - hasEncounteredValidValue = true; - } - } + return product; + } - if (!hasEncounteredValidValue) { - return std::numeric_limits::quiet_NaN(); - } + void Reset() { + start_idx = 0; + end_idx = -1; + } +}; - return product; +template +struct MovingQuantileAccumulator : public QuantileAccumulator { + MovingQuantileAccumulator(const ArrayRef &values, float quantile) + : QuantileAccumulator(values, quantile), + bigger_idxs(std::bind(&MovingQuantileAccumulator::less, this, + std::placeholders::_1, std::placeholders::_2)), + smaller_idxs(std::bind(&MovingQuantileAccumulator::greater, this, + std::placeholders::_1, std::placeholders::_2)) {} + + // heap where the bigger values are stored, sorted with the min at the top + CustomHeap bigger_idxs; + // heap where the smaller values are stored, sorted with the max at the top + CustomHeap smaller_idxs; + + bool less(Idx idxa, Idx idxb) { + return Accumulator::values[idxa] < + Accumulator::values[idxb]; + } + + bool greater(Idx idxa, Idx idxb) { + return Accumulator::values[idxa] > + Accumulator::values[idxb]; + } + + void Add(Idx idx) override { + auto value = Accumulator::values[idx]; + if (std::isnan(value)) { + return; } - void Reset() { - start_idx = 0; - end_idx = -1; + auto ref_heap = this->GetRefHeap(); + auto ref_idx = ref_heap->top(); + + if (ref_idx.has_value()) { + // push the idx to the corresponding heap + auto ref_value = Accumulator::values[*ref_idx]; + + if (value > ref_value) { + bigger_idxs.push(idx); + } else { + smaller_idxs.push(idx); + } + this->Rebalance(); + } else { + // if we reach this point, both heaps are empty + ref_heap->push(idx); } -}; + } + + float GetCurrentRatio() { + auto total = smaller_idxs.size() + bigger_idxs.size(); + return total == 0 ? 0.0 : static_cast(smaller_idxs.size()) / total; + } + float GetCurrentPrecision() { + auto total = smaller_idxs.size() + bigger_idxs.size(); + return total == 0 ? 1.0 : 1.0 / total; + } + + CustomHeap *GetRefHeap() { + // By default use the smallers + if (smaller_idxs.empty() && bigger_idxs.empty()) { + return &smaller_idxs; + } + // to avoid unecesary rebalances the reference value (value closer to the + // quantile) can be selected from either heap according to whether we are + // over or under shoothing the desired quantile + float ratio = this->GetCurrentRatio(); + return ratio < this->quantile ? &bigger_idxs : &smaller_idxs; + } + void Remove(Idx idx) override { + auto value = Accumulator::values[idx]; + if (std::isnan(value)) { + return; + } + auto ref_heap = this->GetRefHeap(); + auto ref_idx = ref_heap->top(); + if (ref_idx.has_value()) { + if (*ref_idx == idx) { + ref_heap->remove(idx); + } else { + // remove the index from the corresponding heap + auto ref_value = Accumulator::values[*ref_idx]; + if (value > ref_value) { + bigger_idxs.remove(idx); + } else { + smaller_idxs.remove(idx); + } + } + } + this->Rebalance(); + } + + void Rebalance() { + // move items from one heap to another if and only if + // the change brings the ratio closer to the desired quantile + // considering the current precision we can expect with the amount of + // items we have. + float ratio = this->GetCurrentRatio(); + float precision = this->GetCurrentPrecision(); + if (ratio <= this->quantile - precision) { + // move the min of bigger_idxs the to the smaller_idxs + smaller_idxs.push(*bigger_idxs.pop()); + } else if (ratio >= this->quantile + precision) { + // move the max of the smaller_idxs to the bigger_idxs + bigger_idxs.push(*smaller_idxs.pop()); + } + // else: heaps are balanced, nothing to do + } + + OUTPUT Result() override { + // empty window + if (smaller_idxs.empty() && bigger_idxs.empty()) { + return std::numeric_limits::quiet_NaN(); + } + // select the item used as reference, that's the one closer to the quantile + float ratio = this->GetCurrentRatio(); + if (ratio < this->quantile) { + return Accumulator::values[*bigger_idxs.top()]; + } else if (ratio > this->quantile) { + return Accumulator::values[*smaller_idxs.top()]; + } else { + // if the heaps are perfectly balanced, use the average of the 2 best + // values. + return static_cast( + Accumulator::values[*smaller_idxs.top()] + + Accumulator::values[*bigger_idxs.top()]) / + 2; + } + } +}; // Instantiate the "accumulate" function with and without sampling, // and with and without variable window length. @@ -582,62 +716,129 @@ struct MovingProductAccumulator : public Accumulator { py::array_t NAME(const ArrayD &evset_timestamps, \ const py::array_t &evset_values, \ const double window_length) { \ - return accumulate>( \ - evset_timestamps, evset_values, window_length); \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_values, \ + window_length, accumulator); \ } \ \ py::array_t NAME( \ const ArrayD &evset_timestamps, const py::array_t &evset_values, \ const ArrayD &sampling_timestamps, const double window_length) { \ - return accumulate>( \ - evset_timestamps, evset_values, sampling_timestamps, window_length); \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_values, \ + sampling_timestamps, window_length, \ + accumulator); \ } \ \ py::array_t NAME(const ArrayD &evset_timestamps, \ const py::array_t &evset_values, \ const ArrayD &window_length) { \ - return accumulate>( \ - evset_timestamps, evset_values, window_length); \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_values, \ + window_length, accumulator); \ } \ \ py::array_t NAME( \ const ArrayD &evset_timestamps, const py::array_t &evset_values, \ const ArrayD &sampling_timestamps, const ArrayD &window_length) { \ - return accumulate>( \ - evset_timestamps, evset_values, sampling_timestamps, window_length); \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_values, \ + sampling_timestamps, window_length, \ + accumulator); \ } // Similar to REGISTER_CC_FUNC, but without inputs -#define REGISTER_CC_FUNC_NO_INPUT(NAME, OUTPUT, ACCUMULATOR) \ - \ - py::array_t NAME(const ArrayD &evset_timestamps, \ - const double window_length) { \ - return accumulate>( \ - evset_timestamps, evset_timestamps, window_length); \ - } \ - \ - py::array_t NAME(const ArrayD &evset_timestamps, \ - const ArrayD &sampling_timestamps, \ - const double window_length) { \ - return accumulate>( \ - evset_timestamps, evset_timestamps, sampling_timestamps, \ - window_length); \ - } \ - \ - py::array_t NAME(const ArrayD &evset_timestamps, \ - const ArrayD &window_length) { \ - return accumulate>( \ - evset_timestamps, evset_timestamps, window_length); \ - } \ - \ - py::array_t NAME(const ArrayD &evset_timestamps, \ - const ArrayD &sampling_timestamps, \ - const ArrayD &window_length) { \ - return accumulate>( \ - evset_timestamps, evset_timestamps, sampling_timestamps, \ - window_length); \ +#define REGISTER_CC_FUNC_NO_INPUT(NAME, OUTPUT, ACCUMULATOR) \ + \ + py::array_t NAME(const ArrayD &evset_timestamps, \ + const double window_length) { \ + auto v_values = evset_timestamps.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_timestamps, \ + window_length, accumulator); \ + } \ + \ + py::array_t NAME(const ArrayD &evset_timestamps, \ + const ArrayD &sampling_timestamps, \ + const double window_length) { \ + auto v_values = evset_timestamps.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_timestamps, \ + sampling_timestamps, window_length, \ + accumulator); \ + } \ + \ + py::array_t NAME(const ArrayD &evset_timestamps, \ + const ArrayD &window_length) { \ + auto v_values = evset_timestamps.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_timestamps, \ + window_length, accumulator); \ + } \ + \ + py::array_t NAME(const ArrayD &evset_timestamps, \ + const ArrayD &sampling_timestamps, \ + const ArrayD &window_length) { \ + auto v_values = evset_timestamps.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values); \ + return accumulate(evset_timestamps, evset_timestamps, \ + sampling_timestamps, window_length, \ + accumulator); \ } +// Instantiate the "accumulate" function with and without sampling, +// and with and without variable window length. +// +// Args: +// NAME: Name of the python and c++ function. +// INPUT: Input value type. +// OUTPUT: Output value type. +// ACCUMULATOR: Accumulator class. +#define REGISTER_CC_FUNC_QUANTILE(NAME, INPUT, OUTPUT, ACCUMULATOR) \ + \ + py::array_t NAME(const ArrayD &evset_timestamps, \ + const py::array_t &evset_values, \ + const double window_length, const float quantile) { \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values, quantile); \ + return accumulate(evset_timestamps, evset_values, \ + window_length, accumulator); \ + } \ + \ + py::array_t NAME(const ArrayD &evset_timestamps, \ + const py::array_t &evset_values, \ + const ArrayD &sampling_timestamps, \ + const double window_length, const float quantile) { \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values, quantile); \ + return accumulate(evset_timestamps, evset_values, \ + sampling_timestamps, window_length, \ + accumulator); \ + } \ + \ + py::array_t NAME( \ + const ArrayD &evset_timestamps, const py::array_t &evset_values, \ + const ArrayD &window_length, const float quantile) { \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values, quantile); \ + return accumulate(evset_timestamps, evset_values, \ + window_length, accumulator); \ + } \ + \ + py::array_t NAME( \ + const ArrayD &evset_timestamps, const py::array_t &evset_values, \ + const ArrayD &sampling_timestamps, const ArrayD &window_length, \ + const float quantile) { \ + auto v_values = evset_values.template unchecked<1>(); \ + auto accumulator = ACCUMULATOR(v_values, quantile); \ + return accumulate(evset_timestamps, evset_values, \ + sampling_timestamps, window_length, \ + accumulator); \ + } // Note: ";" are not needed for the code, but are required for our code // formatter. @@ -670,6 +871,15 @@ REGISTER_CC_FUNC_NO_INPUT(moving_count, int32_t, MovingCountAccumulator); REGISTER_CC_FUNC(moving_product, float, float, MovingProductAccumulator); REGISTER_CC_FUNC(moving_product, double, double, MovingProductAccumulator); + +REGISTER_CC_FUNC_QUANTILE(moving_quantile, int32_t, float, + MovingQuantileAccumulator); +REGISTER_CC_FUNC_QUANTILE(moving_quantile, int64_t, float, + MovingQuantileAccumulator); +REGISTER_CC_FUNC_QUANTILE(moving_quantile, float, float, + MovingQuantileAccumulator); +REGISTER_CC_FUNC_QUANTILE(moving_quantile, double, double, + MovingQuantileAccumulator); } // namespace // Register c++ functions to pybind with and without sampling, @@ -726,6 +936,37 @@ REGISTER_CC_FUNC(moving_product, double, double, MovingProductAccumulator); m.def(#NAME, py::overload_cast(&NAME), "", \ py::arg("evset_timestamps").noconvert(), py::arg("window_length")); +#define ADD_PY_DEF_QUANTILE(NAME, INPUT, OUTPUT) \ + m.def(#NAME, \ + py::overload_cast &, \ + const ArrayD &, double, float>(&NAME), \ + "", py::arg("evset_timestamps").noconvert(), \ + py::arg("evset_values").noconvert(), \ + py::arg("sampling_timestamps").noconvert(), py::arg("window_length"), \ + py::arg("quantile")); \ + \ + m.def(#NAME, \ + py::overload_cast &, double, \ + float>(&NAME), \ + "", py::arg("evset_timestamps").noconvert(), \ + py::arg("evset_values").noconvert(), py::arg("window_length"), \ + py::arg("quantile")); \ + \ + m.def(#NAME, \ + py::overload_cast &, \ + const ArrayD &, const ArrayD &, float>(&NAME), \ + "", py::arg("evset_timestamps").noconvert(), \ + py::arg("evset_values").noconvert(), \ + py::arg("sampling_timestamps").noconvert(), py::arg("window_length"), \ + py::arg("quantile")); \ + \ + m.def(#NAME, \ + py::overload_cast &, \ + const ArrayD &, float>(&NAME), \ + "", py::arg("evset_timestamps").noconvert(), \ + py::arg("evset_values").noconvert(), py::arg("window_length"), \ + py::arg("quantile")); + void init_window(py::module &m) { ADD_PY_DEF(simple_moving_average, float, float) ADD_PY_DEF(simple_moving_average, double, double) @@ -753,4 +994,8 @@ void init_window(py::module &m) { ADD_PY_DEF(moving_product, float, float) ADD_PY_DEF(moving_product, double, double) + ADD_PY_DEF_QUANTILE(moving_quantile, int32_t, float) + ADD_PY_DEF_QUANTILE(moving_quantile, int64_t, float) + ADD_PY_DEF_QUANTILE(moving_quantile, float, float) + ADD_PY_DEF_QUANTILE(moving_quantile, double, double) } From f3488888fd746c8429e38b7d3d4a0cfa25f646ad Mon Sep 17 00:00:00 2001 From: javiber Date: Mon, 6 May 2024 19:24:37 -0300 Subject: [PATCH 2/9] Fixed a bug with negative windows length --- temporian/implementation/numpy_cc/operators/window.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporian/implementation/numpy_cc/operators/window.cc b/temporian/implementation/numpy_cc/operators/window.cc index a3cda8c97..aa0fbf113 100644 --- a/temporian/implementation/numpy_cc/operators/window.cc +++ b/temporian/implementation/numpy_cc/operators/window.cc @@ -192,7 +192,7 @@ py::array_t accumulate(const ArrayD &evset_timestamps, const auto curr_ts = v_timestamps[idx]; auto curr_window_length = v_window_length[idx]; - if (std::isnan(curr_window_length)) { + if (std::isnan(curr_window_length) || curr_window_length < 0) { curr_window_length = 0; } @@ -257,7 +257,7 @@ py::array_t accumulate(const ArrayD &evset_timestamps, const auto right_limit = v_sampling[sampling_idx]; auto curr_window_length = v_window_length[sampling_idx]; - if (std::isnan(curr_window_length)) { + if (std::isnan(curr_window_length) || curr_window_length < 0) { curr_window_length = 0; } From 08d0dc0898dc2abb2d338e58babc9e8cf1398179 Mon Sep 17 00:00:00 2001 From: javiber Date: Mon, 6 May 2024 19:40:20 -0300 Subject: [PATCH 3/9] Adding docs and fixing some tests --- .../operators/window/moving_quantile.md | 1 + temporian/core/event_set_ops.py | 53 ++++++++++++++++++- .../core/operators/window/moving_quantile.py | 12 ++--- .../core/test/registered_operators_test.py | 1 + .../numpy/test/registered_operators_test.py | 1 + .../numpy_cc/operators/window.cc | 4 +- 6 files changed, 62 insertions(+), 10 deletions(-) create mode 100644 docs/src/reference/temporian/operators/window/moving_quantile.md diff --git a/docs/src/reference/temporian/operators/window/moving_quantile.md b/docs/src/reference/temporian/operators/window/moving_quantile.md new file mode 100644 index 000000000..d4723cdf8 --- /dev/null +++ b/docs/src/reference/temporian/operators/window/moving_quantile.md @@ -0,0 +1 @@ +::: temporian.EventSet.moving_quantile diff --git a/temporian/core/event_set_ops.py b/temporian/core/event_set_ops.py index 247780177..4ca71eb9e 100644 --- a/temporian/core/event_set_ops.py +++ b/temporian/core/event_set_ops.py @@ -3224,7 +3224,58 @@ def moving_quantile( quantile: float, sampling: Optional[EventSetOrNode] = None, ) -> EventSetOrNode: - """ """ + """Computes the quantile in a sliding window over an + [`EventSet`][temporian.EventSet]. + + For each t in sampling, and for each feature independently, returns at + time t the appropiated quantile for the feature in the window + (t - window_length, t]. + + `sampling` can't be specified if a variable `window_length` is + specified (i.e. if `window_length` is an EventSet). + + If `sampling` is specified or `window_length` is an EventSet, the moving + window is sampled at each timestamp in them, else it is sampled on the + input's. + + Missing values (such as NaNs) are ignored. + + If the window does not contain any values (e.g., all the values are + missing, or the window does not contain any sampling), outputs missing + values. + + The quantile calculated in each window is equivalent to numpy's + `"averaged_inverted_cdf"` method. + + Example: + ```python + >>> a = tp.event_set( + ... timestamps=[0, 1, 2, 5, 6, 7], + ... features={"value": [np.nan, 1, 5, 10, 15, 20]}, + ... ) + + >>> a.moving_quantile(tp.duration.seconds(4), quantile=0.5) + indexes: ... + (6 events): + timestamps: [0. 1. 2. 5. 6. 7.] + 'value': [ nan 1. 3. 7.5 12.5 15. ] + ... + + ``` + + See [`EventSet.moving_count()`][temporian.EventSet.moving_count] for + examples of moving window operations with external sampling and indices. + + Args: + window_length: Sliding window's length. + quantile: the desired quantile defined in the range (0, 1). + sampling: Timestamps to sample the sliding window's value at. If not + provided, timestamps in the input are used. + + Returns: + EventSet containing the moving standard deviation of each feature in + the input. + """ from temporian.core.operators.window.moving_quantile import ( moving_quantile, ) diff --git a/temporian/core/operators/window/moving_quantile.py b/temporian/core/operators/window/moving_quantile.py index 98dd82a59..caff068d6 100644 --- a/temporian/core/operators/window/moving_quantile.py +++ b/temporian/core/operators/window/moving_quantile.py @@ -44,10 +44,8 @@ def __init__( ): if quantile < 0 or quantile > 1: raise ValueError( - ( - "`quantile` must be a float between 0 and 1. " - f"Received {quantile}" - ) + "`quantile` must be a float between 0 and 1. " + f"Received {quantile}" ) self.quantile = quantile super().__init__(input, window_length, sampling) @@ -62,9 +60,9 @@ def operator_def_key(cls) -> str: def get_feature_dtype(self, feature: FeatureSchema) -> DType: if not feature.dtype.is_numerical: raise ValueError( - "moving_quantile requires the input EventSet to contain numerical" - f" features only, but received feature {feature.name!r} with" - f" type {feature.dtype}" + "moving_quantile requires the input EventSet to contain" + " numerical features only, but received feature" + f" {feature.name!r} with type {feature.dtype}" ) if feature.dtype.is_integer: return DType.FLOAT32 diff --git a/temporian/core/test/registered_operators_test.py b/temporian/core/test/registered_operators_test.py index 155790bee..3b087bbf9 100644 --- a/temporian/core/test/registered_operators_test.py +++ b/temporian/core/test/registered_operators_test.py @@ -81,6 +81,7 @@ def test_base(self): "MOVING_MAX", "MOVING_MIN", "MOVING_PRODUCT", + "MOVING_QUANTILE", "MOVING_STANDARD_DEVIATION", "MOVING_SUM", "MULTIPLICATION", diff --git a/temporian/implementation/numpy/test/registered_operators_test.py b/temporian/implementation/numpy/test/registered_operators_test.py index 6cb447f16..e29bb0bd8 100644 --- a/temporian/implementation/numpy/test/registered_operators_test.py +++ b/temporian/implementation/numpy/test/registered_operators_test.py @@ -79,6 +79,7 @@ def test_base(self): "MOVING_MAX", "MOVING_MIN", "MOVING_PRODUCT", + "MOVING_QUANTILE", "MOVING_STANDARD_DEVIATION", "MOVING_SUM", "MULTIPLICATION", diff --git a/temporian/implementation/numpy_cc/operators/window.cc b/temporian/implementation/numpy_cc/operators/window.cc index aa0fbf113..3b4d08fec 100644 --- a/temporian/implementation/numpy_cc/operators/window.cc +++ b/temporian/implementation/numpy_cc/operators/window.cc @@ -174,7 +174,7 @@ py::array_t accumulate(const ArrayD &evset_timestamps, auto v_window_length = window_length.unchecked<1>(); assert(v_timestamps.shape(0) == v_window_length.shape(0)); - assert(v_timestamps.shape(0) == v_values.shape(0)); + assert(v_timestamps.shape(0) == evset_values.shape(0)); // Index of the first value in the window. size_t begin_idx = 0; @@ -247,7 +247,7 @@ py::array_t accumulate(const ArrayD &evset_timestamps, auto v_sampling = sampling_timestamps.unchecked<1>(); auto v_window_length = window_length.unchecked<1>(); - assert(v_timestamps.shape(0) == v_values.shape(0)); + assert(v_timestamps.shape(0) == evset_values.shape(0)); assert(v_sampling.shape(0) == v_window_length.shape(0)); size_t begin_idx = 0; From 25f7938bd4f7f6e6d69792901ad8c3ae22299c1b Mon Sep 17 00:00:00 2001 From: javiber Date: Fri, 17 May 2024 12:37:02 -0300 Subject: [PATCH 4/9] Addressing changes requested --- temporian/core/event_set_ops.py | 2 +- temporian/core/operators/window/base.py | 9 ++++--- .../core/operators/window/moving_quantile.py | 26 ++++++++++++------- .../numpy_cc/operators/custom_heap.h | 26 ++++++------------- 4 files changed, 31 insertions(+), 32 deletions(-) diff --git a/temporian/core/event_set_ops.py b/temporian/core/event_set_ops.py index 4ca71eb9e..4542bdd8d 100644 --- a/temporian/core/event_set_ops.py +++ b/temporian/core/event_set_ops.py @@ -3254,7 +3254,7 @@ def moving_quantile( ... features={"value": [np.nan, 1, 5, 10, 15, 20]}, ... ) - >>> a.moving_quantile(tp.duration.seconds(4), quantile=0.5) + >>> a.moving_quantile(4, quantile=0.5) indexes: ... (6 events): timestamps: [0. 1. 2. 5. 6. 7.] diff --git a/temporian/core/operators/window/base.py b/temporian/core/operators/window/base.py index 6590898df..2d425fb9d 100644 --- a/temporian/core/operators/window/base.py +++ b/temporian/core/operators/window/base.py @@ -34,8 +34,6 @@ class BaseWindowOperator(Operator, ABC): """Interface definition and common logic for window operators.""" - extra_attribute_def: List[Mapping[str, Any]] = [] - def __init__( self, input: EventSetNode, @@ -125,10 +123,15 @@ def has_variable_winlen(self) -> bool: def add_extra_attributes(self): pass + @classmethod + def extra_attribute_def(cls) -> List[Mapping[str, Any]]: + return [] + @classmethod def build_op_definition(cls) -> pb.OperatorDef: extra_attr_def = [ - pb.OperatorDef.Attribute(**attr) for attr in cls.extra_attribute_def + pb.OperatorDef.Attribute(**attr) + for attr in cls.extra_attribute_def() ] return pb.OperatorDef( key=cls.operator_def_key(), diff --git a/temporian/core/operators/window/moving_quantile.py b/temporian/core/operators/window/moving_quantile.py index caff068d6..9cb302e19 100644 --- a/temporian/core/operators/window/moving_quantile.py +++ b/temporian/core/operators/window/moving_quantile.py @@ -14,7 +14,7 @@ """Moving count operator class and public API function definition.""" -from typing import Optional +from typing import List, Mapping, Optional, Any from temporian.core import operator_lib from temporian.core.compilation import compile @@ -27,14 +27,6 @@ class MovingQuantileOperator(BaseWindowOperator): - extra_attribute_def = [ - { - "key": "quantile", - "is_optional": True, - "type": pb.OperatorDef.Attribute.Type.FLOAT_64, - } - ] - def __init__( self, input: EventSetNode, @@ -47,9 +39,13 @@ def __init__( "`quantile` must be a float between 0 and 1. " f"Received {quantile}" ) - self.quantile = quantile + self._quantile = quantile super().__init__(input, window_length, sampling) + @property + def quantile(self) -> float: + return self._quantile + def add_extra_attributes(self): self.add_attribute("quantile", self.quantile) @@ -68,6 +64,16 @@ def get_feature_dtype(self, feature: FeatureSchema) -> DType: return DType.FLOAT32 return feature.dtype + @classmethod + def extra_attribute_def(cls) -> List[Mapping[str, Any]]: + return [ + { + "key": "quantile", + "is_optional": True, + "type": pb.OperatorDef.Attribute.Type.FLOAT_64, + } + ] + operator_lib.register_operator(MovingQuantileOperator) diff --git a/temporian/implementation/numpy_cc/operators/custom_heap.h b/temporian/implementation/numpy_cc/operators/custom_heap.h index 136a9fa67..9750af76c 100644 --- a/temporian/implementation/numpy_cc/operators/custom_heap.h +++ b/temporian/implementation/numpy_cc/operators/custom_heap.h @@ -13,14 +13,16 @@ class CustomHeap { void push(T value) { heap.push_back(value); auto it = std::prev(heap.end()); + // Notice that this breaks if a value repeats, not a problem in our case + // since we are using the Heap to store the indices val_to_node[value] = it; - // TODO: better sorting? + // TODO: there is no better way to insert in order with a list + // but exploring with trees could make this better while (it != heap.begin()) { auto parent = std::prev(it); if (!compare(*parent, *it)) { break; } - // TODO: check that this swap is doing what I want std::swap(*parent, *it); val_to_node[*it] = it; val_to_node[*parent] = parent; @@ -35,6 +37,8 @@ class CustomHeap { auto value = heap.back(); heap.pop_back(); auto it = val_to_node.find(value); + // all other pointers in val_to_node are still valid because + // heap is a double linked list val_to_node.erase(it); return value; } @@ -52,25 +56,11 @@ class CustomHeap { auto it = val_to_node.find(value); if (it != val_to_node.end()) { heap.erase(it->second); + // all other pointers in val_to_node are still valid because + // heap is a double linked list val_to_node.erase(it); - } else { - // TODO: exception meant for debugging, remove it - throw std::invalid_argument("removing a value that doesn't exists"); } } int size() { return heap.size(); } int empty() { return heap.empty(); } - - void print() { - std::cout << "my_heap{" << std::endl << " ["; - std::for_each(heap.begin(), heap.end(), - [](const int n) { std::cout << n << ' '; }); - std::cout << "]" << std::endl; - - // std::cout << " {" << std::endl; - // for (const auto& pair : val_to_node) { - // std::cout << " " << pair.first << ": " << *(pair.second) << std::endl; - // } - // std::cout << " }" << std::endl; - } }; From 99fd2ec7fd10f0b4f409e27809f42f0ab0e362d4 Mon Sep 17 00:00:00 2001 From: javiber Date: Wed, 5 Jun 2024 15:14:27 -0300 Subject: [PATCH 5/9] Added moving_quantile to table in docs/reference --- docs/src/reference/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/reference/index.md b/docs/src/reference/index.md index 00ebc8acc..40a229571 100644 --- a/docs/src/reference/index.md +++ b/docs/src/reference/index.md @@ -96,7 +96,7 @@ Check the index on the left for a more detailed description of any symbol. | Symbols | Description | | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- | -| [`EventSet.simple_moving_average()`][temporian.EventSet.simple_moving_average] [`EventSet.moving_standard_deviation()`][temporian.EventSet.moving_standard_deviation] [`EventSet.cumsum()`][temporian.EventSet.cumsum] [`EventSet.moving_sum()`][temporian.EventSet.moving_sum] [`EventSet.moving_count()`][temporian.EventSet.moving_count] [`EventSet.moving_min()`][temporian.EventSet.moving_min] [`EventSet.moving_max()`][temporian.EventSet.moving_max] [`EventSet.cumprod()`][temporian.EventSet.cumprod] [`EventSet.moving_product()`][temporian.EventSet.moving_product] | Compute an operation on the values in a sliding window over an EventSet's timestamps. | +| [`EventSet.simple_moving_average()`][temporian.EventSet.simple_moving_average] [`EventSet.moving_standard_deviation()`][temporian.EventSet.moving_standard_deviation] [`EventSet.cumsum()`][temporian.EventSet.cumsum] [`EventSet.moving_sum()`][temporian.EventSet.moving_sum] [`EventSet.moving_count()`][temporian.EventSet.moving_count] [`EventSet.moving_min()`][temporian.EventSet.moving_min] [`EventSet.moving_max()`][temporian.EventSet.moving_max] [`EventSet.cumprod()`][temporian.EventSet.cumprod] [`EventSet.moving_product()`][temporian.EventSet.moving_product] [`EventSet.moving_quantile()`][temporian.EventSet.moving_quantile] | Compute an operation on the values in a sliding window over an EventSet's timestamps. | ### Python operators From 8b9fb1d0806f3dfcfdea296bdcf6b0e6375d4f12 Mon Sep 17 00:00:00 2001 From: javiber Date: Wed, 5 Jun 2024 15:14:54 -0300 Subject: [PATCH 6/9] Added benchmark for moving_quantile --- benchmark/benchmark_time.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/benchmark/benchmark_time.py b/benchmark/benchmark_time.py index 72be3442f..d23ac5a5c 100644 --- a/benchmark/benchmark_time.py +++ b/benchmark/benchmark_time.py @@ -98,6 +98,20 @@ def benchmark_simple_moving_average(runner): ) +def benchmark_moving_quantile(runner): + runner.add_separator() + for n in [100, 10_000, 1_000_000]: + ds = _build_toy_dataset(n) + + node = ds.node() + output = node.moving_quantile(window_length=10.0, quantile=0.5) + + runner.benchmark( + f"moving_quantile (0.5):{n:_}", + lambda: tp.run(output, input={node: ds}), + ) + + def benchmark_moving_minimum(runner): runner.add_separator() for n in [1_000_000, 10_000_000]: @@ -457,6 +471,7 @@ def main(): "add_index_v2", "from_pandas_with_objects", "moving_minimum", + "moving_quantile", ] if args.functions is not None: benchmarks_to_run = args.functions From 5d5f05f3d81c9db9ba5058e41faadb2a980d917a Mon Sep 17 00:00:00 2001 From: javiber Date: Wed, 5 Jun 2024 15:26:51 -0300 Subject: [PATCH 7/9] Added note on dtypes from moving_quantile --- temporian/core/event_set_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/temporian/core/event_set_ops.py b/temporian/core/event_set_ops.py index 4542bdd8d..c25d9fe1d 100644 --- a/temporian/core/event_set_ops.py +++ b/temporian/core/event_set_ops.py @@ -3247,6 +3247,10 @@ def moving_quantile( The quantile calculated in each window is equivalent to numpy's `"averaged_inverted_cdf"` method. + This operation only accepts numeric dtypes in the input. + For `float64` the output will be `float64` but for + `float32`, `int64`, and `int32` output will be `float32`. + Example: ```python >>> a = tp.event_set( From e153534b116708b31da6b6df7f63df5bd9e4985a Mon Sep 17 00:00:00 2001 From: javiber Date: Wed, 5 Jun 2024 15:28:06 -0300 Subject: [PATCH 8/9] Removed unnecessary default value for sampling --- temporian/core/operators/window/moving_quantile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporian/core/operators/window/moving_quantile.py b/temporian/core/operators/window/moving_quantile.py index 9cb302e19..77c776c42 100644 --- a/temporian/core/operators/window/moving_quantile.py +++ b/temporian/core/operators/window/moving_quantile.py @@ -32,7 +32,7 @@ def __init__( input: EventSetNode, window_length: WindowLength, quantile: float, - sampling: Optional[EventSetNode] = None, + sampling: Optional[EventSetNode], ): if quantile < 0 or quantile > 1: raise ValueError( From 92eacd61c9b22cf5e8b0772434030f7de7cd979e Mon Sep 17 00:00:00 2001 From: javiber Date: Wed, 5 Jun 2024 15:31:42 -0300 Subject: [PATCH 9/9] Added comment explaining the odd placement of super().__init__ call --- temporian/core/operators/window/moving_quantile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/temporian/core/operators/window/moving_quantile.py b/temporian/core/operators/window/moving_quantile.py index 77c776c42..5c3a685ba 100644 --- a/temporian/core/operators/window/moving_quantile.py +++ b/temporian/core/operators/window/moving_quantile.py @@ -40,6 +40,8 @@ def __init__( f"Received {quantile}" ) self._quantile = quantile + # This line should be at the top but `BaseWindowOperator.__init__` calls + # `self.check` which fails if `this._quantile` is not set super().__init__(input, window_length, sampling) @property