diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 264c07f562b..388d4d67340 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -51,6 +51,8 @@ Bug fixes the non-missing times could in theory be encoded with integers (:issue:`9488`, :pull:`9497`). By `Spencer Clark `_. +- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`). + By `Deepak Cherian `_. Documentation diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4b6185edf38..39679cbcff7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6786,8 +6786,8 @@ def groupby( >>> da.groupby("letters").sum() Size: 48B - array([[ 9., 11., 13.], - [ 9., 11., 13.]]) + array([[ 9, 11, 13], + [ 9, 11, 13]]) Coordinates: * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 08885e3cd8d..744c6d9eaa0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10390,7 +10390,7 @@ def groupby( * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y Data variables: - foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0 + foo (letters, y) int64 48B 9 11 13 9 11 13 Grouping by multiple variables diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a5e520b98b6..46339e5449a 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -791,14 +791,12 @@ def _maybe_restore_empty_groups(self, combined): """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index. """ - from xarray.groupers import BinGrouper, TimeResampler - + has_missing_groups = ( + self.encoded.unique_coord.size != self.encoded.full_index.size + ) indexers = {} for grouper in self.groupers: - if ( - isinstance(grouper.grouper, BinGrouper | TimeResampler) - and grouper.name in combined.dims - ): + if has_missing_groups and grouper.name in combined._indexes: indexers[grouper.name] = grouper.full_index if indexers: combined = combined.reindex(**indexers) @@ -853,10 +851,6 @@ def _flox_reduce( else obj._coords ) - any_isbin = any( - isinstance(grouper.grouper, BinGrouper) for grouper in self.groupers - ) - if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -930,14 +924,14 @@ def _flox_reduce( ): raise ValueError(f"cannot reduce over dimensions {dim}.") - if kwargs["func"] not in ["all", "any", "count"]: - kwargs.setdefault("fill_value", np.nan) - if any_isbin and kwargs["func"] == "count": - # This is an annoying hack. Xarray returns np.nan - # when there are no observations in a bin, instead of 0. - # We can fake that here by forcing min_count=1. - # note min_count makes no sense in the xarray world - # as a kwarg for count, so this should be OK + has_missing_groups = ( + self.encoded.unique_coord.size != self.encoded.full_index.size + ) + if has_missing_groups or kwargs.get("min_count", 0) > 0: + # Xarray *always* returns np.nan when there are no observations in a group, + # We can fake that here by forcing min_count=1 when it is not set. + # This handles boolean reductions, and count + # See GH8090, GH9398 kwargs.setdefault("fill_value", np.nan) kwargs.setdefault("min_count", 1) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 0e43738ed99..71ae1a7075f 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -148,6 +148,7 @@ def _importorskip( not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck" ) has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0") +_, requires_flox_0_9_12 = _importorskip("flox", "0.9.12") has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict") diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index fa6172c5d66..38feea88b18 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -34,6 +34,7 @@ requires_cftime, requires_dask, requires_flox, + requires_flox_0_9_12, requires_scipy, ) @@ -2859,6 +2860,60 @@ def test_multiple_groupers_mixed(use_flox) -> None: # ------ +@requires_flox_0_9_12 +@pytest.mark.parametrize( + "reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"] +) +def test_groupby_preserve_dtype(reduction): + # all groups are present, we should follow numpy exactly + ds = xr.Dataset( + { + "test": ( + ["x", "y"], + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"), + ) + }, + coords={"idx": ("x", [1, 2, 1])}, + ) + + kwargs = {} + if "nan" in reduction: + kwargs["skipna"] = True + # TODO: fix dtype with numbagg/bottleneck and use_flox=False + with xr.set_options(use_numbagg=False, use_bottleneck=False): + actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))( + **kwargs + ).test.dtype + expected = getattr(np, reduction)(ds.test.data, axis=0).dtype + + assert actual == expected + + +@requires_dask +@requires_flox_0_9_12 +@pytest.mark.parametrize("reduction", ["any", "all", "count"]) +def test_gappy_resample_reductions(reduction): + # GH8090 + dates = (("1988-12-01", "1990-11-30"), ("2000-12-01", "2001-11-30")) + times = [xr.date_range(*d, freq="D") for d in dates] + + da = xr.concat( + [ + xr.DataArray(np.random.rand(len(t)), coords={"time": t}, dims="time") + for t in times + ], + dim="time", + ).chunk(time=100) + + rs = (da > 0.5).resample(time="YS-DEC") + method = getattr(rs, reduction) + with xr.set_options(use_flox=True): + actual = method(dim="time") + with xr.set_options(use_flox=False): + expected = method(dim="time") + assert_identical(expected, actual) + + # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array