diff --git a/xarray/groupers.py b/xarray/groupers.py index 562a8323909..bf5f6f0c2d4 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,9 @@ from __future__ import annotations import datetime +import functools import itertools +import operator from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping, Sequence @@ -670,7 +672,12 @@ class SeasonResampler(Resampler): def __post_init__(self): self.season_inds = season_to_month_tuple(self.seasons) - self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False)) + all_inds = functools.reduce(operator.add, self.season_inds) + if len(all_inds) > len(set(all_inds)): + raise ValueError( + f"Overlapping seasons are not allowed. Received {self.seasons!r}" + ) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) def factorize(self, group): if group.ndim != 1: @@ -696,12 +703,22 @@ def factorize(self, group): season_label[month.isin(season_ind)] = season_str if "DJ" in season_str: after_dec = season_ind[season_str.index("D") + 1 :] + # important this is assuming non-overlapping seasons year[month.isin(after_dec)] -= 1 + # Allow users to skip one or more months? + # present_seasons is a mask that is True for months that are requestsed in the output + present_seasons = season_label != "" + if present_seasons.all(): + present_seasons = slice(None) frame = pd.DataFrame( - data={"index": np.arange(group.size), "month": month}, + data={ + "index": np.arange(group[present_seasons].size), + "month": month[present_seasons], + }, index=pd.MultiIndex.from_arrays( - [year.data, season_label], names=["year", "season"] + [year.data[present_seasons], season_label[present_seasons]], + names=["year", "season"], ), ) @@ -727,7 +744,7 @@ def factorize(self, group): sbins = first_items.values.astype(int) group_indices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False) + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) ] group_indices += [slice(sbins[-1], None)] @@ -735,11 +752,11 @@ def factorize(self, group): # are for the correct months,if not we have incomplete seasons unique_codes = np.arange(len(unique_coord)) if self.drop_incomplete: - for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False): + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True): stamp_year, stamp_season = frame.index[idx] code = seasons.index(stamp_season) stamp_month = season_inds[code][idx] - if stamp_month != month[idx].item(): + if stamp_month != month[present_seasons][idx].item(): # we have an incomplete season! group_indices = group_indices[slicer] unique_coord = unique_coord[slicer] @@ -769,7 +786,9 @@ def factorize(self, group): if not full_index.equals(unique_coord): raise ValueError("Are there seasons missing in the middle of the dataset?") - codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) + final_codes = np.full(group.data.size, -1) + final_codes[present_seasons] = np.repeat(unique_codes, counts) + codes = group.copy(data=final_codes, deep=False) unique_coord_var = Variable(group.name, unique_coord, group.attrs) return EncodedGroups( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3949049f6a9..d9f1471c561 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2958,6 +2958,9 @@ def test_season_resampler(): # skip september da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum() + # "subsampling" + da.groupby(time=SeasonResampler(["JJAS"])).sum() + # overlapping with pytest.raises(ValueError): da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum()