diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index b41bf3eeb3a..c10ee6a659d 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -81,8 +81,7 @@ You can index out a particular group: ds.groupby("letters")["b"] -Just like in pandas, creating a GroupBy object is cheap: it does not actually -split the data until you access particular values. +To group by multiple variables, see :ref:`this section `. Binning ~~~~~~~ @@ -180,19 +179,6 @@ This last line is roughly equivalent to the following:: results.append(group - alt.sel(letters=label)) xr.concat(results, dim='x') -Iterating and Squeezing -~~~~~~~~~~~~~~~~~~~~~~~ - -Previously, Xarray defaulted to squeezing out dimensions of size one when iterating over -a GroupBy object. This behaviour is being removed. -You can always squeeze explicitly later with the Dataset or DataArray -:py:meth:`DataArray.squeeze` methods. - -.. ipython:: python - - next(iter(arr.groupby("x", squeeze=False))) - - .. _groupby.multidim: Multidimensional Grouping @@ -236,6 +222,8 @@ applying your function, and then unstacking the result: stacked = da.stack(gridcell=["ny", "nx"]) stacked.groupby("gridcell").sum(...).unstack("gridcell") +Alternatively, you can groupby both `lat` and `lon` at the :ref:`same time `. + .. _groupby.groupers: Grouper Objects @@ -276,7 +264,8 @@ is identical to ds.groupby(x=UniqueGrouper()) -and + +Similarly, .. code-block:: python @@ -303,3 +292,26 @@ is identical to from xarray.groupers import TimeResampler ds.resample(time=TimeResampler("ME")) + + +.. _groupby.multiple: + +Grouping by multiple variables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use grouper objects to group by multiple dimensions: + +.. ipython:: python + + from xarray.groupers import UniqueGrouper + + da.groupby(lat=UniqueGrouper(), lon=UniqueGrouper()).sum() + + +Different groupers can be combined to construct sophisticated GroupBy operations. + +.. ipython:: python + + from xarray.groupers import BinGrouper + + ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3c6b7bfb58d..712ad68aeb3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,11 @@ New Features ~~~~~~~~~~~~ - Make chunk manager an option in ``set_options`` (:pull:`9362`). By `Tom White `_. +- Support for :ref:`grouping by multiple variables `. + This is quite new, so please check your results and report bugs. + Binary operations after grouping by multiple arrays are not supported yet. + (:issue:`1056`, :issue:`9332`, :issue:`324`, :pull:`9372`). + By `Deepak Cherian `_. - Allow data variable specific ``constant_values`` in the dataset ``pad`` function (:pull:`9353``). By `Tiago Sanona `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 84f229bf575..1f0544c1041 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6801,27 +6801,22 @@ def groupby( groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore group = None - grouper: Grouper + rgroupers: tuple[ResolvedGrouper, ...] if group is not None: if groupers: raise ValueError( "Providing a combination of `group` and **groupers is not supported." ) - grouper = UniqueGrouper() + rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),) else: - if len(groupers) > 1: - raise ValueError("grouping by multiple variables is not supported yet.") if not groupers: raise ValueError("Either `group` or `**groupers` must be provided.") - group, grouper = next(iter(groupers.items())) - - rgrouper = ResolvedGrouper(grouper, group, self) + rgroupers = tuple( + ResolvedGrouper(grouper, group, self) + for group, grouper in groupers.items() + ) - return DataArrayGroupBy( - self, - (rgrouper,), - restore_coord_dims=restore_coord_dims, - ) + return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @_deprecate_positional_args("v2024.07.0") def groupby_bins( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dbc00a03025..e14176f1589 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10397,25 +10397,22 @@ def groupby( groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore group = None + rgroupers: tuple[ResolvedGrouper, ...] if group is not None: if groupers: raise ValueError( "Providing a combination of `group` and **groupers is not supported." ) - rgrouper = ResolvedGrouper(UniqueGrouper(), group, self) + rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),) else: - if len(groupers) > 1: - raise ValueError("Grouping by multiple variables is not supported yet.") - elif not groupers: + if not groupers: raise ValueError("Either `group` or `**groupers` must be provided.") - for group, grouper in groupers.items(): - rgrouper = ResolvedGrouper(grouper, group, self) + rgroupers = tuple( + ResolvedGrouper(grouper, group, self) + for group, grouper in groupers.items() + ) - return DatasetGroupBy( - self, - (rgrouper,), - restore_coord_dims=restore_coord_dims, - ) + return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @_deprecate_positional_args("v2024.07.0") def groupby_bins( diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 833466ffe9e..cc83b32adc8 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,6 +1,8 @@ from __future__ import annotations import copy +import functools +import itertools import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field @@ -15,7 +17,7 @@ DataArrayGroupByAggregations, DatasetGroupByAggregations, ) -from xarray.core.alignment import align +from xarray.core.alignment import align, broadcast from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat @@ -68,10 +70,11 @@ def check_reduce_dims(reduce_dims, dimensions): ) -def _codes_to_group_indices(inverse: np.ndarray, N: int) -> GroupIndices: - assert inverse.ndim == 1 +def _codes_to_group_indices(codes: np.ndarray, N: int) -> GroupIndices: + """Converts integer codes for groups to group indices.""" + assert codes.ndim == 1 groups: GroupIndices = tuple([] for _ in range(N)) - for n, g in enumerate(inverse): + for n, g in enumerate(codes): if g >= 0: groups[g].append(n) return groups @@ -380,6 +383,65 @@ def _resolve_group( return newgroup +@dataclass +class ComposedGrouper: + """ + Helper class for multi-variable GroupBy. + This satisfies the Grouper interface, but is awkward to wrap in ResolvedGrouper. + For one, it simply re-infers a new EncodedGroups using known information + in existing ResolvedGroupers. So passing in a `group` (hard to define), + and `obj` (pointless) is not useful. + """ + + groupers: tuple[ResolvedGrouper, ...] + + def factorize(self) -> EncodedGroups: + from xarray.groupers import EncodedGroups + + groupers = self.groupers + + # At this point all arrays have been factorized. + codes = tuple(grouper.codes for grouper in groupers) + shape = tuple(grouper.size for grouper in groupers) + # We broadcast the codes against each other + broadcasted_codes = broadcast(*codes) + # This fully broadcasted DataArray is used as a template later + first_codes = broadcasted_codes[0] + # Now we convert to a single variable GroupBy problem + _flatcodes = np.ravel_multi_index( + tuple(codes.data for codes in broadcasted_codes), shape, mode="wrap" + ) + # NaNs; as well as values outside the bins are coded by -1 + # Restore these after the raveling + mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore + _flatcodes[mask] = -1 + + midx = pd.MultiIndex.from_product( + (grouper.unique_coord.data for grouper in groupers), + names=tuple(grouper.name for grouper in groupers), + ) + # Constructing an index from the product is wrong when there are missing groups + # (e.g. binning, resampling). Account for that now. + midx = midx[np.sort(pd.unique(_flatcodes[~mask]))] + + full_index = pd.MultiIndex.from_product( + (grouper.full_index.values for grouper in groupers), + names=tuple(grouper.name for grouper in groupers), + ) + dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers) + + coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name) + for grouper in groupers: + coords.variables[grouper.name].attrs = grouper.group.attrs + return EncodedGroups( + codes=first_codes.copy(data=_flatcodes), + full_index=full_index, + group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)), + unique_coord=Variable(dims=(dim_name,), data=midx.values), + coords=coords, + ) + + class GroupBy(Generic[T_Xarray]): """A object that implements the split-apply-combine pattern. @@ -417,12 +479,12 @@ class GroupBy(Generic[T_Xarray]): "encoded", ) _obj: T_Xarray - groupers: tuple[ResolvedGrouper] + groupers: tuple[ResolvedGrouper, ...] _restore_coord_dims: bool _original_obj: T_Xarray _group_indices: GroupIndices - _codes: DataArray + _codes: tuple[DataArray, ...] _group_dim: Hashable _groups: dict[GroupKey, GroupIndex] | None @@ -440,7 +502,7 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - groupers: tuple[ResolvedGrouper], + groupers: tuple[ResolvedGrouper, ...], restore_coord_dims: bool = True, ) -> None: """Create a GroupBy object @@ -459,8 +521,19 @@ def __init__( self._restore_coord_dims = restore_coord_dims self.groupers = groupers - (grouper,) = groupers - self.encoded = grouper.encoded + if len(groupers) == 1: + (grouper,) = groupers + self.encoded = grouper.encoded + else: + if any( + isinstance(obj._indexes.get(grouper.name, None), PandasMultiIndex) + for grouper in groupers + ): + raise NotImplementedError( + "Grouping by multiple variables, one of which " + "wraps a Pandas MultiIndex, is not supported yet." + ) + self.encoded = ComposedGrouper(groupers).factorize() # specification for the groupby operation # TODO: handle obj having variables that are not present on any of the groupers @@ -514,6 +587,12 @@ def reduce( ) -> T_Xarray: raise NotImplementedError() + def _raise_if_not_single_group(self): + if len(self.groupers) != 1: + raise NotImplementedError( + "This method is not supported for grouping by multiple variables yet." + ) + @property def groups(self) -> dict[GroupKey, GroupIndex]: """ @@ -539,13 +618,16 @@ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: return zip(self.encoded.unique_coord.data, self._iter_grouped()) def __repr__(self) -> str: - (grouper,) = self.groupers - return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( - self.__class__.__name__, - grouper.name, - grouper.full_index.size, - ", ".join(format_array_flat(grouper.full_index, 30).split()), + text = ( + f"<{self.__class__.__name__}, " + f"grouped over {len(self.groupers)} grouper(s)," + f" {self._len} groups in total:" ) + for grouper in self.groupers: + coord = grouper.unique_coord + labels = ", ".join(format_array_flat(coord, 30).split()) + text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}" + return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: """Iterate over each element in this group""" @@ -554,7 +636,6 @@ def _iter_grouped(self) -> Iterator[T_Xarray]: yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): - if self._group_dim in applied_example.dims: coord = self.group1d positions = self.encoded.group_indices @@ -570,6 +651,7 @@ def _binary_op(self, other, f, reflexive=False): g = f if not reflexive else lambda x, y: f(y, x) + self._raise_if_not_single_group() (grouper,) = self.groupers obj = self._original_obj name = grouper.name @@ -662,27 +744,44 @@ def _maybe_restore_empty_groups(self, combined): """ from xarray.groupers import BinGrouper, TimeResampler - (grouper,) = self.groupers - if ( - isinstance(grouper.grouper, BinGrouper | TimeResampler) - and grouper.name in combined.dims - ): - indexers = {grouper.name: grouper.full_index} + indexers = {} + for grouper in self.groupers: + if ( + isinstance(grouper.grouper, BinGrouper | TimeResampler) + and grouper.name in combined.dims + ): + indexers[grouper.name] = grouper.full_index + if indexers: combined = combined.reindex(**indexers) return combined def _maybe_unstack(self, obj): """This gets called if we are applying on an array with a multidimensional group.""" - (grouper,) = self.groupers + from xarray.groupers import UniqueGrouper + stacked_dim = self._stacked_dim - inserted_dims = self._inserted_dims if stacked_dim is not None and stacked_dim in obj.dims: + inserted_dims = self._inserted_dims obj = obj.unstack(stacked_dim) for dim in inserted_dims: if dim in obj.coords: del obj.coords[dim] obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) + elif len(self.groupers) > 1: + # TODO: we could clean this up by setting the appropriate `stacked_dim` + # and `inserted_dims` + # if multiple groupers all share the same single dimension, then + # we don't stack/unstack. Do that manually now. + obj = obj.unstack(*self.encoded.unique_coord.dims) + to_drop = [ + grouper.name + for grouper in self.groupers + if isinstance(grouper.group, _DummyGroup) + and isinstance(grouper.grouper, UniqueGrouper) + ] + obj = obj.drop_vars(to_drop) + return obj def _flox_reduce( @@ -699,9 +798,15 @@ def _flox_reduce( from xarray.groupers import BinGrouper obj = self._original_obj - (grouper,) = self.groupers - name = grouper.name - isbin = isinstance(grouper.grouper, BinGrouper) + variables = ( + {k: v.variable for k, v in obj.data_vars.items()} + if isinstance(obj, Dataset) + else obj._coords + ) + + any_isbin = any( + isinstance(grouper.grouper, BinGrouper) for grouper in self.groupers + ) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -712,12 +817,27 @@ def _flox_reduce( # flox >=0.9 will choose this on its own. kwargs.setdefault("method", "cohorts") - numeric_only = kwargs.pop("numeric_only", None) - if numeric_only: + midx_grouping_vars: tuple[Hashable, ...] = () + for grouper in self.groupers: + name = grouper.name + maybe_midx = obj._indexes.get(name, None) + if isinstance(maybe_midx, PandasMultiIndex): + midx_grouping_vars += tuple(maybe_midx.index.names) + (name,) + + # For datasets, running a numeric-only reduction on non-numeric + # variable will just drop it. + non_numeric: dict[Hashable, Variable] + if kwargs.pop("numeric_only", None): non_numeric = { name: var - for name, var in obj.data_vars.items() - if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + for name, var in variables.items() + if ( + not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + # this avoids dropping any levels of a MultiIndex, which raises + # a warning + and name not in midx_grouping_vars + and name not in obj.dims + ) } else: non_numeric = {} @@ -729,15 +849,25 @@ def _flox_reduce( # set explicitly to avoid unnecessarily accumulating count kwargs["min_count"] = 0 - unindexed_dims: tuple[Hashable, ...] = tuple() - if isinstance(grouper.group, _DummyGroup) and not isbin: - unindexed_dims = (name,) + unindexed_dims: tuple[Hashable, ...] = tuple( + grouper.name + for grouper in self.groupers + if isinstance(grouper.group, _DummyGroup) + and not isinstance(grouper.grouper, BinGrouper) + ) parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): parsed_dim = (dim,) elif dim is None: - parsed_dim = grouper.group.dims + parsed_dim_list = list() + # preserve order + for dim_ in itertools.chain( + *(grouper.group.dims for grouper in self.groupers) + ): + if dim_ not in parsed_dim_list: + parsed_dim_list.append(dim_) + parsed_dim = tuple(parsed_dim_list) elif dim is ...: parsed_dim = tuple(obj.dims) else: @@ -745,12 +875,15 @@ def _flox_reduce( # Do this so we raise the same error message whether flox is present or not. # Better to control it here than in flox. - if any(d not in grouper.group.dims and d not in obj.dims for d in parsed_dim): - raise ValueError(f"cannot reduce over dimensions {dim}.") + for grouper in self.groupers: + if any( + d not in grouper.group.dims and d not in obj.dims for d in parsed_dim + ): + raise ValueError(f"cannot reduce over dimensions {dim}.") if kwargs["func"] not in ["all", "any", "count"]: kwargs.setdefault("fill_value", np.nan) - if isbin and kwargs["func"] == "count": + if any_isbin and kwargs["func"] == "count": # This is an annoying hack. Xarray returns np.nan # when there are no observations in a bin, instead of 0. # We can fake that here by forcing min_count=1. @@ -759,13 +892,17 @@ def _flox_reduce( kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) - output_index = grouper.full_index + # pass RangeIndex as a hint to flox that `by` is already factorized + expected_groups = tuple( + pd.RangeIndex(len(grouper)) for grouper in self.groupers + ) + + codes = tuple(g.codes for g in self.groupers) result = xarray_reduce( obj.drop_vars(non_numeric.keys()), - self.encoded.codes, + *codes, dim=parsed_dim, - # pass RangeIndex as a hint to flox that `by` is already factorized - expected_groups=(pd.RangeIndex(len(output_index)),), + expected_groups=expected_groups, isbin=False, keep_attrs=keep_attrs, **kwargs, @@ -795,12 +932,28 @@ def _flox_reduce( Coordinates(new_coords, new_indexes) ).drop_vars(unindexed_dims) - # broadcast and restore non-numeric data variables (backcompat) - for name, var in non_numeric.items(): - if all(d not in var.dims for d in parsed_dim): - result[name] = var.variable.set_dims( - (name,) + var.dims, (result.sizes[name],) + var.shape + # broadcast any non-dim coord variables that don't + # share all dimensions with the grouper + result_variables = ( + result._variables if isinstance(result, Dataset) else result._coords + ) + to_broadcast: dict[Hashable, Variable] = {} + for name, var in variables.items(): + dims_set = set(var.dims) + if ( + dims_set <= set(parsed_dim) + and (dims_set & set(result.dims)) + and name not in result_variables + ): + to_broadcast[name] = var + for name, var in to_broadcast.items(): + if new_dims := tuple(d for d in parsed_dim if d not in var.dims): + new_sizes = tuple( + result.sizes.get(dim, obj.sizes.get(dim)) for dim in new_dims ) + result[name] = var.set_dims( + new_dims + var.dims, new_sizes + var.shape + ).transpose(..., *result.dims) if not isinstance(result, Dataset): # only restore dimension order for arrays @@ -962,8 +1115,7 @@ def quantile( The American Statistician, 50(4), pp. 361-365, 1996 """ if dim is None: - (grouper,) = self.groupers - dim = self.group1d.dims + dim = (self._group_dim,) # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1074,7 +1226,8 @@ def _iter_grouped_shortcut(self): """ var = self._obj.variable for idx, indices in enumerate(self.encoded.group_indices): - yield var[{self._group_dim: indices}] + if indices: + yield var[{self._group_dim: indices}] def _concat_shortcut(self, applied, dim, positions=None): # nb. don't worry too much about maintaining this method -- it does @@ -1088,12 +1241,11 @@ def _concat_shortcut(self, applied, dim, positions=None): return self._obj._replace_maybe_drop_dims(reordered) def _restore_dim_order(self, stacked: DataArray) -> DataArray: - (grouper,) = self.groupers - group = self.group1d def lookup_order(dimension): - if dimension == grouper.name: - (dimension,) = group.dims + for grouper in self.groupers: + if dimension == grouper.name and grouper.group.ndim == 1: + (dimension,) = grouper.group.dims if dimension in self._obj.dims: axis = self._obj.get_axis_num(dimension) else: @@ -1101,7 +1253,10 @@ def lookup_order(dimension): return axis new_order = sorted(stacked.dims, key=lookup_order) - return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims) + stacked = stacked.transpose( + *new_order, transpose_coords=self._restore_coord_dims + ) + return stacked def map( self, diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 7dbb0d5e59c..fc04b49fabc 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -11,9 +11,16 @@ import xarray as xr from xarray import DataArray, Dataset, Variable +from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions -from xarray.groupers import BinGrouper, EncodedGroups, Grouper, UniqueGrouper +from xarray.groupers import ( + BinGrouper, + EncodedGroups, + Grouper, + TimeResampler, + UniqueGrouper, +) from xarray.tests import ( InaccessibleArray, assert_allclose, @@ -119,6 +126,15 @@ def test_multi_index_groupby_sum() -> None: actual = ds.stack(space=["x", "y"]).groupby("space").sum("z").unstack("space") assert_equal(expected, actual) + with pytest.raises(NotImplementedError): + actual = ( + ds.stack(space=["x", "y"]) + .groupby(space=UniqueGrouper(), z=UniqueGrouper()) + .sum("z") + .unstack("space") + ) + assert_equal(expected, ds) + if not has_pandas_ge_2_1: # the next line triggers a mysterious multiindex error on pandas 2.0 return @@ -564,27 +580,28 @@ def test_da_groupby_assign_coords() -> None: @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) def test_groupby_repr(obj, dim) -> None: actual = repr(obj.groupby(dim)) - expected = f"{obj.__class__.__name__}GroupBy" - expected += f", grouped over {dim!r}" - expected += f"\n{len(np.unique(obj[dim]))!r} groups with labels " + N = len(np.unique(obj[dim])) + expected = f"<{obj.__class__.__name__}GroupBy" + expected += f", grouped over 1 grouper(s), {N} groups in total:" + expected += f"\n\t{dim!r}: {N} groups with labels " if dim == "x": - expected += "1, 2, 3, 4, 5." + expected += "1, 2, 3, 4, 5>" elif dim == "y": - expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19." + expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19>" elif dim == "z": - expected += "'a', 'b', 'c'." + expected += "'a', 'b', 'c'>" elif dim == "month": - expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." + expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) - expected = f"{obj.__class__.__name__}GroupBy" - expected += ", grouped over 'month'" - expected += f"\n{len(np.unique(obj.t.dt.month))!r} groups with labels " - expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12." + expected = f"<{obj.__class__.__name__}GroupBy" + expected += ", grouped over 1 grouper(s), 12 groups in total:\n" + expected += "\t'month': 12 groups with labels " + expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected @@ -2605,3 +2622,139 @@ def test_weather_data_resample(use_flox): with xr.set_options(use_flox=use_flox): actual = ds.resample(time="1MS").mean() assert "location" in actual._indexes + + gb = ds.groupby(time=TimeResampler(freq="1MS"), location=UniqueGrouper()) + with xr.set_options(use_flox=use_flox): + actual = gb.mean() + expected = ds.resample(time="1MS").mean().sortby("location") + assert_allclose(actual, expected) + assert actual.time.attrs == ds.time.attrs + assert actual.location.attrs == ds.location.attrs + + assert expected.time.attrs == ds.time.attrs + assert expected.location.attrs == ds.location.attrs + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_multiple_groupers(use_flox) -> None: + da = DataArray( + np.array([1, 2, 3, 0, 2, np.nan]), + dims="d", + coords=dict( + labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])), + labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])), + ), + name="foo", + ) + + gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()) + repr(gb) + + expected = DataArray( + np.array([[1.0, np.nan, np.nan], [np.nan, 2.0, np.nan], [np.nan, np.nan, 1.5]]), + dims=("labels1", "labels2"), + coords={ + "labels1": np.array(["a", "b", "c"], dtype=object), + "labels2": np.array(["x", "y", "z"], dtype=object), + }, + name="foo", + ) + with xr.set_options(use_flox=use_flox): + actual = gb.mean() + assert_identical(actual, expected) + + # ------- + coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])} + square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"]) + gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper()) + repr(gb) + with xr.set_options(use_flox=use_flox): + actual = gb.mean() + expected = DataArray( + np.array([[2.5, 4.5], [10.5, 12.5]]), + dims=("a", "b"), + coords={"a": [0, 1], "b": [0, 1]}, + ) + assert_identical(actual, expected) + + expected = square.astype(np.float64) + expected["a"], expected["b"] = broadcast(square.a, square.b) + with xr.set_options(use_flox=use_flox): + assert_identical( + square.groupby(x=UniqueGrouper(), y=UniqueGrouper()).mean(), expected + ) + + b = xr.DataArray( + np.random.RandomState(0).randn(2, 3, 4), + coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])}, + dims=["x", "y", "z"], + ) + gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + repr(gb) + with xr.set_options(use_flox=use_flox): + assert_identical(gb.mean("z"), b.mean("z")) + + gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper()) + repr(gb) + with xr.set_options(use_flox=use_flox): + actual = gb.mean() + expected = b.drop_vars("xy").rename({"y": "xy"}).copy(deep=True) + newval = b.isel(x=1, y=slice(1, None)).mean("y").data + expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data + expected.loc[dict(x=1, xy=0)] = np.nan + expected.loc[dict(x=1, xy=2)] = newval + expected["xy"] = ("xy", ["a", "b", "c"]) + # TODO: is order of dims correct? + assert_identical(actual, expected.transpose("z", "x", "xy")) + + +@pytest.mark.parametrize("use_flox", [True, False]) +def test_multiple_groupers_mixed(use_flox) -> None: + # This groupby has missing groups + ds = xr.Dataset( + {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, + coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ) + gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()) + expected_data = np.array( + [ + [[0.0, np.nan], [np.nan, 3.0]], + [[1.0, np.nan], [np.nan, 4.0]], + [[2.0, np.nan], [np.nan, 5.0]], + ] + ) + expected = xr.Dataset( + {"foo": (("y", "x_bins", "letters"), expected_data)}, + coords={ + "x_bins": ( + "x_bins", + np.array( + [ + pd.Interval(5, 15, closed="right"), + pd.Interval(15, 25, closed="right"), + ], + dtype=object, + ), + ), + "letters": ("letters", np.array(["a", "b"], dtype=object)), + }, + ) + with xr.set_options(use_flox=use_flox): + actual = gb.sum() + assert_identical(actual, expected) + + # assert_identical( + # b.groupby(['x', 'y']).apply(lambda x: x - x.mean()), + # b - b.mean("z"), + # ) + + # gb = square.groupby(x=UniqueGrouper(), y=UniqueGrouper()) + # gb - gb.mean() + + # ------ + + +# Possible property tests +# 1. lambda x: x +# 2. grouped-reduce on unique coords is identical to array +# 3. group_over == groupby-reduce along other dimensions