diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a2c77732692..bbbf65f3529 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -22,6 +22,7 @@ from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat from xarray.core.coordinates import Coordinates +from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( PandasIndex, @@ -462,20 +463,26 @@ def factorize(self) -> EncodedGroups: # NaNs; as well as values outside the bins are coded by -1 # Restore these after the raveling mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore[arg-type] - _flatcodes[mask] = -1 - - midx = pd.MultiIndex.from_product( - (grouper.unique_coord.data for grouper in groupers), - names=tuple(grouper.name for grouper in groupers), - ) - # Constructing an index from the product is wrong when there are missing groups - # (e.g. binning, resampling). Account for that now. - midx = midx[np.sort(pd.unique(_flatcodes[~mask]))] + _flatcodes = where(mask, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), names=tuple(grouper.name for grouper in groupers), ) + # This will be unused when grouping by dask arrays, so skip.. + if not is_chunked_array(_flatcodes): + midx = pd.MultiIndex.from_product( + (grouper.unique_coord.data for grouper in groupers), + names=tuple(grouper.name for grouper in groupers), + ) + # Constructing an index from the product is wrong when there are missing groups + # (e.g. binning, resampling). Account for that now. + midx = midx[np.sort(pd.unique(_flatcodes[~mask]))] + group_indices = _codes_to_group_indices(_flatcodes.ravel(), len(full_index)) + else: + midx = full_index + group_indices = None + dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers) coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name) @@ -484,7 +491,7 @@ def factorize(self) -> EncodedGroups: return EncodedGroups( codes=first_codes.copy(data=_flatcodes), full_index=full_index, - group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)), + group_indices=group_indices, unique_coord=Variable(dims=(dim_name,), data=midx.values), coords=coords, ) diff --git a/xarray/groupers.py b/xarray/groupers.py index 75f8a5b4305..8168fed307e 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -191,7 +191,10 @@ def factorize(self, group: T_Group) -> EncodedGroups: self.group = group if is_chunked_array(group.data) and self.labels is None: - raise ValueError("When grouping by a dask array, `labels` must be passed.") + raise ValueError( + "When grouping by a dask array, `labels` must be passed using " + "a UniqueGrouper object." + ) if self.labels is not None: return self._factorize_given_labels(group) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 2a81dd460e2..adccf5da132 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,6 +22,7 @@ TimeResampler, UniqueGrouper, ) +from xarray.namedarray.pycompat import is_chunked_array from xarray.tests import ( InaccessibleArray, assert_allclose, @@ -29,6 +30,7 @@ assert_identical, create_test_data, has_cftime, + has_dask, has_flox, has_pandas_ge_2_2, raise_if_dask_computes, @@ -2796,7 +2798,7 @@ def test_multiple_groupers(use_flox) -> None: b = xr.DataArray( np.random.RandomState(0).randn(2, 3, 4), - coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])}, + coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})}, dims=["x", "y", "z"], ) gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) @@ -2813,10 +2815,24 @@ def test_multiple_groupers(use_flox) -> None: expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data expected.loc[dict(x=1, xy=0)] = np.nan expected.loc[dict(x=1, xy=2)] = newval - expected["xy"] = ("xy", ["a", "b", "c"]) + expected["xy"] = ("xy", ["a", "b", "c"], {"foo": "bar"}) # TODO: is order of dims correct? assert_identical(actual, expected.transpose("z", "x", "xy")) + if has_dask: + b["xy"] = b["xy"].chunk() + with raise_if_dask_computes(): + gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"])) + + expected = xr.DataArray( + [[[1, 1, 1], [0, 1, 2]]] * 4, + dims=("z", "x", "xy"), + coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})}, + ) + assert_identical(gb.count(), expected) + assert is_chunked_array(gb.encoded.codes.data) + assert not gb.encoded.group_indices + @pytest.mark.parametrize("use_flox", [True, False]) def test_multiple_groupers_mixed(use_flox) -> None: