Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupBy(chunked-array) #9522

Merged
merged 20 commits into from
Nov 4, 2024
Merged
9 changes: 9 additions & 0 deletions doc/user-guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,15 @@ is identical to
ds.resample(time=TimeResampler("ME"))


The :py:class:`groupers.UniqueGrouper` accepts an optional ``labels`` kwarg that is not present
in :py:meth:`DataArray.groupby` or :py:meth:`Dataset.groupby`.
Specifying ``labels`` is required when grouping by a lazy array type (e.g. dask or cubed).
The ``labels`` are used to construct the output coordinate (say for a reduction), and aggregations
will only be run over the specified labels.
You may use ``labels`` to also specify the ordering of groups to be used during iteration.
The order will be preserved in the output.


.. _groupby.multiple:

Grouping by multiple variables
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ New Features
`Tom Nicholas <https://github.com/TomNicholas>`_.
- Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`).
By `Eni Awowale <https://github.com/eni-awowale>`_.
- Support lazy grouping by dask arrays, and allow specifying groups with ``UniqueGrouper(labels=["a", "b", "c"])``
(:issue:`2852`, :issue:`757`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
66 changes: 52 additions & 14 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.duck_array_ops import where
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
Expand All @@ -47,6 +48,7 @@
peek_at,
)
from xarray.core.variable import IndexVariable, Variable
from xarray.namedarray.pycompat import is_chunked_array
from xarray.util.deprecation_helpers import _deprecate_positional_args

if TYPE_CHECKING:
Expand Down Expand Up @@ -190,8 +192,8 @@ def values(self) -> range:
return range(self.size)

@property
def data(self) -> range:
return range(self.size)
def data(self) -> np.ndarray:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for typing

return np.arange(self.size, dtype=int)

def __array__(self) -> np.ndarray:
return np.arange(self.size)
Expand Down Expand Up @@ -249,7 +251,9 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
stacked_dim = "stacked_" + "_".join(map(str, orig_dims))
# these dimensions get created by the stack operation
inserted_dims = [dim for dim in group.dims if dim not in group.coords]
newgroup = group.stack({stacked_dim: orig_dims})
# `newgroup` construction is optimized so we don't create an index unnecessarily,
# or stack any non-dim coords unnecessarily
newgroup = DataArray(group.variable.stack({stacked_dim: orig_dims}))
newobj = obj.stack({stacked_dim: orig_dims})
return newgroup, newobj, stacked_dim, inserted_dims

Expand Down Expand Up @@ -459,20 +463,26 @@ def factorize(self) -> EncodedGroups:
# NaNs; as well as values outside the bins are coded by -1
# Restore these after the raveling
mask = functools.reduce(np.logical_or, [(code == -1) for code in broadcasted_codes]) # type: ignore[arg-type]
_flatcodes[mask] = -1

midx = pd.MultiIndex.from_product(
(grouper.unique_coord.data for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# Constructing an index from the product is wrong when there are missing groups
# (e.g. binning, resampling). Account for that now.
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
_flatcodes = where(mask, -1, _flatcodes)

full_index = pd.MultiIndex.from_product(
(grouper.full_index.values for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# This will be unused when grouping by dask arrays, so skip..
if not is_chunked_array(_flatcodes):
midx = pd.MultiIndex.from_product(
(grouper.unique_coord.data for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# Constructing an index from the product is wrong when there are missing groups
# (e.g. binning, resampling). Account for that now.
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
group_indices = _codes_to_group_indices(_flatcodes.ravel(), len(full_index))
else:
midx = full_index
group_indices = None

dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers)

coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name)
Expand All @@ -481,7 +491,7 @@ def factorize(self) -> EncodedGroups:
return EncodedGroups(
codes=first_codes.copy(data=_flatcodes),
full_index=full_index,
group_indices=_codes_to_group_indices(_flatcodes.ravel(), len(full_index)),
group_indices=group_indices,
unique_coord=Variable(dims=(dim_name,), data=midx.values),
coords=coords,
)
Expand Down Expand Up @@ -514,6 +524,7 @@ class GroupBy(Generic[T_Xarray]):
"_dims",
"_sizes",
"_len",
"_by_chunked",
# Save unstacked object for flox
"_original_obj",
"_codes",
Expand All @@ -531,6 +542,7 @@ class GroupBy(Generic[T_Xarray]):
_group_indices: GroupIndices
_codes: tuple[DataArray, ...]
_group_dim: Hashable
_by_chunked: bool

_groups: dict[GroupKey, GroupIndex] | None
_dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None
Expand Down Expand Up @@ -583,6 +595,7 @@ def __init__(
# specification for the groupby operation
# TODO: handle obj having variables that are not present on any of the groupers
# simple broadcasting fails for ExtensionArrays.
# FIXME: Skip this stacking when grouping by a dask array, it's useless in that case.
(self.group1d, self._obj, self._stacked_dim, self._inserted_dims) = _ensure_1d(
group=self.encoded.codes, obj=obj
)
Expand All @@ -593,6 +606,7 @@ def __init__(
self._dims = None
self._sizes = None
self._len = len(self.encoded.full_index)
self._by_chunked = is_chunked_array(self.encoded.codes.data)

@property
def sizes(self) -> Mapping[Hashable, int]:
Expand Down Expand Up @@ -632,6 +646,14 @@ def reduce(
) -> T_Xarray:
raise NotImplementedError()

def _raise_if_by_is_chunked(self):
if self._by_chunked:
raise ValueError(
"This method is not supported when lazily grouping by a chunked array. "
"Either load the array in to memory prior to grouping, or explore another "
"way of applying your function, potentially using the `flox` package."
)

def _raise_if_not_single_group(self):
if len(self.groupers) != 1:
raise NotImplementedError(
Expand Down Expand Up @@ -680,6 +702,7 @@ def __repr__(self) -> str:

def _iter_grouped(self) -> Iterator[T_Xarray]:
"""Iterate over each element in this group"""
self._raise_if_by_is_chunked()
for indices in self.encoded.group_indices:
if indices:
yield self._obj.isel({self._group_dim: indices})
Expand Down Expand Up @@ -854,7 +877,7 @@ def _flox_reduce(
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

if Version(flox.__version__) < Version("0.9"):
if Version(flox.__version__) < Version("0.9") and not self._by_chunked:
# preserve current strategy (approximately) for dask groupby
# on older flox versions to prevent surprises.
# flox >=0.9 will choose this on its own.
Expand Down Expand Up @@ -1267,6 +1290,7 @@ def _iter_grouped_shortcut(self):
"""Fast version of `_iter_grouped` that yields Variables without
metadata
"""
self._raise_if_by_is_chunked()
var = self._obj.variable
for _idx, indices in enumerate(self.encoded.group_indices):
if indices:
Expand Down Expand Up @@ -1428,6 +1452,12 @@ def reduce(
Array with summarized data and the indicated dimension(s)
removed.
"""
if self._by_chunked:
raise ValueError(
"This method is not supported when lazily grouping by a chunked array. "
"Try installing the `flox` package if you are using one of the standard "
"reductions (e.g. `mean`). "
)
if dim is None:
dim = [self._group_dim]

Expand Down Expand Up @@ -1579,6 +1609,14 @@ def reduce(
Array with summarized data and the indicated dimension(s)
removed.
"""

if self._by_chunked:
raise ValueError(
"This method is not supported when lazily grouping by a chunked array. "
"Try installing the `flox` package if you are using one of the standard "
"reductions (e.g. `mean`). "
)

if dim is None:
dim = [self._group_dim]

Expand Down
20 changes: 14 additions & 6 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,13 +1031,21 @@ def stack(
f"from variable {name!r} that wraps a multi-index"
)

split_labels, levels = zip(
*[lev.factorize() for lev in level_indexes], strict=True
)
labels_mesh = np.meshgrid(*split_labels, indexing="ij")
labels = [x.ravel() for x in labels_mesh]
# from_product sorts by default, so we can't use that always
# https://github.com/pydata/xarray/issues/980
# https://github.com/pandas-dev/pandas/issues/14672
if all(index.is_monotonic_increasing for index in level_indexes):
index = pd.MultiIndex.from_product(
level_indexes, sortorder=0, names=variables.keys()
)
else:
split_labels, levels = zip(
*[lev.factorize() for lev in level_indexes], strict=True
)
labels_mesh = np.meshgrid(*split_labels, indexing="ij")
labels = [x.ravel() for x in labels_mesh]

index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys())
index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys())
level_coords_dtype = {k: var.dtype for k, var in variables.items()}

return cls(index, dim, level_coords_dtype=level_coords_dtype)
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,15 +1017,18 @@ def contains_only_chunked_or_numpy(obj) -> bool:

Expects obj to be Dataset or DataArray"""
from xarray.core.dataarray import DataArray
from xarray.core.indexing import ExplicitlyIndexed
from xarray.namedarray.pycompat import is_chunked_array

if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()

return all(
[
isinstance(var.data, np.ndarray) or is_chunked_array(var.data)
for var in obj.variables.values()
isinstance(var._data, ExplicitlyIndexed)
or isinstance(var._data, np.ndarray)
or is_chunked_array(var._data)
for var in obj._variables.values()
]
)

Expand Down
Loading
Loading