From 95f4802ab32b161d6f90921b48fd8cc51a57f01a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 18 Sep 2024 17:41:19 -0600 Subject: [PATCH 01/17] GroupBy(chunked-array) Closes #757 Closes #2852 --- xarray/groupers.py | 102 ++++++++++++++++++++++++++++------- xarray/tests/test_groupby.py | 4 +- 2 files changed, 86 insertions(+), 20 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..97de79602d4 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -16,8 +16,10 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops +from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray +from xarray.core.duck_array_ops import isnull from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -29,6 +31,7 @@ SideOptions, ) from xarray.core.variable import Variable +from xarray.namedarray.pycompat import is_chunked_array __all__ = [ "EncodedGroups", @@ -96,7 +99,7 @@ def __init__( assert isinstance(full_index, pd.Index) self.full_index = full_index - if group_indices is None: + if group_indices is None and not is_chunked_array(codes.data): self.group_indices = tuple( g for g in _codes_to_group_indices(codes.data.ravel(), len(full_index)) @@ -155,10 +158,17 @@ class UniqueGrouper(Grouper): """Grouper object for grouping by a categorical variable.""" _group_as_index: pd.Index | None = field(default=None, repr=False) + labels: np.ndarray | None = field(default=None) + + def __post_init__(self) -> None: + if self.labels is not None: + self.labels = np.sort(self.labels) @property def group_as_index(self) -> pd.Index: """Caches the group DataArray as a pandas Index.""" + if is_chunked_array(self.group): + raise ValueError("Please call compute manually.") if self._group_as_index is None: if self.group.ndim == 1: self._group_as_index = self.group.to_index() @@ -169,6 +179,11 @@ def group_as_index(self) -> pd.Index: def factorize(self, group: T_Group) -> EncodedGroups: self.group = group + if is_chunked_array(group.data) and self.labels is None: + raise ValueError("When grouping by a dask array, `labels` must be passed.") + if self.labels is not None: + return self._factorize_given_labels(group) + index = self.group_as_index is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or ( index.is_unique @@ -182,6 +197,24 @@ def factorize(self, group: T_Group) -> EncodedGroups: else: return self._factorize_unique() + def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: + codes = apply_ufunc( + _factorize_given_labels, + group, + kwargs={"labels": self.labels}, + dask="parallelized", + output_dtypes=[np.int64], + ) + return EncodedGroups( + codes=codes, + full_index=pd.Index(self.labels), + unique_coord=Variable( + dims=codes.name, + data=self.labels, + attrs=self.group.attrs, + ), + ) + def _factorize_unique(self) -> EncodedGroups: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) @@ -291,13 +324,9 @@ def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") - def factorize(self, group: T_Group) -> EncodedGroups: - from xarray.core.dataarray import DataArray - - data = np.asarray(group.data) # Cast _DummyGroup data to array - - binned, self.bins = pd.cut( # type: ignore [call-overload] - data.ravel(), + def _cut(self, data): + return pd.cut( # type: ignore [call-overload] + np.asarray(data).ravel(), bins=self.bins, right=self.right, labels=self.labels, @@ -307,23 +336,43 @@ def factorize(self, group: T_Group) -> EncodedGroups: retbins=True, ) - binned_codes = binned.codes - if (binned_codes == -1).all(): + def _factorize_lazy(self, group: T_Group) -> DataArray: + def _wrapper(data, **kwargs): + binned, bins = self._cut(data) + if isinstance(self.bins, int): + # we are running eagerly, update self.bins with actual edges instead + self.bins = bins + return binned.codes.reshape(data.shape) + + return apply_ufunc(_wrapper, group, dask="parallelized") + + def factorize(self, group: T_Group) -> EncodedGroups: + if isinstance(group, _DummyGroup): + group = DataArray(group.data, dims=group.dims, name=group.name) + by_is_chunked = is_chunked_array(group.data) + if isinstance(self.bins, int) and by_is_chunked: + raise ValueError( + f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead" + ) + codes = self._factorize_lazy(group) + if not by_is_chunked and (codes == -1).all(): raise ValueError( f"None of the data falls within bins with edges {self.bins!r}" ) new_dim_name = f"{group.name}_bins" + codes.name = new_dim_name + + # This seems silly, but it lets us have Pandas handle the complexity + # of labels, precision, and include_lowest, even when group is a chunked array + dummy, _ = self._cut(np.array([1, 2, 3]).astype(group.dtype)) + full_index = dummy.categories + if not by_is_chunked: + uniques = np.sort(pd.unique(codes.data.ravel())) + unique_values = full_index[uniques[uniques != -1]] + else: + unique_values = full_index - full_index = binned.categories - uniques = np.sort(pd.unique(binned_codes)) - unique_values = full_index[uniques[uniques != -1]] - - codes = DataArray( - binned_codes.reshape(group.shape), - getattr(group, "coords", None), - name=new_dim_name, - ) unique_coord = Variable( dims=new_dim_name, data=unique_values, attrs=group.attrs ) @@ -461,6 +510,21 @@ def factorize(self, group: T_Group) -> EncodedGroups: ) +def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: + # Copied from flox + sort = False # use labels as provided + sorter = np.argsort(labels) + codes = np.searchsorted(labels, data, sorter=sorter) + mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels)) + if not sort: + # codes is the index in to the sorted array. + # if we didn't want sorting, unsort it back + codes[(codes == len(labels),)] = -1 + codes = sorter[(codes,)] + codes[mask] = -1 + return codes + + def unique_value_groups( ar, sort: bool = True ) -> tuple[np.ndarray | pd.Index, np.ndarray]: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dc869cc3a34..bdf017e2be9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2583,7 +2583,9 @@ def test_groupby_math_auto_chunk() -> None: sub = xr.DataArray( InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]} ) - actual = da.chunk(x=1, y=2).groupby("label") - sub + chunked = da.chunk(x=1, y=2) + chunked.label.load() + actual = chunked.groupby("label") - sub assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)} From e022231186c60ab2634d9a1499adf4193f729771 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 18 Sep 2024 22:08:39 -0600 Subject: [PATCH 02/17] Optimizations --- xarray/core/groupby.py | 4 +++- xarray/core/utils.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 58971435018..eb8b061e8b0 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -249,7 +249,9 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] - newgroup = group.stack({stacked_dim: orig_dims}) + # `newgroup` construction is optimized so we don't create an index unnecessarily, + # or stack any non-dim coords unnecessarily + newgroup = DataArray(group.variable.stack({stacked_dim: orig_dims})) newobj = obj.stack({stacked_dim: orig_dims}) return newgroup, newobj, stacked_dim, inserted_dims diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 3c1dee7a36d..53cb9c160fb 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -1017,6 +1017,7 @@ def contains_only_chunked_or_numpy(obj) -> bool: Expects obj to be Dataset or DataArray""" from xarray.core.dataarray import DataArray + from xarray.core.indexing import ExplicitlyIndexed from xarray.namedarray.pycompat import is_chunked_array if isinstance(obj, DataArray): @@ -1024,8 +1025,10 @@ def contains_only_chunked_or_numpy(obj) -> bool: return all( [ - isinstance(var.data, np.ndarray) or is_chunked_array(var.data) - for var in obj.variables.values() + isinstance(var._data, ExplicitlyIndexed) + or isinstance(var._data, np.ndarray) + or is_chunked_array(var._data) + for var in obj._variables.values() ] ) From d5d8ef21e3b5758f4b83942c546b47ff25633cf0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 18 Sep 2024 22:15:41 -0600 Subject: [PATCH 03/17] Optimize multi-index construction --- xarray/core/indexes.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5abc2129e3e..49cac752532 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1031,13 +1031,21 @@ def stack( f"from variable {name!r} that wraps a multi-index" ) - split_labels, levels = zip( - *[lev.factorize() for lev in level_indexes], strict=True - ) - labels_mesh = np.meshgrid(*split_labels, indexing="ij") - labels = [x.ravel() for x in labels_mesh] + # from_product sorts by default, so we can't use that always + # https://github.com/pydata/xarray/issues/980 + # https://github.com/pandas-dev/pandas/issues/14672 + if all(index.is_monotonic_increasing for index in level_indexes): + index = pd.MultiIndex.from_product( + level_indexes, sortorder=0, names=variables.keys() + ) + else: + split_labels, levels = zip( + *[lev.factorize() for lev in level_indexes], strict=True + ) + labels_mesh = np.meshgrid(*split_labels, indexing="ij") + labels = [x.ravel() for x in labels_mesh] - index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) + index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys()) level_coords_dtype = {k: var.dtype for k, var in variables.items()} return cls(index, dim, level_coords_dtype=level_coords_dtype) From a1e0d6fedd419bcdc404b4d2938536f564c0f061 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 08:03:39 -0600 Subject: [PATCH 04/17] Add tests --- xarray/core/groupby.py | 1 + xarray/groupers.py | 18 ++++------- xarray/tests/test_groupby.py | 63 ++++++++++++++++++++++++++++++++---- 3 files changed, 65 insertions(+), 17 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index eb8b061e8b0..dd6a6671a6a 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -585,6 +585,7 @@ def __init__( # specification for the groupby operation # TODO: handle obj having variables that are not present on any of the groupers # simple broadcasting fails for ExtensionArrays. + # FIXME: Skip this stacking when grouping by a dask array, it's useless in that case. (self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d( group=self.encoded.codes, obj=obj ) diff --git a/xarray/groupers.py b/xarray/groupers.py index 97de79602d4..c4bf2d02d9b 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -160,10 +160,6 @@ class UniqueGrouper(Grouper): _group_as_index: pd.Index | None = field(default=None, repr=False) labels: np.ndarray | None = field(default=None) - def __post_init__(self) -> None: - if self.labels is not None: - self.labels = np.sort(self.labels) - @property def group_as_index(self) -> pd.Index: """Caches the group DataArray as a pandas Index.""" @@ -364,8 +360,8 @@ def factorize(self, group: T_Group) -> EncodedGroups: codes.name = new_dim_name # This seems silly, but it lets us have Pandas handle the complexity - # of labels, precision, and include_lowest, even when group is a chunked array - dummy, _ = self._cut(np.array([1, 2, 3]).astype(group.dtype)) + # of `labels`, `precision`, and `include_lowest`, even when group is a chunked array + dummy, _ = self._cut(np.array([0]).astype(group.dtype)) full_index = dummy.categories if not by_is_chunked: uniques = np.sort(pd.unique(codes.data.ravel())) @@ -512,14 +508,14 @@ def factorize(self, group: T_Group) -> EncodedGroups: def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray: # Copied from flox - sort = False # use labels as provided sorter = np.argsort(labels) + is_sorted = (sorter == np.arange(sorter.size)).all() codes = np.searchsorted(labels, data, sorter=sorter) mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels)) - if not sort: - # codes is the index in to the sorted array. - # if we didn't want sorting, unsort it back - codes[(codes == len(labels),)] = -1 + # codes is the index in to the sorted array. + # if we didn't want sorting, unsort it back + if not is_sorted: + codes[codes == len(labels)] = -1 codes = sorter[(codes,)] codes[mask] = -1 return codes diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index bdf017e2be9..f4df37252b6 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -31,6 +31,7 @@ has_cftime, has_flox, has_pandas_ge_2_2, + raise_if_dask_computes, requires_cftime, requires_dask, requires_flox, @@ -2917,12 +2918,6 @@ def test_gappy_resample_reductions(reduction): assert_identical(expected, actual) -# Possible property tests -# 1. lambda x: x -# 2. grouped-reduce on unique coords is identical to array -# 3. group_over == groupby-reduce along other dimensions - - def test_groupby_transpose(): # GH5361 data = xr.DataArray( @@ -2934,3 +2929,59 @@ def test_groupby_transpose(): second = data.groupby("x").sum() assert_identical(first, second.transpose(*first.dims)) + + +@requires_dask +@pytest.mark.parametrize( + "grouper, expect_index", + [ + [UniqueGrouper(labels=np.arange(1, 5)), pd.Index(np.arange(1, 5))], + [UniqueGrouper(labels=np.arange(1, 5)[::-1]), pd.Index(np.arange(1, 5)[::-1])], + [ + BinGrouper(bins=np.arange(1, 5)), + pd.IntervalIndex.from_breaks(np.arange(1, 5)), + ], + ], +) +def test_lazy_grouping(grouper, expect_index): + import dask.array + + data = DataArray( + dims=("x", "y"), + data=dask.array.arange(20, chunks=3).reshape((4, 5)), + name="zoo", + ) + with raise_if_dask_computes(): + encoded = grouper.factorize(data) + assert encoded.codes.ndim == data.ndim + pd.testing.assert_index_equal(encoded.full_index, expect_index) + np.testing.assert_array_equal(encoded.unique_coord.values, np.array(expect_index)) + + lazy = xr.Dataset({"foo": data}, coords={"zoo": data}).groupby(zoo=grouper).count() + eager = ( + xr.Dataset({"foo": data}, coords={"zoo": data.compute()}) + .groupby(zoo=grouper) + .count() + ) + expected = Dataset( + {"foo": (encoded.codes.name, np.ones(encoded.full_index.size))}, + coords={encoded.codes.name: expect_index}, + ) + assert_identical(eager, lazy) + assert_identical(eager, expected) + + +@requires_dask +def test_lazy_int_bins_error(): + import dask.array + + with pytest.raises(ValueError, match="Bin edges must be provided"): + with raise_if_dask_computes(): + _ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3))) + + +# Possible property tests +# 1. lambda x: x +# 2. grouped-reduce on unique coords is identical to array +# 3. group_over == groupby-reduce along other dimensions +# 4. result is equivalent for transposed input From adf29437a4fd6fcebf79755ab8e0d908d7ffcb5e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 10:40:52 -0600 Subject: [PATCH 05/17] Add whats-new --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e4b2a06a3e7..67550522322 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,6 +32,9 @@ New Features `Tom Nicholas `_. - Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`). By `Eni Awowale `_. +- Support lazy grouping by dask arrays, and allow specifying groups with ``UniqueGrouper(labels=["a", "b", "c"])`` + (:issue:`2852`, :issue:`757`). + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ From f56dc8543f64cc2d0c422a4e4e31a0c2a1f1ff1a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 13:25:57 -0600 Subject: [PATCH 06/17] Raise errors --- xarray/core/groupby.py | 29 ++++++++++++++++++++++++++++- xarray/tests/test_groupby.py | 24 ++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index dd6a6671a6a..94101874b5d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -47,6 +47,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: @@ -533,6 +534,7 @@ class GroupBy(Generic[T_Xarray]): _group_indices: GroupIndices _codes: tuple[DataArray, ...] _group_dim: Hashable + _by_chunked: bool _groups: dict[GroupKey, GroupIndex] | None _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None @@ -596,6 +598,7 @@ def __init__( self._dims = None self._sizes = None self._len = len(self.encoded.full_index) + self._by_chunked = is_chunked_array(self.encoded.codes.data) @property def sizes(self) -> Mapping[Hashable, int]: @@ -635,6 +638,14 @@ def reduce( ) -> T_Xarray: raise NotImplementedError() + def _raise_if_by_is_chunked(self): + if self._by_chunked: + raise ValueError( + "This method is not supported when lazily grouping by a chunked array. " + "Either load the array in to memory prior to grouping, or explore another " + "way of applying your function, potentially using the `flox` package." + ) + def _raise_if_not_single_group(self): if len(self.groupers) != 1: raise NotImplementedError( @@ -683,6 +694,7 @@ def __repr__(self) -> str: def _iter_grouped(self) -> Iterator[T_Xarray]: """Iterate over each element in this group""" + self._raise_if_by_is_chunked() for indices in self.encoded.group_indices: if indices: yield self._obj.isel({self._group_dim: indices}) @@ -857,7 +869,7 @@ def _flox_reduce( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - if Version(flox.__version__) < Version("0.9"): + if Version(flox.__version__) < Version("0.9") and not self._by_chunked: # preserve current strategy (approximately) for dask groupby # on older flox versions to prevent surprises. # flox >=0.9 will choose this on its own. @@ -1270,6 +1282,7 @@ def _iter_grouped_shortcut(self): """Fast version of `_iter_grouped` that yields Variables without metadata """ + self._raise_if_by_is_chunked() var = self._obj.variable for _idx, indices in enumerate(self.encoded.group_indices): if indices: @@ -1431,6 +1444,12 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ + if self._by_chunked: + raise ValueError( + "This method is not supported when lazily grouping by a chunked array. " + "Try installing the `flox` package if you are using one of the standard " + "reductions (e.g. `mean`). " + ) if dim is None: dim = [self._group_dim] @@ -1582,6 +1601,14 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ + + if self._by_chunked: + raise ValueError( + "This method is not supported when lazily grouping by a chunked array. " + "Try installing the `flox` package if you are using one of the standard " + "reductions (e.g. `mean`). " + ) + if dim is None: dim = [self._group_dim] diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f4df37252b6..8b8d01baee3 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2971,6 +2971,30 @@ def test_lazy_grouping(grouper, expect_index): assert_identical(eager, expected) +@requires_dask +def test_lazy_grouping_errors(): + import dask.array + + data = DataArray( + dims=("x",), + data=dask.array.arange(20, chunks=3), + name="foo", + coords={"y": ("x", dask.array.arange(20, chunks=3))}, + ) + + gb = data.groupby(y=UniqueGrouper(labels=np.arange(5, 10))) + message = "not supported when lazily grouping by" + with pytest.raises(ValueError, match=message): + gb.map(lambda x: x) + + with pytest.raises(ValueError, match=message): + gb.reduce(np.mean) + + with pytest.raises(ValueError, match=message): + for _, _ in gb: + pass + + @requires_dask def test_lazy_int_bins_error(): import dask.array From 17b7f2f2406a5767d6a21d6c00c4d82527dd4da3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 13:32:43 -0600 Subject: [PATCH 07/17] Add docstring --- doc/user-guide/groupby.rst | 9 +++++++++ xarray/groupers.py | 12 +++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index 98bd7b4833b..069c7e0cb10 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -294,6 +294,15 @@ is identical to ds.resample(time=TimeResampler("ME")) +The :py:class:`groupers.UniqueGrouper` accepts an optional ``labels`` kwarg that is not present +in :py:meth:`DataArray.groupby` or :py:meth:`Dataset.groupby`. +Specifying ``labels`` is required when grouping by a lazy array type (e.g. dask or cubed). +The ``labels`` are used to construct the output coordinate (say for a reduction), and aggregations +will only be run over the specified labels. +You may use ``labels`` to also specify the ordering of groups to be used during iteration. +The order will be preserved in the output. + + .. _groupby.multiple: Grouping by multiple variables diff --git a/xarray/groupers.py b/xarray/groupers.py index c4bf2d02d9b..c930cfac846 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -155,7 +155,17 @@ class Resampler(Grouper): @dataclass class UniqueGrouper(Grouper): - """Grouper object for grouping by a categorical variable.""" + """ + Grouper object for grouping by a categorical variable. + + Parameters + ---------- + labels: array-like, optional + Group labels to aggregate on. This is required when grouping by a chunked array type + (e.g. dask or cubed) since it is used to construct the coordinate on the output. + Grouped operations will only be run on the specified group labels. Any group that is not + present in ``labels`` will be ignored. + """ _group_as_index: pd.Index | None = field(default=None, repr=False) labels: np.ndarray | None = field(default=None) From 339ed3a80577b2a3d18562b95215149aa592e491 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 13:41:13 -0600 Subject: [PATCH 08/17] preserve attrs --- xarray/groupers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index c930cfac846..829ebc298ca 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -210,6 +210,7 @@ def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: kwargs={"labels": self.labels}, dask="parallelized", output_dtypes=[np.int64], + keep_attrs=True, ) return EncodedGroups( codes=codes, @@ -350,7 +351,7 @@ def _wrapper(data, **kwargs): self.bins = bins return binned.codes.reshape(data.shape) - return apply_ufunc(_wrapper, group, dask="parallelized") + return apply_ufunc(_wrapper, group, dask="parallelized", keep_attrs=True) def factorize(self, group: T_Group) -> EncodedGroups: if isinstance(group, _DummyGroup): From 93e786b53ffed69f07ccbb1e6394a9efbce81b75 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 20:27:18 -0600 Subject: [PATCH 09/17] Add test for #757 --- xarray/tests/test_groupby.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 8b8d01baee3..2a81dd460e2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3004,6 +3004,15 @@ def test_lazy_int_bins_error(): _ = BinGrouper(bins=4).factorize(DataArray(dask.array.arange(3))) +def test_time_grouping_seasons_specified(): + time = xr.date_range("2001-01-01", "2002-01-01", freq="D") + ds = xr.Dataset({"foo": np.arange(time.size)}, coords={"time": ("time", time)}) + labels = ["DJF", "MAM", "JJA", "SON"] + actual = ds.groupby({"time.season": UniqueGrouper(labels=labels)}).sum() + expected = ds.groupby("time.season").sum() + assert_identical(actual, expected.reindex(season=labels)) + + # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array From dfdc96a10865f1eca795ea708183408048b3100e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 20:37:58 -0600 Subject: [PATCH 10/17] Typing fixes --- xarray/core/groupby.py | 5 +++-- xarray/groupers.py | 27 ++++++++++++++++----------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 94101874b5d..a2c77732692 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -191,8 +191,8 @@ def values(self) -> range: return range(self.size) @property - def data(self) -> range: - return range(self.size) + def data(self) -> np.ndarray: + return np.arange(self.size, dtype=int) def __array__(self) -> np.ndarray: return np.arange(self.size) @@ -517,6 +517,7 @@ class GroupBy(Generic[T_Xarray]): "_dims", "_sizes", "_len", + "_by_chunked", # Save unstacked object for flox "_original_obj", "_codes", diff --git a/xarray/groupers.py b/xarray/groupers.py index 829ebc298ca..75f8a5b4305 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd +from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops @@ -99,12 +100,18 @@ def __init__( assert isinstance(full_index, pd.Index) self.full_index = full_index - if group_indices is None and not is_chunked_array(codes.data): - self.group_indices = tuple( - g - for g in _codes_to_group_indices(codes.data.ravel(), len(full_index)) - if g - ) + if group_indices is None: + if not is_chunked_array(codes.data): + self.group_indices = tuple( + g + for g in _codes_to_group_indices( + codes.data.ravel(), len(full_index) + ) + if g + ) + else: + # We will not use this when grouping by a chunked array + self.group_indices = tuple() else: self.group_indices = group_indices @@ -168,13 +175,11 @@ class UniqueGrouper(Grouper): """ _group_as_index: pd.Index | None = field(default=None, repr=False) - labels: np.ndarray | None = field(default=None) + labels: ArrayLike | None = field(default=None) @property def group_as_index(self) -> pd.Index: """Caches the group DataArray as a pandas Index.""" - if is_chunked_array(self.group): - raise ValueError("Please call compute manually.") if self._group_as_index is None: if self.group.ndim == 1: self._group_as_index = self.group.to_index() @@ -214,7 +219,7 @@ def _factorize_given_labels(self, group: T_Group) -> EncodedGroups: ) return EncodedGroups( codes=codes, - full_index=pd.Index(self.labels), + full_index=pd.Index(self.labels), # type: ignore[arg-type] unique_coord=Variable( dims=codes.name, data=self.labels, @@ -332,7 +337,7 @@ def __post_init__(self) -> None: raise ValueError("All bin edges are NaN.") def _cut(self, data): - return pd.cut( # type: ignore [call-overload] + return pd.cut( np.asarray(data).ravel(), bins=self.bins, right=self.right, From a15b04dff6246985db191efd94b889c679385985 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Sep 2024 20:49:52 -0600 Subject: [PATCH 11/17] Handle multiple groupers --- xarray/core/groupby.py | 27 +++++++++++++++++---------- xarray/groupers.py | 5 ++++- xarray/tests/test_groupby.py | 20 ++++++++++++++++++-- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a2c77732692..bbbf65f3529 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -22,6 +22,7 @@ from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat from xarray.core.coordinates import Coordinates +from xarray.core.duck_array_ops import where from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( PandasIndex, @@ -462,20 +463,26 @@ def factorize(self) -> EncodedGroups: # NaNs; as well as values outside the bins are coded by -1 # Restore these after the raveling mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore[arg-type] - _flatcodes[mask] = -1 - - midx = pd.MultiIndex.from_product( - (grouper.unique_coord.data for grouper in groupers), - names=tuple(grouper.name for grouper in groupers), - ) - # Constructing an index from the product is wrong when there are missing groups - # (e.g. binning, resampling). Account for that now. - midx = midx[np.sort(pd.unique(_flatcodes[~mask]))] + _flatcodes = where(mask, -1, _flatcodes) full_index = pd.MultiIndex.from_product( (grouper.full_index.values for grouper in groupers), names=tuple(grouper.name for grouper in groupers), ) + # This will be unused when grouping by dask arrays, so skip.. + if not is_chunked_array(_flatcodes): + midx = pd.MultiIndex.from_product( + (grouper.unique_coord.data for grouper in groupers), + names=tuple(grouper.name for grouper in groupers), + ) + # Constructing an index from the product is wrong when there are missing groups + # (e.g. binning, resampling). Account for that now. + midx = midx[np.sort(pd.unique(_flatcodes[~mask]))] + group_indices = _codes_to_group_indices(_flatcodes.ravel(), len(full_index)) + else: + midx = full_index + group_indices = None + dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers) coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name) @@ -484,7 +491,7 @@ def factorize(self) -> EncodedGroups: return EncodedGroups( codes=first_codes.copy(data=_flatcodes), full_index=full_index, - group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)), + group_indices=group_indices, unique_coord=Variable(dims=(dim_name,), data=midx.values), coords=coords, ) diff --git a/xarray/groupers.py b/xarray/groupers.py index 75f8a5b4305..8168fed307e 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -191,7 +191,10 @@ def factorize(self, group: T_Group) -> EncodedGroups: self.group = group if is_chunked_array(group.data) and self.labels is None: - raise ValueError("When grouping by a dask array, `labels` must be passed.") + raise ValueError( + "When grouping by a dask array, `labels` must be passed using " + "a UniqueGrouper object." + ) if self.labels is not None: return self._factorize_given_labels(group) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 2a81dd460e2..adccf5da132 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -22,6 +22,7 @@ TimeResampler, UniqueGrouper, ) +from xarray.namedarray.pycompat import is_chunked_array from xarray.tests import ( InaccessibleArray, assert_allclose, @@ -29,6 +30,7 @@ assert_identical, create_test_data, has_cftime, + has_dask, has_flox, has_pandas_ge_2_2, raise_if_dask_computes, @@ -2796,7 +2798,7 @@ def test_multiple_groupers(use_flox) -> None: b = xr.DataArray( np.random.RandomState(0).randn(2, 3, 4), - coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]])}, + coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})}, dims=["x", "y", "z"], ) gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper()) @@ -2813,10 +2815,24 @@ def test_multiple_groupers(use_flox) -> None: expected.loc[dict(x=1, xy=1)] = expected.sel(x=1, xy=0).data expected.loc[dict(x=1, xy=0)] = np.nan expected.loc[dict(x=1, xy=2)] = newval - expected["xy"] = ("xy", ["a", "b", "c"]) + expected["xy"] = ("xy", ["a", "b", "c"], {"foo": "bar"}) # TODO: is order of dims correct? assert_identical(actual, expected.transpose("z", "x", "xy")) + if has_dask: + b["xy"] = b["xy"].chunk() + with raise_if_dask_computes(): + gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"])) + + expected = xr.DataArray( + [[[1, 1, 1], [0, 1, 2]]] * 4, + dims=("z", "x", "xy"), + coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})}, + ) + assert_identical(gb.count(), expected) + assert is_chunked_array(gb.encoded.codes.data) + assert not gb.encoded.group_indices + @pytest.mark.parametrize("use_flox", [True, False]) def test_multiple_groupers_mixed(use_flox) -> None: From b295193819395e67890fe72ddba5fa857d8a7875 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 22 Oct 2024 07:04:01 -0600 Subject: [PATCH 12/17] Backcompat --- doc/whats-new.rst | 7 ++++++- xarray/core/common.py | 2 +- xarray/core/dataarray.py | 10 ++++++++-- xarray/core/dataset.py | 10 ++++++++-- xarray/core/groupby.py | 30 +++++++++++++++++++++++++++--- xarray/tests/test_groupby.py | 15 ++++++++++++--- 6 files changed, 62 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3e9d9a76b89..eede5b88e0a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,7 @@ New Features By `Tom Nicholas `_. - Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`). By `Eni Awowale `_. -- Support lazy grouping by dask arrays, and allow specifying groups with ``UniqueGrouper(labels=["a", "b", "c"])`` +- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])`` (:issue:`2852`, :issue:`757`). By `Deepak Cherian `_. - Added support for vectorized interpolation using additional interpolators @@ -49,6 +49,11 @@ Breaking changes Deprecations ~~~~~~~~~~~~ +- Grouping by a chunked array (e.g. dask or cubed) currently eagerly loads that variable in to + memory. This behaviour is deprecated. If eager loading was intended, please load such arrays + manually using ``.load()`` or ``.compute()``. Else pass ``eagerly_compute_group=False``, and + provide expected group labels using the ``labels`` kwarg to a grouper object such as + :py:class:`grouper.UniqueGrouper` or :py:class:`grouper.BinGrouper`. Bug fixes diff --git a/xarray/core/common.py b/xarray/core/common.py index 9a6807faad2..6966b1723d3 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1094,7 +1094,7 @@ def _resample( f"Received {type(freq)} instead." ) - rgrouper = ResolvedGrouper(grouper, group, self) + rgrouper = ResolvedGrouper(grouper, group, self, eagerly_compute_group=False) return resample_cls( self, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 9b5291fc553..a4d03e3ce73 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6733,6 +6733,7 @@ def groupby( *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, + eagerly_compute_group: bool = True, **groupers: Grouper, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6862,7 +6863,9 @@ def groupby( ) _validate_groupby_squeeze(squeeze) - rgroupers = _parse_group_and_groupers(self, group, groupers) + rgroupers = _parse_group_and_groupers( + self, group, groupers, eagerly_compute_group=eagerly_compute_group + ) return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @_deprecate_positional_args("v2024.07.0") @@ -6877,6 +6880,7 @@ def groupby_bins( squeeze: Literal[False] = False, restore_coord_dims: bool = False, duplicates: Literal["raise", "drop"] = "raise", + eagerly_compute_group: bool = True, ) -> DataArrayGroupBy: """Returns a DataArrayGroupBy object for performing grouped operations. @@ -6950,7 +6954,9 @@ def groupby_bins( precision=precision, include_lowest=include_lowest, ) - rgrouper = ResolvedGrouper(grouper, group, self) + rgrouper = ResolvedGrouper( + grouper, group, self, eagerly_compute_group=eagerly_compute_group + ) return DataArrayGroupBy( self, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index f8cf23d188c..5cdeb8049d7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10366,6 +10366,7 @@ def groupby( *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, + eagerly_compute_group: bool = True, **groupers: Grouper, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10463,7 +10464,9 @@ def groupby( ) _validate_groupby_squeeze(squeeze) - rgroupers = _parse_group_and_groupers(self, group, groupers) + rgroupers = _parse_group_and_groupers( + self, group, groupers, eagerly_compute_group=eagerly_compute_group + ) return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @@ -10479,6 +10482,7 @@ def groupby_bins( squeeze: Literal[False] = False, restore_coord_dims: bool = False, duplicates: Literal["raise", "drop"] = "raise", + eagerly_compute_group: bool = True, ) -> DatasetGroupBy: """Returns a DatasetGroupBy object for performing grouped operations. @@ -10552,7 +10556,9 @@ def groupby_bins( precision=precision, include_lowest=include_lowest, ) - rgrouper = ResolvedGrouper(grouper, group, self) + rgrouper = ResolvedGrouper( + grouper, group, self, eagerly_compute_group=eagerly_compute_group + ) return DatasetGroupBy( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 180dfdb8a4a..e63243eae70 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -41,6 +41,7 @@ FrozenMappingWarningOnValuesAccess, contains_only_chunked_or_numpy, either_dict_or_kwargs, + emit_user_level_warning, hashable, is_scalar, maybe_wrap_array, @@ -284,6 +285,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): grouper: Grouper group: T_Group obj: T_DataWithCoords + eagerly_compute_group: bool = field(repr=False) # returned by factorize: encoded: EncodedGroups = field(init=False, repr=False) @@ -310,6 +312,18 @@ def __post_init__(self) -> None: self.group = _resolve_group(self.obj, self.group) + if ( + self.eagerly_compute_group + and not isinstance(self.group, _DummyGroup) + and is_chunked_array(self.group.variable._data) + ): + emit_user_level_warning( + f"Eagerly computing the DataArray you're grouping by ({self.group.name!r}) " + "is deprecated and will be removed in v2025.05.0. " + "Please load this array's data manually using `.compute` or `.load`.", + DeprecationWarning, + ) + self.encoded = self.grouper.factorize(self.group) @property @@ -330,7 +344,11 @@ def __len__(self) -> int: def _parse_group_and_groupers( - obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper] + obj: T_Xarray, + group: GroupInput, + groupers: dict[str, Grouper], + *, + eagerly_compute_group: bool, ) -> tuple[ResolvedGrouper, ...]: from xarray.core.dataarray import DataArray from xarray.core.variable import Variable @@ -355,7 +373,11 @@ def _parse_group_and_groupers( rgroupers: tuple[ResolvedGrouper, ...] if isinstance(group, DataArray | Variable): - rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),) + rgroupers = ( + ResolvedGrouper( + UniqueGrouper(), group, obj, eagerly_compute_group=eagerly_compute_group + ), + ) else: if group is not None: if TYPE_CHECKING: @@ -368,7 +390,9 @@ def _parse_group_and_groupers( grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers) rgroupers = tuple( - ResolvedGrouper(grouper, group, obj) + ResolvedGrouper( + grouper, group, obj, eagerly_compute_group=eagerly_compute_group + ) for group, grouper in grouper_mapping.items() ) return rgroupers diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 95b963f6cb3..edf0f9df176 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2843,7 +2843,10 @@ def test_multiple_groupers(use_flox) -> None: if has_dask: b["xy"] = b["xy"].chunk() with raise_if_dask_computes(): - gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"])) + with pytest.warns(DeprecationWarning): + gb = b.groupby( + x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"]) + ) expected = xr.DataArray( [[[1, 1, 1], [0, 1, 2]]] * 4, @@ -2994,7 +2997,11 @@ def test_lazy_grouping(grouper, expect_index): pd.testing.assert_index_equal(encoded.full_index, expect_index) np.testing.assert_array_equal(encoded.unique_coord.values, np.array(expect_index)) - lazy = xr.Dataset({"foo": data}, coords={"zoo": data}).groupby(zoo=grouper).count() + lazy = ( + xr.Dataset({"foo": data}, coords={"zoo": data}) + .groupby(zoo=grouper, eagerly_compute_group=False) + .count() + ) eager = ( xr.Dataset({"foo": data}, coords={"zoo": data.compute()}) .groupby(zoo=grouper) @@ -3019,7 +3026,9 @@ def test_lazy_grouping_errors(): coords={"y": ("x", dask.array.arange(20, chunks=3))}, ) - gb = data.groupby(y=UniqueGrouper(labels=np.arange(5, 10))) + gb = data.groupby( + y=UniqueGrouper(labels=np.arange(5, 10)), eagerly_compute_group=False + ) message = "not supported when lazily grouping by" with pytest.raises(ValueError, match=message): gb.map(lambda x: x) From f826b65b16a57f206b4e9489712eaab38161eb01 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 23 Oct 2024 09:31:00 -0600 Subject: [PATCH 13/17] better backcompat --- xarray/core/dataarray.py | 10 ++++++++ xarray/core/dataset.py | 10 ++++++++ xarray/core/groupby.py | 49 +++++++++++++++++++++++++++--------- xarray/tests/test_groupby.py | 35 ++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 12 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a4d03e3ce73..77c00bff55d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6749,6 +6749,11 @@ def groupby( restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. **groupers : Mapping of str to Grouper or Resampler Mapping of variable name to group by to :py:class:`Grouper` or :py:class:`Resampler` object. One of ``group`` or ``groupers`` must be provided. @@ -6917,6 +6922,11 @@ def groupby_bins( coordinates. duplicates : {"raise", "drop"}, default: "raise" If bin edges are not unique, raise ValueError or drop non-uniques. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. Returns ------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5cdeb8049d7..415bf350e7d 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10382,6 +10382,11 @@ def groupby( restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. **groupers : Mapping of str to Grouper or Resampler Mapping of variable name to group by to :py:class:`Grouper` or :py:class:`Resampler` object. One of ``group`` or ``groupers`` must be provided. @@ -10519,6 +10524,11 @@ def groupby_bins( coordinates. duplicates : {"raise", "drop"}, default: "raise" If bin edges are not unique, raise ValueError or drop non-uniques. + eagerly_compute_group: bool + Whether to eagerly compute ``group`` when it is a chunked array. + This option is to maintain backwards compatibility. Set to False + to opt-in to future behaviour, where ``group`` is not automatically loaded + into memory. Returns ------- diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e63243eae70..68362ccd187 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -308,21 +308,45 @@ def __post_init__(self) -> None: # of pd.cut # We do not want to modify the original object, since the same grouper # might be used multiple times. + from xarray.groupers import BinGrouper, UniqueGrouper + self.grouper = copy.deepcopy(self.grouper) self.group = _resolve_group(self.obj, self.group) - if ( - self.eagerly_compute_group - and not isinstance(self.group, _DummyGroup) - and is_chunked_array(self.group.variable._data) + if not isinstance(self.group, _DummyGroup) and is_chunked_array( + self.group.variable._data ): - emit_user_level_warning( - f"Eagerly computing the DataArray you're grouping by ({self.group.name!r}) " - "is deprecated and will be removed in v2025.05.0. " - "Please load this array's data manually using `.compute` or `.load`.", - DeprecationWarning, - ) + if self.eagerly_compute_group is False: + # This requires a pass to discover the groups present + if ( + isinstance(self.grouper, UniqueGrouper) + and self.grouper.labels is None + ): + raise ValueError( + "Please pass `labels` to UniqueGrouper when grouping by a chunked array." + ) + # this requires a pass to compute the bin edges + if isinstance(self.grouper, BinGrouper) and isinstance( + self.grouper.bins, int + ): + raise ValueError( + "Please pass explicit bin edges to BinGrouper using the ``bins`` kwarg" + "when grouping by a chunked array." + ) + + if self.eagerly_compute_group: + emit_user_level_warning( + f""""Eagerly computing the DataArray you're grouping by ({self.group.name!r}) " + is deprecated and will raise an error in v2025.05.0. + Please load this array's data manually using `.compute` or `.load`. + To intentionally avoid eager loading, either (1) specify + `.groupby({self.group.name}=UniqueGrouper(labels=...), eagerly_load_group=False)` + or (2) pass explicit bin edges using or `.groupby({self.group.name}=BinGrouper(bins=...), + eagerly_load_group=False)`; as appropriate.""", + DeprecationWarning, + ) + self.group = self.group.compute() self.encoded = self.grouper.factorize(self.group) @@ -678,8 +702,9 @@ def _raise_if_by_is_chunked(self): if self._by_chunked: raise ValueError( "This method is not supported when lazily grouping by a chunked array. " - "Either load the array in to memory prior to grouping, or explore another " - "way of applying your function, potentially using the `flox` package." + "Either load the array in to memory prior to grouping using .load or .compute, " + " or explore another way of applying your function, " + "potentially using the `flox` package." ) def _raise_if_not_single_group(self): diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index edf0f9df176..b69682c18c7 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3091,6 +3091,41 @@ def test_groupby_multiple_bin_grouper_missing_groups(): assert_identical(actual, expected) +@requires_dask +def test_groupby_dask_eager_load_warnings(): + ds = xr.Dataset( + {"foo": (("z"), np.arange(12))}, + coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))}, + ).chunk(z=6) + + with pytest.warns(DeprecationWarning): + ds.groupby(x=UniqueGrouper()) + + with pytest.warns(DeprecationWarning): + ds.groupby("x") + + with pytest.warns(DeprecationWarning): + ds.groupby(ds.x) + + with pytest.raises(ValueError, match="Please pass"): + ds.groupby("x", eagerly_compute_group=False) + + # This is technically fine but anyone iterating over the groupby object + # will see an error, so let's warn and have them opt-in. + with pytest.warns(DeprecationWarning): + ds.groupby(x=UniqueGrouper(labels=[1, 2, 3])) + + ds.groupby(x=UniqueGrouper(labels=[1, 2, 3]), eagerly_compute_group=False) + + with pytest.warns(DeprecationWarning): + ds.groupby_bins("x", bins=3) + with pytest.raises(ValueError, match="Please pass"): + ds.groupby_bins("x", bins=3, eagerly_compute_group=False) + with pytest.warns(DeprecationWarning): + ds.groupby_bins("x", bins=[1, 2, 3]) + ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) + + # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array From aada75dd0682898f1335a0cf533c50b412f99cf8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 29 Oct 2024 16:35:48 -0700 Subject: [PATCH 14/17] fix --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b69682c18c7..b3b2f0389d9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2842,7 +2842,7 @@ def test_multiple_groupers(use_flox) -> None: if has_dask: b["xy"] = b["xy"].chunk() - with raise_if_dask_computes(): + with raise_if_dask_computes(max_computes=1): with pytest.warns(DeprecationWarning): gb = b.groupby( x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"]) From 3e40605754619fb0e7a9251d196b7e2e0a683cef Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 31 Oct 2024 17:28:26 -0700 Subject: [PATCH 15/17] Handle edge case --- xarray/core/groupby.py | 6 +++++- xarray/tests/test_groupby.py | 34 ++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 68362ccd187..5c4633c1612 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -995,11 +995,15 @@ def _flox_reduce( has_missing_groups = ( self.encoded.unique_coord.size != self.encoded.full_index.size ) - if has_missing_groups or kwargs.get("min_count", 0) > 0: + if self._by_chunked or has_missing_groups or kwargs.get("min_count", 0) > 0: # Xarray *always* returns np.nan when there are no observations in a group, # We can fake that here by forcing min_count=1 when it is not set. # This handles boolean reductions, and count # See GH8090, GH9398 + # Note that `has_missing_groups=False` when `self._by_chunked is True`. + # We *choose* to always do the masking, so that behaviour is predictable + # in some way. The real solution is to expose fill_value as a kwarg, + # and set appopriate defaults :/. kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b3b2f0389d9..e8a81460ab9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2842,20 +2842,26 @@ def test_multiple_groupers(use_flox) -> None: if has_dask: b["xy"] = b["xy"].chunk() - with raise_if_dask_computes(max_computes=1): - with pytest.warns(DeprecationWarning): - gb = b.groupby( - x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"]) - ) - - expected = xr.DataArray( - [[[1, 1, 1], [0, 1, 2]]] * 4, - dims=("z", "x", "xy"), - coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})}, - ) - assert_identical(gb.count(), expected) - assert is_chunked_array(gb.encoded.codes.data) - assert not gb.encoded.group_indices + for eagerly_compute_group in [True, False]: + kwargs = dict( + x=UniqueGrouper(), + xy=UniqueGrouper(labels=["a", "b", "c"]), + eagerly_compute_group=eagerly_compute_group, + ) + with raise_if_dask_computes(max_computes=1): + if eagerly_compute_group: + with pytest.warns(DeprecationWarning): + gb = b.groupby(**kwargs) + else: + gb = b.groupby(**kwargs) + assert is_chunked_array(gb.encoded.codes.data) + assert not gb.encoded.group_indices + expected = xr.DataArray( + [[[1, 1, 1], [np.nan, 1, 2]]] * 4, + dims=("z", "x", "xy"), + coords={"xy": ("xy", ["a", "b", "c"], {"foo": "bar"})}, + ) + assert_identical(gb.count(), expected) @pytest.mark.parametrize("use_flox", [True, False]) From 295d6dd930cedb4b084e5ee428b7adeac649830e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 31 Oct 2024 17:28:57 -0700 Subject: [PATCH 16/17] comment --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e8a81460ab9..14974ae1dd2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3132,7 +3132,7 @@ def test_groupby_dask_eager_load_warnings(): ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) -# Possible property tests +# TODO: Possible property tests to add to this module # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array # 3. group_over == groupby-reduce along other dimensions From a4fed4d1f5fd28986975a834b3e7ef041fec9c83 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 31 Oct 2024 18:00:38 -0700 Subject: [PATCH 17/17] type: ignore --- xarray/tests/test_groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 14974ae1dd2..95ba0a3384e 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2851,9 +2851,9 @@ def test_multiple_groupers(use_flox) -> None: with raise_if_dask_computes(max_computes=1): if eagerly_compute_group: with pytest.warns(DeprecationWarning): - gb = b.groupby(**kwargs) + gb = b.groupby(**kwargs) # type: ignore[arg-type] else: - gb = b.groupby(**kwargs) + gb = b.groupby(**kwargs) # type: ignore[arg-type] assert is_chunked_array(gb.encoded.codes.data) assert not gb.encoded.group_indices expected = xr.DataArray(