Skip to content

Commit

Permalink
Grouper, Resampler as public api
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 15, 2024
1 parent c9d3084 commit 3c6d2d5
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 29 deletions.
24 changes: 14 additions & 10 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def _resample(
# TODO support non-string indexer after removing the old API.

from xarray.core.dataarray import DataArray
from xarray.core.groupby import ResolvedGrouper, TimeResampler
from xarray.core.groupby import Resampler, ResolvedGrouper, TimeResampler
from xarray.core.resample import RESAMPLE_DIM

# note: the second argument (now 'skipna') use to be 'dim'
Expand Down Expand Up @@ -1079,15 +1079,19 @@ def _resample(
name=RESAMPLE_DIM,
)

grouper = TimeResampler(
freq=freq,
closed=closed,
label=label,
origin=origin,
offset=offset,
loffset=loffset,
base=base,
)
if isinstance(freq, str):
grouper = TimeResampler(
freq=freq,
closed=closed,
label=label,
origin=origin,
offset=offset,
loffset=loffset,
base=base,
)
else:
assert isinstance(freq, Resampler)
grouper = freq

rgrouper = ResolvedGrouper(grouper, group, self)

Expand Down
17 changes: 15 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6636,9 +6636,10 @@ def interp_calendar(

def groupby(
self,
group: Hashable | DataArray | IndexVariable,
group: Hashable | DataArray | IndexVariable = None,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
**groupers,
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -6710,7 +6711,19 @@ def groupby(
)

_validate_groupby_squeeze(squeeze)
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)

if group is not None:
assert not groupers
grouper = UniqueGrouper()
else:
if len(groupers) > 1:
raise ValueError("grouping by multiple variables is not supported yet.")
if not groupers:
raise ValueError
group, grouper = next(iter(groupers.items()))

rgrouper = ResolvedGrouper(grouper, group, self)

return DataArrayGroupBy(
self,
(rgrouper,),
Expand Down
14 changes: 12 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10134,9 +10134,10 @@ def interp_calendar(

def groupby(
self,
group: Hashable | DataArray | IndexVariable,
group: Hashable | DataArray | IndexVariable | None = None,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
**groupers,
) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -10186,7 +10187,16 @@ def groupby(
)

_validate_groupby_squeeze(squeeze)
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
if group is not None:
assert not groupers
rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
else:
if len(groupers) > 1:
raise ValueError("grouping by multiple variables is not supported yet.")
if not groupers:
raise ValueError
for group, grouper in groupers.items():
rgrouper = ResolvedGrouper(grouper, group, self)

return DatasetGroupBy(
self,
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def attrs(self) -> dict:

def __getitem__(self, key):
if isinstance(key, tuple):
key = key[0]
(key,) = key
return self.values[key]

def to_index(self) -> pd.Index:
Expand Down
49 changes: 35 additions & 14 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

import xarray as xr
from xarray import DataArray, Dataset, Variable
from xarray.core.groupby import _consolidate_slices
from xarray.core.groupby import (
BinGrouper,
UniqueGrouper,
_consolidate_slices,
)
from xarray.tests import (
InaccessibleArray,
assert_allclose,
Expand Down Expand Up @@ -112,8 +116,9 @@ def test_multi_index_groupby_map(dataset) -> None:
assert_equal(expected, actual)


def test_reduce_numeric_only(dataset) -> None:
gb = dataset.groupby("x", squeeze=False)
@pytest.mark.parametrize("grouper", [dict(group="x"), dict(x=UniqueGrouper())])
def test_reduce_numeric_only(dataset, grouper) -> None:
gb = dataset.groupby(**grouper, squeeze=False)
with xr.set_options(use_flox=False):
expected = gb.sum()
with xr.set_options(use_flox=True):
Expand Down Expand Up @@ -830,11 +835,12 @@ def test_groupby_dataset_reduce() -> None:

expected = data.mean("y")
expected["yonly"] = expected["yonly"].variable.set_dims({"x": 3})
actual = data.groupby("x").mean(...)
assert_allclose(expected, actual)
for gb in [data.groupby("x"), data.groupby(x=UniqueGrouper())]:
actual = gb.mean(...)
assert_allclose(expected, actual)

actual = data.groupby("x").mean("y")
assert_allclose(expected, actual)
actual = gb.mean("y")
assert_allclose(expected, actual)

letters = data["letters"]
expected = Dataset(
Expand All @@ -844,8 +850,9 @@ def test_groupby_dataset_reduce() -> None:
"yonly": data["yonly"].groupby(letters).mean(),
}
)
actual = data.groupby("letters").mean(...)
assert_allclose(expected, actual)
for gb in [data.groupby("letters"), data.groupby(letters=UniqueGrouper())]:
actual = gb.mean(...)
assert_allclose(expected, actual)


@pytest.mark.parametrize("squeeze", [True, False])
Expand Down Expand Up @@ -975,6 +982,14 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
)
assert_identical(expected, actual)

with xr.set_options(use_flox=use_flox):
actual = da.groupby(
x=BinGrouper(
bins=x_bins, cut_kwargs=dict(include_lowest=True, right=False)
),
).mean()
assert_identical(expected, actual)


@pytest.mark.parametrize("indexed_coord", [True, False])
def test_groupby_bins_math(indexed_coord) -> None:
Expand All @@ -983,11 +998,17 @@ def test_groupby_bins_math(indexed_coord) -> None:
if indexed_coord:
da["x"] = np.arange(N)
da["y"] = np.arange(N)
g = da.groupby_bins("x", np.arange(0, N + 1, 3))
mean = g.mean()
expected = da.isel(x=slice(1, None)) - mean.isel(x_bins=("x", [0, 0, 0, 1, 1, 1]))
actual = g - mean
assert_identical(expected, actual)

for g in [
da.groupby_bins("x", np.arange(0, N + 1, 3)),
da.groupby(x=BinGrouper(bins=np.arange(0, N + 1, 3))),
]:
mean = g.mean()
expected = da.isel(x=slice(1, None)) - mean.isel(
x_bins=("x", [0, 0, 0, 1, 1, 1])
)
actual = g - mean
assert_identical(expected, actual)


def test_groupby_math_nD_group() -> None:
Expand Down

0 comments on commit 3c6d2d5

Please sign in to comment.