Skip to content

Commit

Permalink
Backcompat
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 22, 2024
1 parent c8c27f7 commit b295193
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 12 deletions.
7 changes: 6 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ New Features
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`).
By `Eni Awowale <https://github.com/eni-awowale>`_.
- Support lazy grouping by dask arrays, and allow specifying groups with ``UniqueGrouper(labels=["a", "b", "c"])``
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Added support for vectorized interpolation using additional interpolators
Expand All @@ -49,6 +49,11 @@ Breaking changes

Deprecations
~~~~~~~~~~~~
- Grouping by a chunked array (e.g. dask or cubed) currently eagerly loads that variable in to
memory. This behaviour is deprecated. If eager loading was intended, please load such arrays
manually using ``.load()`` or ``.compute()``. Else pass ``eagerly_compute_group=False``, and
provide expected group labels using the ``labels`` kwarg to a grouper object such as
:py:class:`grouper.UniqueGrouper` or :py:class:`grouper.BinGrouper`.


Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def _resample(
f"Received {type(freq)} instead."
)

rgrouper = ResolvedGrouper(grouper, group, self)
rgrouper = ResolvedGrouper(grouper, group, self, eagerly_compute_group=False)

return resample_cls(
self,
Expand Down
10 changes: 8 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6733,6 +6733,7 @@ def groupby(
*,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
eagerly_compute_group: bool = True,
**groupers: Grouper,
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -6862,7 +6863,9 @@ def groupby(
)

_validate_groupby_squeeze(squeeze)
rgroupers = _parse_group_and_groupers(self, group, groupers)
rgroupers = _parse_group_and_groupers(
self, group, groupers, eagerly_compute_group=eagerly_compute_group
)
return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

@_deprecate_positional_args("v2024.07.0")
Expand All @@ -6877,6 +6880,7 @@ def groupby_bins(
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
duplicates: Literal["raise", "drop"] = "raise",
eagerly_compute_group: bool = True,
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -6950,7 +6954,9 @@ def groupby_bins(
precision=precision,
include_lowest=include_lowest,
)
rgrouper = ResolvedGrouper(grouper, group, self)
rgrouper = ResolvedGrouper(
grouper, group, self, eagerly_compute_group=eagerly_compute_group
)

return DataArrayGroupBy(
self,
Expand Down
10 changes: 8 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10366,6 +10366,7 @@ def groupby(
*,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
eagerly_compute_group: bool = True,
**groupers: Grouper,
) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -10463,7 +10464,9 @@ def groupby(
)

_validate_groupby_squeeze(squeeze)
rgroupers = _parse_group_and_groupers(self, group, groupers)
rgroupers = _parse_group_and_groupers(
self, group, groupers, eagerly_compute_group=eagerly_compute_group
)

return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

Expand All @@ -10479,6 +10482,7 @@ def groupby_bins(
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
duplicates: Literal["raise", "drop"] = "raise",
eagerly_compute_group: bool = True,
) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -10552,7 +10556,9 @@ def groupby_bins(
precision=precision,
include_lowest=include_lowest,
)
rgrouper = ResolvedGrouper(grouper, group, self)
rgrouper = ResolvedGrouper(
grouper, group, self, eagerly_compute_group=eagerly_compute_group
)

return DatasetGroupBy(
self,
Expand Down
30 changes: 27 additions & 3 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FrozenMappingWarningOnValuesAccess,
contains_only_chunked_or_numpy,
either_dict_or_kwargs,
emit_user_level_warning,
hashable,
is_scalar,
maybe_wrap_array,
Expand Down Expand Up @@ -284,6 +285,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
grouper: Grouper
group: T_Group
obj: T_DataWithCoords
eagerly_compute_group: bool = field(repr=False)

# returned by factorize:
encoded: EncodedGroups = field(init=False, repr=False)
Expand All @@ -310,6 +312,18 @@ def __post_init__(self) -> None:

self.group = _resolve_group(self.obj, self.group)

if (
self.eagerly_compute_group
and not isinstance(self.group, _DummyGroup)
and is_chunked_array(self.group.variable._data)
):
emit_user_level_warning(
f"Eagerly computing the DataArray you're grouping by ({self.group.name!r}) "
"is deprecated and will be removed in v2025.05.0. "
"Please load this array's data manually using `.compute` or `.load`.",
DeprecationWarning,
)

self.encoded = self.grouper.factorize(self.group)

@property
Expand All @@ -330,7 +344,11 @@ def __len__(self) -> int:


def _parse_group_and_groupers(
obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper]
obj: T_Xarray,
group: GroupInput,
groupers: dict[str, Grouper],
*,
eagerly_compute_group: bool,
) -> tuple[ResolvedGrouper, ...]:
from xarray.core.dataarray import DataArray
from xarray.core.variable import Variable
Expand All @@ -355,7 +373,11 @@ def _parse_group_and_groupers(

rgroupers: tuple[ResolvedGrouper, ...]
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),)
rgroupers = (
ResolvedGrouper(
UniqueGrouper(), group, obj, eagerly_compute_group=eagerly_compute_group
),
)
else:
if group is not None:
if TYPE_CHECKING:
Expand All @@ -368,7 +390,9 @@ def _parse_group_and_groupers(
grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers)

rgroupers = tuple(
ResolvedGrouper(grouper, group, obj)
ResolvedGrouper(
grouper, group, obj, eagerly_compute_group=eagerly_compute_group
)
for group, grouper in grouper_mapping.items()
)
return rgroupers
Expand Down
15 changes: 12 additions & 3 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2843,7 +2843,10 @@ def test_multiple_groupers(use_flox) -> None:
if has_dask:
b["xy"] = b["xy"].chunk()
with raise_if_dask_computes():
gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"]))
with pytest.warns(DeprecationWarning):
gb = b.groupby(
x=UniqueGrouper(), xy=UniqueGrouper(labels=["a", "b", "c"])
)

expected = xr.DataArray(
[[[1, 1, 1], [0, 1, 2]]] * 4,
Expand Down Expand Up @@ -2994,7 +2997,11 @@ def test_lazy_grouping(grouper, expect_index):
pd.testing.assert_index_equal(encoded.full_index, expect_index)
np.testing.assert_array_equal(encoded.unique_coord.values, np.array(expect_index))

lazy = xr.Dataset({"foo": data}, coords={"zoo": data}).groupby(zoo=grouper).count()
lazy = (
xr.Dataset({"foo": data}, coords={"zoo": data})
.groupby(zoo=grouper, eagerly_compute_group=False)
.count()
)
eager = (
xr.Dataset({"foo": data}, coords={"zoo": data.compute()})
.groupby(zoo=grouper)
Expand All @@ -3019,7 +3026,9 @@ def test_lazy_grouping_errors():
coords={"y": ("x", dask.array.arange(20, chunks=3))},
)

gb = data.groupby(y=UniqueGrouper(labels=np.arange(5, 10)))
gb = data.groupby(
y=UniqueGrouper(labels=np.arange(5, 10)), eagerly_compute_group=False
)
message = "not supported when lazily grouping by"
with pytest.raises(ValueError, match=message):
gb.map(lambda x: x)
Expand Down

0 comments on commit b295193

Please sign in to comment.