diff --git a/doc/api.rst b/doc/api.rst index 87f116514cc..2ac2e40fa88 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -250,6 +250,7 @@ Reshaping and reorganizing Dataset.roll Dataset.pad Dataset.sortby + Dataset.shuffle_by Dataset.broadcast_like DataArray @@ -590,6 +591,7 @@ Reshaping and reorganizing DataArray.roll DataArray.pad DataArray.sortby + DataArray.shuffle_by DataArray.broadcast_like DataTree @@ -1083,6 +1085,7 @@ Dataset DatasetGroupBy.var DatasetGroupBy.dims DatasetGroupBy.groups + DatasetGroupBy.shuffle DataArray --------- @@ -1114,6 +1117,7 @@ DataArray DataArrayGroupBy.var DataArrayGroupBy.dims DataArrayGroupBy.groups + DataArrayGroupBy.shuffle Grouper Objects --------------- diff --git a/xarray/core/common.py b/xarray/core/common.py index f043b7be3dd..f06d64c3e21 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -52,7 +52,7 @@ T_Variable, ) from xarray.core.variable import Variable - from xarray.groupers import Resampler + from xarray.groupers import Grouper, Resampler DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] @@ -888,6 +888,68 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) + def shuffle_by( + self, + group: Hashable | DataArray | Mapping[Any, Grouper] | None = None, + chunks: T_Chunks = None, + **groupers: Grouper, + ) -> Self: + """ + Sort or "shuffle" this object by a Grouper. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + For chunked array types, the order of appearance is not guaranteed, but will depend on + the input chunking. + + Parameters + ---------- + group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + Array whose unique values should be used to group this array. If a + Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, + must map an existing variable name to a :py:class:`Grouper` instance. + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + **groupers : Grouper + Grouper objects using which to shuffle the data. + + Examples + -------- + >>> import dask + >>> from xarray.groupers import UniqueGrouper + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=1), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> da + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 1 2 3 1 2 3 1 2 3 0 + + >>> da.shuffle_by(x=UniqueGrouper()) + Size: 80B + dask.array + Coordinates: + * x (x) int64 80B 0 1 1 1 2 2 2 3 3 3 + + Returns + ------- + DataArray or Dataset + The same type as this object + + See Also + -------- + DataArrayGroupBy.shuffle + DatasetGroupBy.shuffle + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + return self.groupby(group=group, **groupers)._shuffle_obj(chunks) + def _resample( self, resample_cls: type[T_Resample], diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a5e520b98b6..7193e71ca9d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -47,6 +47,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: @@ -54,7 +55,13 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey + from xarray.core.types import ( + GroupIndex, + GroupIndices, + GroupInput, + GroupKey, + T_Chunks, + ) from xarray.core.utils import Frozen from xarray.groupers import EncodedGroups, Grouper @@ -610,6 +617,100 @@ def sizes(self) -> Mapping[Hashable, int]: self._sizes = self._obj.isel({self._group_dim: index}).sizes return self._sizes + def shuffle(self, chunks: T_Chunks = None): + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> shuffled = da.groupby("x").shuffle() + >>> shuffled.quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + new_groupers = { + # Using group.name handles the BinGrouper case + # It does *not* handle the TimeResampler case, + # so we just override this method in Resample + grouper.group.name: grouper.grouper.reset() + for grouper in self.groupers + } + return self._shuffle_obj(chunks).groupby( + new_groupers, + restore_coord_dims=self._restore_coord_dims, + ) + + def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray: + from xarray.core.dataarray import DataArray + + dim = self._group_dim + size = self._obj.sizes[dim] + was_array = isinstance(self._obj, DataArray) + as_dataset = self._obj._to_temp_dataset() if was_array else self._obj + no_slices: list[list[int]] = [ + list(range(*idx.indices(size))) if isinstance(idx, slice) else idx + for idx in self.encoded.group_indices + ] + no_slices = [idx for idx in no_slices if idx] + + for grouper in self.groupers: + if grouper.name not in as_dataset._variables: + as_dataset.coords[grouper.name] = grouper.group + + # Shuffling is only different from `isel` for chunked arrays. + # Extract them out, and treat them specially. The rest, we route through isel. + # This makes it easy to ensure correct handling of indexes. + is_chunked = { + name: var + for name, var in as_dataset._variables.items() + if is_chunked_array(var._data) + } + subset = as_dataset[ + [name for name in as_dataset._variables if name not in is_chunked] + ] + + shuffled = subset.isel({dim: np.concatenate(no_slices)}) + for name, var in is_chunked.items(): + shuffled[name] = var._shuffle( + indices=list(idx for idx in self.encoded.group_indices if idx), + dim=dim, + chunks=chunks, + ) + shuffled = self._maybe_unstack(shuffled) + new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled + return new_obj + def map( self, func: Callable, @@ -822,7 +923,9 @@ def _maybe_unstack(self, obj): # and `inserted_dims` # if multiple groupers all share the same single dimension, then # we don't stack/unstack. Do that manually now. - obj = obj.unstack(*self.encoded.unique_coord.dims) + dims_to_unstack = self.encoded.unique_coord.dims + if all(dim in obj.dims for dim in dims_to_unstack): + obj = obj.unstack(*dims_to_unstack) to_drop = [ grouper.name for grouper in self.groupers diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 677de48f0b6..8e0c258debb 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable, Hashable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from xarray.core._aggregations import ( DataArrayResampleAggregations, @@ -14,6 +14,8 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + from xarray.core.types import T_Chunks + from xarray.groupers import Resampler from xarray.groupers import RESAMPLE_DIM @@ -58,6 +60,60 @@ def _flox_reduce( result = result.rename({RESAMPLE_DIM: self._group_dim}) return result + def shuffle(self, chunks: T_Chunks = None): + """ + Sort or "shuffle" the underlying object. + + "Shuffle" means the object is sorted so that all group members occur sequentially, + in the same chunk. Multiple groups may occur in the same chunk. + This method is particularly useful for chunked arrays (e.g. dask, cubed). + particularly when you need to map a function that requires all members of a group + to be present in a single chunk. For chunked array types, the order of appearance + is not guaranteed, but will depend on the input chunking. + + .. warning:: + + With resampling it is a lot better to use ``.chunk`` instead of ``.shuffle``, + since one can only resample a sorted time coordinate. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional + How to adjust chunks along dimensions not present in the array being grouped by. + + Returns + ------- + DataArrayGroupBy or DatasetGroupBy + + Examples + -------- + >>> import dask + >>> da = xr.DataArray( + ... dims="x", + ... data=dask.array.arange(10, chunks=3), + ... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + ... name="a", + ... ) + >>> shuffled = da.groupby("x").shuffle() + >>> shuffled.quantile(q=0.5).compute() + Size: 32B + array([9., 3., 4., 5.]) + Coordinates: + quantile float64 8B 0.5 + * x (x) int64 32B 0 1 2 3 + + See Also + -------- + dask.dataframe.DataFrame.shuffle + dask.array.shuffle + """ + (grouper,) = self.groupers + shuffled = self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM) + return shuffled.resample( + {self._group_dim: cast("Resampler", grouper.grouper.reset())}, + restore_coord_dims=self._restore_coord_dims, + ) + def _drop_coords(self) -> T_Xarray: """Drop non-dimension coordinates along the resampled dimension.""" obj = self._obj diff --git a/xarray/core/types.py b/xarray/core/types.py index 34b6029ee15..2a33b85e27f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -313,7 +313,7 @@ def copy( ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] GroupKey = Any -GroupIndex = Union[int, slice, list[int]] +GroupIndex = Union[slice, list[int]] GroupIndices = tuple[GroupIndex, ...] Bins = Union[ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d84a03c3677..bcf6c0946af 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -45,7 +45,13 @@ maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions -from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import ( + integer_types, + is_0d_dask_array, + is_chunked_array, + to_duck_array, +) from xarray.util.deprecation_helpers import deprecate_dims NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -1002,6 +1008,24 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) + def _shuffle( + self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks + ) -> Self: + array = self._data + if is_chunked_array(array): + chunkmanager = get_chunked_array_type(array) + return self._replace( + data=chunkmanager.shuffle( + array, + indexer=indices, + axis=self.get_axis_num(dim), + chunks=chunks, + ) + ) + else: + assert False, "this should be unreachable" + return self.isel({dim: np.concatenate(indices)}) + def isel( self, indexers: Mapping[Any, Any] | None = None, diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..d07dd4b7b94 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -26,6 +26,7 @@ DatetimeLike, GroupIndices, ResampleCompatible, + Self, SideOptions, ) from xarray.core.variable import Variable @@ -139,6 +140,13 @@ def factorize(self, group: T_Group) -> EncodedGroups: """ pass + @abstractmethod + def reset(self) -> Self: + """ + Creates a new version of this Grouper clearing any caches. + """ + pass + class Resampler(Grouper): """ @@ -166,6 +174,9 @@ def group_as_index(self) -> pd.Index: self._group_as_index = pd.Index(np.array(self.group).ravel()) return self._group_as_index + def reset(self) -> Self: + return type(self)() + def factorize(self, group: T_Group) -> EncodedGroups: self.group = group @@ -287,6 +298,16 @@ class BinGrouper(Grouper): include_lowest: bool = False duplicates: Literal["raise", "drop"] = "raise" + def reset(self) -> Self: + return type(self)( + bins=self.bins, + right=self.right, + labels=self.labels, + precision=self.precision, + include_lowest=self.include_lowest, + duplicates=self.duplicates, + ) + def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") @@ -373,6 +394,15 @@ class TimeResampler(Resampler): index_grouper: CFTimeGrouper | pd.Grouper = field(init=False, repr=False) group_as_index: pd.Index = field(init=False, repr=False) + def reset(self) -> Self: + return type(self)( + freq=self.freq, + closed=self.closed, + label=self.label, + origin=self.origin, + offset=self.offset, + ) + def _init_properties(self, group: T_Group) -> None: from xarray import CFTimeIndex diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index a7d7ed7994f..b4c11bc63a5 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -78,7 +78,8 @@ def dtype(self) -> _DType_co: ... _Chunks = tuple[_Shape, ...] _NormalizedChunks = tuple[tuple[int, ...], ...] # FYI in some cases we don't allow `None`, which this doesn't take account of. -T_ChunkDim: TypeAlias = int | Literal["auto"] | None | tuple[int, ...] +# # FYI the `str` is for a size string, e.g. "16MB", supported by dask. +T_ChunkDim: TypeAlias = str | int | Literal["auto"] | None | tuple[int, ...] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim] diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index a056f4e00bd..82ceadf548b 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -251,3 +251,18 @@ def store( targets=targets, **kwargs, ) + + def shuffle( + self, x: DaskArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> DaskArray: + import dask.array + + if not module_available("dask", minversion="2024.08.1"): + raise ValueError( + "This method is very inefficient on dask<2024.08.1. Please upgrade." + ) + if chunks is None: + chunks = "auto" + if chunks != "auto": + raise NotImplementedError("Only chunks='auto' is supported at present.") + return dask.array.shuffle(x, indexer, axis, chunks="auto") diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index b90e0f99782..a1edfe41b13 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from xarray.namedarray._typing import ( + T_Chunks, _Chunks, _DType, _DType_co, @@ -357,6 +358,11 @@ def compute( """ raise NotImplementedError() + def shuffle( + self, x: T_ChunkedArray, indexer: list[list[int]], axis: int, chunks: T_Chunks + ) -> T_ChunkedArray: + raise NotImplementedError() + @property def array_api(self) -> Any: """ diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0e43738ed99..da49329104b 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -106,6 +106,7 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +has_dask_ge_2024_08_1, _ = _importorskip("dask", minversion="2024.08.1") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 062f0525593..3c7f0321acc 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1803,3 +1803,27 @@ def test_minimize_graph_size(): # all the other dimensions. # e.g. previously for 'x', actual == numchunks['y'] * numchunks['z'] assert actual == numchunks[var], (actual, numchunks[var]) + + +@pytest.mark.parametrize( + "chunks, expected_chunks", + [ + ((1,), (1, 3, 3, 3)), + ((10,), (10,)), + ], +) +def test_shuffle_by(chunks, expected_chunks): + from xarray.groupers import UniqueGrouper + + da = xr.DataArray( + dims="x", + data=dask.array.arange(10, chunks=chunks), + coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]}, + name="a", + ) + ds = da.to_dataset() + + for obj in [ds, da]: + actual = obj.shuffle_by(x=UniqueGrouper()) + assert_identical(actual, obj.sortby("x")) + assert actual.chunksizes["x"] == expected_chunks diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index fa6172c5d66..c765c718ee8 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3,6 +3,7 @@ import datetime import operator import warnings +from typing import Literal from unittest import mock import numpy as np @@ -29,8 +30,11 @@ assert_identical, create_test_data, has_cftime, + has_dask, + has_dask_ge_2024_08_1, has_flox, has_pandas_ge_2_2, + raise_if_dask_computes, requires_cftime, requires_dask, requires_flox, @@ -608,7 +612,22 @@ def test_groupby_repr_datetime(obj) -> None: @pytest.mark.filterwarnings("ignore:Converting non-nanosecond") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") -def test_groupby_drops_nans() -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize( + "chunk", + [ + pytest.param( + dict(lat=1), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + pytest.param( + dict(lat=2, lon=2), marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], +) +def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_1: + pytest.skip() # GH2383 # nan in 2D data variable (requires stacking) ds = xr.Dataset( @@ -623,13 +642,17 @@ def test_groupby_drops_nans() -> None: ds["id"].values[3, 0] = np.nan ds["id"].values[-1, -1] = np.nan + if chunk: + ds = ds.chunk(chunk) grouped = ds.groupby(ds.id) + if shuffle: + grouped = grouped.shuffle() # non reduction operation expected1 = ds.copy() - expected1.variable.values[0, 0, :] = np.nan - expected1.variable.values[-1, -1, :] = np.nan - expected1.variable.values[3, 0, :] = np.nan + expected1.variable.data[0, 0, :] = np.nan + expected1.variable.data[-1, -1, :] = np.nan + expected1.variable.data[3, 0, :] = np.nan actual1 = grouped.map(lambda x: x).transpose(*ds.variable.dims) assert_identical(actual1, expected1) @@ -1328,11 +1351,27 @@ def test_groupby_sum(self) -> None: assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) + @pytest.mark.parametrize("use_flox", [True, False]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize( + "chunk", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_dask, reason="no dask") + ), + False, + ], + ) @pytest.mark.parametrize("method", ["sum", "mean", "median"]) - def test_groupby_reductions(self, method) -> None: - array = self.da - grouped = array.groupby("abc") + def test_groupby_reductions( + self, use_flox: bool, method: str, shuffle: bool, chunk: bool + ) -> None: + if shuffle and chunk and not has_dask_ge_2024_08_1: + pytest.skip() + array = self.da + if chunk: + array.data = array.chunk({"y": 5}).data reduction = getattr(np, method) expected = Dataset( { @@ -1350,14 +1389,14 @@ def test_groupby_reductions(self, method) -> None: } )["foo"] - with xr.set_options(use_flox=False): - actual_legacy = getattr(grouped, method)(dim="y") + with raise_if_dask_computes(): + grouped = array.groupby("abc") + if shuffle: + grouped = grouped.shuffle() - with xr.set_options(use_flox=True): - actual_npg = getattr(grouped, method)(dim="y") - - assert_allclose(expected, actual_legacy) - assert_allclose(expected, actual_npg) + with xr.set_options(use_flox=use_flox): + actual = getattr(grouped, method)(dim="y") + assert_allclose(expected, actual) def test_groupby_count(self) -> None: array = DataArray( @@ -1621,13 +1660,14 @@ def test_groupby_bins( ) with xr.set_options(use_flox=use_flox): - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).sum() + gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs) + actual = gb.sum() assert_identical(expected, actual) + assert_identical(expected, gb.shuffle().sum()) - actual = array.groupby_bins("dim_0", bins=bins, **cut_kwargs).map( - lambda x: x.sum() - ) + actual = gb.map(lambda x: x.sum()) assert_identical(expected, actual) + assert_identical(expected, gb.shuffle().map(lambda x: x.sum())) # make sure original array dims are unchanged assert len(array.dim_0) == 4 @@ -1772,6 +1812,7 @@ def test_groupby_fastpath_for_monotonic(self, use_flox: bool) -> None: class TestDataArrayResample: + @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_cftime", [True, False]) @pytest.mark.parametrize( "resample_freq", @@ -1786,7 +1827,7 @@ class TestDataArrayResample: ], ) def test_resample( - self, use_cftime: bool, resample_freq: ResampleCompatible + self, use_cftime: bool, shuffle: bool, resample_freq: ResampleCompatible ) -> None: if use_cftime and not has_cftime: pytest.skip() @@ -1809,16 +1850,21 @@ def resample_as_pandas(array, *args, **kwargs): array = DataArray(np.arange(10), [("time", times)]) - actual = array.resample(time=resample_freq).mean() + rs = array.resample(time=resample_freq) + actual = rs.mean() expected = resample_as_pandas(array, resample_freq) assert_identical(expected, actual) + assert_identical(expected, rs.shuffle().mean()) - actual = array.resample(time=resample_freq).reduce(np.mean) - assert_identical(expected, actual) + assert_identical(expected, rs.reduce(np.mean)) + assert_identical(expected, rs.shuffle().reduce(np.mean)) - actual = array.resample(time=resample_freq, closed="right").mean() - expected = resample_as_pandas(array, resample_freq, closed="right") + rs = array.resample(time="24h", closed="right") + actual = rs.mean() + shuffled = rs.shuffle().mean() + expected = resample_as_pandas(array, "24h", closed="right") assert_identical(expected, actual) + assert_identical(expected, shuffled) with pytest.raises(ValueError, match=r"Index must be monotonic"): array[[2, 0, 1]].resample(time=resample_freq) @@ -2644,6 +2690,9 @@ def factorize(self, group) -> EncodedGroups: codes = group.copy(data=codes_).rename("year") return EncodedGroups(codes=codes, full_index=pd.Index(uniques)) + def reset(self): + return type(self)() + da = xr.DataArray( dims="time", data=np.arange(20), @@ -2740,8 +2789,9 @@ def test_multiple_groupers_string(as_dataset) -> None: obj.groupby("labels1", foo=UniqueGrouper()) +@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers(use_flox) -> None: +def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None: da = DataArray( np.array([1, 2, 3, 0, 2, np.nan]), dims="d", @@ -2753,6 +2803,8 @@ def test_multiple_groupers(use_flox) -> None: ) gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) expected = DataArray( @@ -2772,6 +2824,8 @@ def test_multiple_groupers(use_flox) -> None: coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])} square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"]) gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2795,11 +2849,15 @@ def test_multiple_groupers(use_flox) -> None: dims=["x", "y", "z"], ) gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): assert_identical(gb.mean("z"), b.mean("z")) gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() repr(gb) with xr.set_options(use_flox=use_flox): actual = gb.mean() @@ -2814,13 +2872,16 @@ def test_multiple_groupers(use_flox) -> None: @pytest.mark.parametrize("use_flox", [True, False]) -def test_multiple_groupers_mixed(use_flox) -> None: +@pytest.mark.parametrize("shuffle", [True, False]) +def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None: # This groupby has missing groups ds = xr.Dataset( {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, ) gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()) + if shuffle: + gb = gb.shuffle() expected_data = np.array( [ [[0.0, np.nan], [np.nan, 3.0]],