Skip to content

Commit

Permalink
Use map_overlap for rolling reductions with Dask (#9770)
Browse files Browse the repository at this point in the history
* Use ``map_overlap`` for rolling reducers with Dask

* Enable argmin test

* Update
  • Loading branch information
phofl authored Nov 13, 2024
1 parent 568dd6f commit b16a104
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 30 deletions.
26 changes: 8 additions & 18 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,15 @@

def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
"""Wrapper to apply bottleneck moving window funcs on dask arrays"""
import dask.array as da

dtype, fill_value = dtypes.maybe_promote(a.dtype)
a = a.astype(dtype)
# inputs for overlap
if axis < 0:
axis = a.ndim + axis
depth = {d: 0 for d in range(a.ndim)}
depth[axis] = (window + 1) // 2
boundary = {d: fill_value for d in range(a.ndim)}
# Create overlap array.
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
# apply rolling func
out = da.map_blocks(
moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype
dtype, _ = dtypes.maybe_promote(a.dtype)
return a.data.map_overlap(
moving_func,
depth={axis: (window - 1, 0)},
axis=axis,
dtype=dtype,
window=window,
min_count=min_count,
)
# trim array
result = da.overlap.trim_internal(out, depth)
return result


def least_squares(lhs, rhs, rcond=None, skipna=False):
Expand Down
24 changes: 13 additions & 11 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from packaging.version import Version

from xarray.core import dtypes, duck_array_ops, utils
from xarray.core import dask_array_ops, dtypes, duck_array_ops, utils
from xarray.core.arithmetic import CoarsenArithmetic
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray
Expand Down Expand Up @@ -597,16 +597,18 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")

if is_duck_dask_array(padded.data):
raise AssertionError("should not be reachable")
values = dask_array_ops.dask_rolling_wrapper(
func, padded, axis=axis, window=self.window[0], min_count=min_count
)
else:
values = func(
padded.data, window=self.window[0], min_count=min_count, axis=axis
)
# index 0 is at the rightmost edge of the window
# need to reverse index here
# see GH #8541
if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
values = self.window[0] - 1 - values
# index 0 is at the rightmost edge of the window
# need to reverse index here
# see GH #8541
if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
values = self.window[0] - 1 - values

if self.center[0]:
values = values[valid]
Expand Down Expand Up @@ -669,12 +671,12 @@ def _array_reduce(
if (
OPTIONS["use_bottleneck"]
and bottleneck_move_func is not None
and not is_duck_dask_array(self.obj.data)
and (
not is_duck_dask_array(self.obj.data)
or module_available("dask", "2024.11.0")
)
and self.ndim == 1
):
# TODO: re-enable bottleneck with dask after the issues
# underlying https://github.com/pydata/xarray/issues/2940 are
# fixed.
return self._bottleneck_reduce(
bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
)
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ def test_rolling_properties(self, da) -> None:
):
da.rolling(foo=2)

@requires_dask
@pytest.mark.parametrize(
"name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax")
)
@pytest.mark.parametrize("center", (True, False, None))
@pytest.mark.parametrize("min_periods", (1, None))
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
@pytest.mark.parametrize("backend", ["numpy", "dask"], indirect=True)
def test_rolling_wrapped_bottleneck(
self, da, name, center, min_periods, compute_backend
) -> None:
Expand Down

0 comments on commit b16a104

Please sign in to comment.