Skip to content

Commit

Permalink
Support "subsampled" seasons
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Sep 22, 2024
1 parent 82f3c21 commit 9180536
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
33 changes: 26 additions & 7 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from __future__ import annotations

import datetime
import functools
import itertools
import operator
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -670,7 +672,12 @@ class SeasonResampler(Resampler):

def __post_init__(self):
self.season_inds = season_to_month_tuple(self.seasons)
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False))
all_inds = functools.reduce(operator.add, self.season_inds)
if len(all_inds) > len(set(all_inds)):
raise ValueError(
f"Overlapping seasons are not allowed. Received {self.seasons!r}"
)
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True))

def factorize(self, group):
if group.ndim != 1:
Expand All @@ -696,12 +703,22 @@ def factorize(self, group):
season_label[month.isin(season_ind)] = season_str
if "DJ" in season_str:
after_dec = season_ind[season_str.index("D") + 1 :]
# important this is assuming non-overlapping seasons
year[month.isin(after_dec)] -= 1

# Allow users to skip one or more months?
# present_seasons is a mask that is True for months that are requestsed in the output
present_seasons = season_label != ""
if present_seasons.all():
present_seasons = slice(None)
frame = pd.DataFrame(
data={"index": np.arange(group.size), "month": month},
data={
"index": np.arange(group[present_seasons].size),
"month": month[present_seasons],
},
index=pd.MultiIndex.from_arrays(
[year.data, season_label], names=["year", "season"]
[year.data[present_seasons], season_label[present_seasons]],
names=["year", "season"],
),
)

Expand All @@ -727,19 +744,19 @@ def factorize(self, group):

sbins = first_items.values.astype(int)
group_indices = [
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False)
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True)
]
group_indices += [slice(sbins[-1], None)]

# Make sure the first and last timestamps
# are for the correct months,if not we have incomplete seasons
unique_codes = np.arange(len(unique_coord))
if self.drop_incomplete:
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False):
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True):
stamp_year, stamp_season = frame.index[idx]
code = seasons.index(stamp_season)
stamp_month = season_inds[code][idx]
if stamp_month != month[idx].item():
if stamp_month != month[present_seasons][idx].item():
# we have an incomplete season!
group_indices = group_indices[slicer]
unique_coord = unique_coord[slicer]
Expand Down Expand Up @@ -769,7 +786,9 @@ def factorize(self, group):
if not full_index.equals(unique_coord):
raise ValueError("Are there seasons missing in the middle of the dataset?")

codes = group.copy(data=np.repeat(unique_codes, counts), deep=False)
final_codes = np.full(group.data.size, -1)
final_codes[present_seasons] = np.repeat(unique_codes, counts)
codes = group.copy(data=final_codes, deep=False)
unique_coord_var = Variable(group.name, unique_coord, group.attrs)

return EncodedGroups(
Expand Down
3 changes: 3 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,9 @@ def test_season_resampler():
# skip september
da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum()

# "subsampling"
da.groupby(time=SeasonResampler(["JJAS"])).sum()

# overlapping
with pytest.raises(ValueError):
da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum()
Expand Down

0 comments on commit 9180536

Please sign in to comment.