Skip to content

Commit

Permalink
flox: don't set fill_value where possible (#9433)
Browse files Browse the repository at this point in the history
* flox: don't set fill_value where possible

Closes #8090
Closes #8206
Closes #9398

* Update doctest

* Fix test

* fix test

* Test for flox >= 0.9.12

* fix whats-new
  • Loading branch information
dcherian committed Sep 18, 2024
1 parent 9cb9958 commit e313853
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 21 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ Bug fixes
the non-missing times could in theory be encoded with integers
(:issue:`9488`, :pull:`9497`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
By `Deepak Cherian <https://github.com/dcherian>`_.


Documentation
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6786,8 +6786,8 @@ def groupby(
>>> da.groupby("letters").sum()
<xarray.DataArray (letters: 2, y: 3)> Size: 48B
array([[ 9., 11., 13.],
[ 9., 11., 13.]])
array([[ 9, 11, 13],
[ 9, 11, 13]])
Coordinates:
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10390,7 +10390,7 @@ def groupby(
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Data variables:
foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0
foo (letters, y) int64 48B 9 11 13 9 11 13
Grouping by multiple variables
Expand Down
30 changes: 12 additions & 18 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,14 +791,12 @@ def _maybe_restore_empty_groups(self, combined):
"""Our index contained empty groups (e.g., from a resampling or binning). If we
reduced on that dimension, we want to restore the full index.
"""
from xarray.groupers import BinGrouper, TimeResampler

has_missing_groups = (
self.encoded.unique_coord.size != self.encoded.full_index.size
)
indexers = {}
for grouper in self.groupers:
if (
isinstance(grouper.grouper, BinGrouper | TimeResampler)
and grouper.name in combined.dims
):
if has_missing_groups and grouper.name in combined._indexes:
indexers[grouper.name] = grouper.full_index
if indexers:
combined = combined.reindex(**indexers)
Expand Down Expand Up @@ -853,10 +851,6 @@ def _flox_reduce(
else obj._coords
)

any_isbin = any(
isinstance(grouper.grouper, BinGrouper) for grouper in self.groupers
)

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)

Expand Down Expand Up @@ -930,14 +924,14 @@ def _flox_reduce(
):
raise ValueError(f"cannot reduce over dimensions {dim}.")

if kwargs["func"] not in ["all", "any", "count"]:
kwargs.setdefault("fill_value", np.nan)
if any_isbin and kwargs["func"] == "count":
# This is an annoying hack. Xarray returns np.nan
# when there are no observations in a bin, instead of 0.
# We can fake that here by forcing min_count=1.
# note min_count makes no sense in the xarray world
# as a kwarg for count, so this should be OK
has_missing_groups = (
self.encoded.unique_coord.size != self.encoded.full_index.size
)
if has_missing_groups or kwargs.get("min_count", 0) > 0:
# Xarray *always* returns np.nan when there are no observations in a group,
# We can fake that here by forcing min_count=1 when it is not set.
# This handles boolean reductions, and count
# See GH8090, GH9398
kwargs.setdefault("fill_value", np.nan)
kwargs.setdefault("min_count", 1)

Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _importorskip(
not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck"
)
has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0")
_, requires_flox_0_9_12 = _importorskip("flox", "0.9.12")

has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict")

Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
requires_cftime,
requires_dask,
requires_flox,
requires_flox_0_9_12,
requires_scipy,
)

Expand Down Expand Up @@ -2859,6 +2860,60 @@ def test_multiple_groupers_mixed(use_flox) -> None:
# ------


@requires_flox_0_9_12
@pytest.mark.parametrize(
"reduction", ["max", "min", "nanmax", "nanmin", "sum", "nansum", "prod", "nanprod"]
)
def test_groupby_preserve_dtype(reduction):
# all groups are present, we should follow numpy exactly
ds = xr.Dataset(
{
"test": (
["x", "y"],
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype="int16"),
)
},
coords={"idx": ("x", [1, 2, 1])},
)

kwargs = {}
if "nan" in reduction:
kwargs["skipna"] = True
# TODO: fix dtype with numbagg/bottleneck and use_flox=False
with xr.set_options(use_numbagg=False, use_bottleneck=False):
actual = getattr(ds.groupby("idx"), reduction.removeprefix("nan"))(
**kwargs
).test.dtype
expected = getattr(np, reduction)(ds.test.data, axis=0).dtype

assert actual == expected


@requires_dask
@requires_flox_0_9_12
@pytest.mark.parametrize("reduction", ["any", "all", "count"])
def test_gappy_resample_reductions(reduction):
# GH8090
dates = (("1988-12-01", "1990-11-30"), ("2000-12-01", "2001-11-30"))
times = [xr.date_range(*d, freq="D") for d in dates]

da = xr.concat(
[
xr.DataArray(np.random.rand(len(t)), coords={"time": t}, dims="time")
for t in times
],
dim="time",
).chunk(time=100)

rs = (da > 0.5).resample(time="YS-DEC")
method = getattr(rs, reduction)
with xr.set_options(use_flox=True):
actual = method(dim="time")
with xr.set_options(use_flox=False):
expected = method(dim="time")
assert_identical(expected, actual)


# Possible property tests
# 1. lambda x: x
# 2. grouped-reduce on unique coords is identical to array
Expand Down

0 comments on commit e313853

Please sign in to comment.