From 594d4a7118159cfb45b4b14e41d6b35ce9847419 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 28 Jun 2024 11:00:08 -0400 Subject: [PATCH 01/10] Add SeasonGrouper, SeasonResampler These two groupers allow defining custom seasons, and dropping incomplete seasons from the output. Both cases are treated by adjusting the factorization -- conversion from group labels to integer codes -- appropriately. --- doc/api.rst | 2 + properties/test_properties.py | 31 +++++ xarray/core/toolzcompat.py | 56 +++++++++ xarray/groupers.py | 218 ++++++++++++++++++++++++++++++++++ xarray/tests/test_groupby.py | 28 ++++- 5 files changed, 329 insertions(+), 6 deletions(-) create mode 100644 xarray/core/toolzcompat.py diff --git a/doc/api.rst b/doc/api.rst index 87f116514cc..941ab8a7602 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1126,6 +1126,8 @@ Grouper Objects groupers.BinGrouper groupers.UniqueGrouper groupers.TimeResampler + groupers.SeasonGrouper + groupers.SeasonResampler Rolling objects diff --git a/properties/test_properties.py b/properties/test_properties.py index fc0a1955539..859f9d4e500 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -2,10 +2,12 @@ pytest.importorskip("hypothesis") +import hypothesis.strategies as st from hypothesis import given import xarray as xr import xarray.testing.strategies as xrst +from xarray.groupers import season_to_month_tuple @given(attrs=xrst.simple_attrs) @@ -15,3 +17,32 @@ def test_assert_identical(attrs): ds = xr.Dataset(attrs=attrs) xr.testing.assert_identical(ds, ds.copy(deep=True)) + + +@given( + roll=st.integers(min_value=0, max_value=12), + breaks=st.lists( + st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True + ), +) +def test_property_season_month_tuple(roll, breaks): + chars = list("JFMAMJJASOND") + months = tuple(range(1, 13)) + + rolled_chars = chars[roll:] + chars[:roll] + rolled_months = months[roll:] + months[:roll] + breaks = sorted(breaks) + if breaks[0] != 0: + breaks = [0] + breaks + if breaks[-1] != 12: + breaks = breaks + [12] + seasons = tuple( + "".join(rolled_chars[start:stop]) + for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + ) + actual = season_to_month_tuple(seasons) + expected = tuple( + rolled_months[start:stop] + for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + ) + assert expected == actual diff --git a/xarray/core/toolzcompat.py b/xarray/core/toolzcompat.py new file mode 100644 index 00000000000..4632419a845 --- /dev/null +++ b/xarray/core/toolzcompat.py @@ -0,0 +1,56 @@ +# This file contains functions copied from the toolz library in accordance +# with its license. The original copyright notice is duplicated below. + +# Copyright (c) 2013 Matthew Rocklin + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of toolz nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. + + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + + +def sliding_window(n, seq): + """A sequence of overlapping subsequences + + >>> list(sliding_window(2, [1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4)] + + This function creates a sliding window suitable for transformations like + sliding means / smoothing + + >>> mean = lambda seq: float(sum(seq)) / len(seq) + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) + [1.5, 2.5, 3.5] + """ + import collections + import itertools + + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ), + strict=False, + ) diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..0ba59f726df 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,9 @@ from __future__ import annotations import datetime +import itertools from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, cast @@ -16,11 +18,13 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops +from xarray.core.common import _contains_datetime_like_objects from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper +from xarray.core.toolzcompat import sliding_window from xarray.core.types import ( Bins, DatetimeLike, @@ -485,3 +489,217 @@ def unique_value_groups( if isinstance(values, pd.MultiIndex): values.names = ar.names return values, inverse + + +def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: + initials = "JFMAMJJASOND" + starts = dict( + ("".join(s), i + 1) + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False) + ) + result: list[tuple[int, ...]] = [] + for i, season in enumerate(seasons): + if len(season) == 1: + if i < len(seasons) - 1: + suffix = seasons[i + 1][0] + else: + suffix = seasons[0][0] + else: + suffix = season[1] + + start = starts[season[0] + suffix] + + month_append = [] + for i in range(len(season[1:])): + elem = start + i + 1 + month_append.append(elem - 12 * (elem > 12)) + result.append((start,) + tuple(month_append)) + return tuple(result) + + +@dataclass +class SeasonGrouper(Grouper): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: sequence of str + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + + Examples + -------- + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + # drop_incomplete: bool = field(default=True) # TODO + + def __post_init__(self) -> None: + self.season_inds = season_to_month_tuple(self.seasons) + + def factorize(self, group: T_Group) -> EncodedGroups: + if TYPE_CHECKING: + assert not isinstance(group, _DummyGroup) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonGrouper can only be used to group by datetime-like arrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + + months = group.dt.month + codes_ = np.full(group.shape, -1) + group_indices: list[list[int]] = [[]] * len(seasons) + + index = np.arange(group.size) + for idx, season_tuple in enumerate(season_inds): + mask = months.isin(season_tuple) + codes_[mask] = idx + group_indices[idx] = index[mask] + + if np.all(codes_ == -1): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = group.copy(data=codes_, deep=False).rename("season") + unique_coord = Variable("season", seasons, attrs=group.attrs) + full_index = pd.Index(seasons) + return EncodedGroups( + codes=codes, + group_indices=tuple(group_indices), + unique_coord=unique_coord, + full_index=full_index, + ) + + +@dataclass +class SeasonResampler(Resampler): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: Sequence[str] + An ordered list of seasons. + drop_incomplete: bool + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. + + Examples + -------- + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + drop_incomplete: bool = field(default=True) + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False)) + + def factorize(self, group): + if group.ndim != 1: + raise ValueError( + "SeasonResampler can only be used to resample by 1D arrays." + ) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonResampler can only be used to group by datetime-like arrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + season_tuples = self.season_tuples + + nstr = max(len(s) for s in seasons) + year = group.dt.year.astype(int) + month = group.dt.month.astype(int) + season_label = np.full(group.shape, "", dtype=f"U{nstr}") + + # offset years for seasons with December and January + for season_str, season_ind in zip(seasons, season_inds, strict=False): + season_label[month.isin(season_ind)] = season_str + if "DJ" in season_str: + after_dec = season_ind[season_str.index("D") + 1 :] + year[month.isin(after_dec)] -= 1 + + frame = pd.DataFrame( + data={"index": np.arange(group.size), "month": month}, + index=pd.MultiIndex.from_arrays( + [year.data, season_label], names=["year", "season"] + ), + ) + + series = frame["index"] + g = series.groupby(["year", "season"], sort=False) + first_items = g.first() + counts = g.count() + + # these are the seasons that are present + unique_coord = pd.DatetimeIndex( + [ + pd.Timestamp(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) + + sbins = first_items.values.astype(int) + group_indices = [ + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False) + ] + group_indices += [slice(sbins[-1], None)] + + # Make sure the first and last timestamps + # are for the correct months,if not we have incomplete seasons + unique_codes = np.arange(len(unique_coord)) + if self.drop_incomplete: + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False): + stamp_year, stamp_season = frame.index[idx] + code = seasons.index(stamp_season) + stamp_month = season_inds[code][idx] + if stamp_month != month[idx].item(): + # we have an incomplete season! + group_indices = group_indices[slicer] + unique_coord = unique_coord[slicer] + if idx == 0: + unique_codes -= 1 + unique_codes[idx] = -1 + + # all years and seasons + complete_index = pd.DatetimeIndex( + # This sorted call is a hack. It's hard to figure out how + # to start the iteration + sorted( + [ + pd.Timestamp(f"{y}-{m}-01") + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) + ) + # only keep that included in data + range_ = complete_index.get_indexer(unique_coord[[0, -1]]) + full_index = complete_index[slice(range_[0], range_[-1] + 1)] + # check that there are no "missing" seasons in the middle + # print(full_index, unique_coord) + if not full_index.equals(unique_coord): + raise ValueError("Are there seasons missing in the middle of the dataset?") + + codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) + unique_coord_var = Variable(group.name, unique_coord, group.attrs) + + return EncodedGroups( + codes=codes, + group_indices=group_indices, + unique_coord=unique_coord_var, + full_index=full_index, + ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dc869cc3a34..0a5fbccbdf4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -21,6 +21,7 @@ Grouper, TimeResampler, UniqueGrouper, + season_to_month_tuple, ) from xarray.tests import ( InaccessibleArray, @@ -2915,12 +2916,6 @@ def test_gappy_resample_reductions(reduction): assert_identical(expected, actual) -# Possible property tests -# 1. lambda x: x -# 2. grouped-reduce on unique coords is identical to array -# 3. group_over == groupby-reduce along other dimensions - - def test_groupby_transpose(): # GH5361 data = xr.DataArray( @@ -2932,3 +2927,24 @@ def test_groupby_transpose(): second = data.groupby("x").sum() assert_identical(first, second.transpose(*first.dims)) + + +def test_season_to_month_tuple(): + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( + (1, 2), + (3, 4, 5), + (6, 7, 8, 9), + (10, 11, 12), + ) + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( + (12, 1, 2, 3), + (4, 5), + (6, 7, 8, 9), + (10, 11), + ) + + +# Possible property tests +# 1. lambda x: x +# 2. grouped-reduce on unique coords is identical to array +# 3. group_over == groupby-reduce along other dimensions From de0edf40405a2d6cf52220138b034e32856974bf Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 17:46:48 -0600 Subject: [PATCH 02/10] Allow sliding seasons --- xarray/core/groupby.py | 5 ++- xarray/groupers.py | 92 ++++++++++++++++++++++++++++++++---------- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 58971435018..a70598e5b6e 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -244,6 +244,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ from xarray.core.dataarray import DataArray if isinstance(group, DataArray): + group, obj = broadcast(group, obj) # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) @@ -906,7 +907,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.group.dims for grouper in self.groupers) + *(grouper.codes.dims for grouper in self.groupers) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -920,7 +921,7 @@ def _flox_reduce( # Better to control it here than in flox. for grouper in self.groupers: if any( - d not in grouper.group.dims and d not in obj.dims for d in parsed_dim + d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim ): raise ValueError(f"cannot reduce over dimensions {dim}.") diff --git a/xarray/groupers.py b/xarray/groupers.py index 0ba59f726df..c8c64f93503 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -79,9 +79,9 @@ class EncodedGroups: codes: DataArray full_index: pd.Index - group_indices: GroupIndices - unique_coord: Variable | _DummyGroup - coords: Coordinates + group_indices: GroupIndices = field(init=False, repr=False) + unique_coord: Variable | _DummyGroup = field(init=False, repr=False) + coords: Coordinates = field(init=False, repr=False) def __init__( self, @@ -517,6 +517,50 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] return tuple(result) +def to_string(asints: tuple[tuple[tuple[int]]]): + inits = "JFMAMJJASOND" + + return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) + + +@dataclass +class SeasonsGroup: + seasons: tuple[str, ...] + inds: tuple[tuple[int, ...], ...] + codes: Sequence[int] + + +def find_independent_seasons(seasons: Sequence[str]): + from collections import defaultdict + from itertools import chain + + sinds = season_to_month_tuple(seasons) + grouped = defaultdict(list) + codes = defaultdict(list) + idx = 0 + seen: set[int] = set() + for i, first in enumerate(sinds): + if first not in seen: + grouped[idx].append(first) + codes[idx].append(i) + seen.update((first,)) + + for j, second in enumerate(sinds[i:]): + if not (set(chain(*grouped[idx])) & set(second)): + if second not in seen: + grouped[idx].append(second) + codes[idx].append(j + i) + seen.update((second,)) + idx += 1 + + grouped_ints = [idx for idx in grouped.values() if idx] + # TODO: inds is unncessary + return [ + SeasonsGroup(seasons=to_string(inds), inds=inds, codes=codes) + for inds, codes in zip(grouped_ints, codes.values(), strict=False) + ] + + @dataclass class SeasonGrouper(Grouper): """Allows grouping using a custom definition of seasons. @@ -533,11 +577,11 @@ class SeasonGrouper(Grouper): """ seasons: Sequence[str] - season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + # season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) # drop_incomplete: bool = field(default=True) # TODO - def __post_init__(self) -> None: - self.season_inds = season_to_month_tuple(self.seasons) + # def __post_init__(self) -> None: + # self.season_inds = season_to_month_tuple(self.seasons) def factorize(self, group: T_Group) -> EncodedGroups: if TYPE_CHECKING: @@ -546,27 +590,31 @@ def factorize(self, group: T_Group) -> EncodedGroups: raise ValueError( "SeasonGrouper can only be used to group by datetime-like arrays." ) - - seasons = self.seasons - season_inds = self.season_inds - - months = group.dt.month - codes_ = np.full(group.shape, -1) - group_indices: list[list[int]] = [[]] * len(seasons) - - index = np.arange(group.size) - for idx, season_tuple in enumerate(season_inds): - mask = months.isin(season_tuple) - codes_[mask] = idx - group_indices[idx] = index[mask] + months = group.dt.month.data + seasons_groups = find_independent_seasons(self.seasons) + codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8) + group_indices: list[list[int]] = [[]] * len(self.seasons) + for axis_index, seasgroup in enumerate(seasons_groups): + for season_tuple, code in zip( + seasgroup.inds, seasgroup.codes, strict=False + ): + mask = np.isin(months, season_tuple) + codes_[axis_index, mask] = code + (group_indices[code],) = mask.nonzero() if np.all(codes_ == -1): raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - codes = group.copy(data=codes_, deep=False).rename("season") - unique_coord = Variable("season", seasons, attrs=group.attrs) - full_index = pd.Index(seasons) + needs_dummy_dim = len(seasons_groups) > 1 + codes = DataArray( + dims=(("__season_dim__",) if needs_dummy_dim else tuple()) + group.dims, + data=codes_ if needs_dummy_dim else codes_.squeeze(), + attrs=group.attrs, + name="season", + ) + unique_coord = Variable("season", self.seasons, attrs=group.attrs) + full_index = pd.Index(self.seasons) return EncodedGroups( codes=codes, group_indices=tuple(group_indices), From 6707b6269a61ef06569c0851ee4755299b947260 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 20:04:10 -0600 Subject: [PATCH 03/10] minor fix --- xarray/core/groupby.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a70598e5b6e..ca51d2879d9 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -244,7 +244,8 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ from xarray.core.dataarray import DataArray if isinstance(group, DataArray): - group, obj = broadcast(group, obj) + for dim in set(group.dims) - set(obj.dims): + obj = obj.expand_dims(dim) # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) From d296efde2feb56abfe25f8d29efef13fde16d57d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 21:17:29 -0600 Subject: [PATCH 04/10] cleanup --- xarray/groupers.py | 54 +++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index c8c64f93503..609c38018ae 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -9,8 +9,10 @@ import datetime import itertools from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Mapping, Sequence from dataclasses import dataclass, field +from itertools import chain from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np @@ -517,9 +519,8 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] return tuple(result) -def to_string(asints: tuple[tuple[tuple[int]]]): +def inds_to_string(asints: tuple[tuple[int, ...]]) -> tuple[str, ...]: inits = "JFMAMJJASOND" - return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) @@ -530,33 +531,39 @@ class SeasonsGroup: codes: Sequence[int] -def find_independent_seasons(seasons: Sequence[str]): - from collections import defaultdict - from itertools import chain - +def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: + """ + Iterates though a list of seasons e.g. ["DJF", "FMA", ...], + and splits that into multiple sequences of non-overlapping seasons. + """ sinds = season_to_month_tuple(seasons) grouped = defaultdict(list) codes = defaultdict(list) + seen: set[tuple[int, ...]] = set() idx = 0 - seen: set[int] = set() - for i, first in enumerate(sinds): - if first not in seen: - grouped[idx].append(first) + # This is quadratic, but the length of seasons is at most 12 + for i, current in enumerate(sinds): + # Start with a group + if current not in seen: + grouped[idx].append(current) codes[idx].append(i) - seen.update((first,)) + seen.add(current) + # Loop through remaining groups, and look for overlaps for j, second in enumerate(sinds[i:]): if not (set(chain(*grouped[idx])) & set(second)): if second not in seen: grouped[idx].append(second) codes[idx].append(j + i) - seen.update((second,)) + seen.add(second) + if len(seen) == len(seasons): + break + # found all non-overlapping groups for this row, increment and start over idx += 1 - grouped_ints = [idx for idx in grouped.values() if idx] - # TODO: inds is unncessary + grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx) return [ - SeasonsGroup(seasons=to_string(inds), inds=inds, codes=codes) + SeasonsGroup(seasons=inds_to_string(inds), inds=inds, codes=codes) for inds, codes in zip(grouped_ints, codes.values(), strict=False) ] @@ -573,16 +580,20 @@ class SeasonGrouper(Grouper): Examples -------- >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) - >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) + SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) + + The ordering is preserved + >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) + SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) + + Overlapping seasons are allowed + >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) + SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) """ seasons: Sequence[str] - # season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) # drop_incomplete: bool = field(default=True) # TODO - # def __post_init__(self) -> None: - # self.season_inds = season_to_month_tuple(self.seasons) - def factorize(self, group: T_Group) -> EncodedGroups: if TYPE_CHECKING: assert not isinstance(group, _DummyGroup) @@ -640,7 +651,10 @@ class SeasonResampler(Resampler): Examples -------- >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + SeasonResampler(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=True) + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + SeasonResampler(seasons=['DJFM', 'AM', 'JJA', 'SON'], drop_incomplete=True) """ seasons: Sequence[str] From e3302c76f1f0c26e248edd632600f882c2f7138b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 21:38:45 -0600 Subject: [PATCH 05/10] Fix quantile --- xarray/core/groupby.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ca51d2879d9..25a541cdc61 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1159,9 +1159,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - if dim is None: - dim = (self._group_dim,) - # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1180,7 +1177,7 @@ def quantile( self._obj.__class__.quantile, shortcut=False, q=q, - dim=dim, + dim=dim or self._group_dim, method=method, keep_attrs=keep_attrs, skipna=skipna, From 07e60347d1a0d9cd59496aa7e2f2fef339cb6dd2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 21:38:52 -0600 Subject: [PATCH 06/10] Nicer repr --- xarray/core/groupby.py | 2 +- xarray/tests/test_groupby.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 25a541cdc61..adbcf98db3b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -677,7 +677,7 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}" + text += f"\n {grouper.name!r}: {type(grouper.grouper).__name__}({grouper.group.name!r}), {coord.size} groups with labels {labels}" return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 0a5fbccbdf4..dff1fc61313 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -586,7 +586,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n {dim!r}: {N} groups with labels " + expected += f"\n {dim!r}: UniqueGrouper({dim!r}), {N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -603,7 +603,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += " 'month': 12 groups with labels " + expected += " 'month': UniqueGrouper('month'), 12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected From e035bfa83078dba48f0c828caad8b08d50ebe98e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 20 Sep 2024 21:45:27 -0600 Subject: [PATCH 07/10] fix --- xarray/groupers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 609c38018ae..ebde448b926 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -519,7 +519,7 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] return tuple(result) -def inds_to_string(asints: tuple[tuple[int, ...]]) -> tuple[str, ...]: +def inds_to_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: inits = "JFMAMJJASOND" return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) @@ -611,7 +611,8 @@ def factorize(self, group: T_Group) -> EncodedGroups: ): mask = np.isin(months, season_tuple) codes_[axis_index, mask] = code - (group_indices[code],) = mask.nonzero() + (indices,) = mask.nonzero() + group_indices[code] = indices.tolist() if np.all(codes_ == -1): raise ValueError( From 17cc3d8d46d2600956a0730b2cb21a19d7fe9d45 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 20:59:27 -0600 Subject: [PATCH 08/10] cftime support --- xarray/groupers.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index ebde448b926..562a8323909 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -19,10 +19,15 @@ import pandas as pd from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops -from xarray.core.common import _contains_datetime_like_objects +from xarray.core.common import ( + _contains_cftime_datetimes, + _contains_datetime_like_objects, +) from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray +from xarray.core.formatting import first_n_items from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper @@ -497,7 +502,7 @@ def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...] initials = "JFMAMJJASOND" starts = dict( ("".join(s), i + 1) - for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False) + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=True) ) result: list[tuple[int, ...]] = [] for i, season in enumerate(seasons): @@ -687,7 +692,7 @@ def factorize(self, group): season_label = np.full(group.shape, "", dtype=f"U{nstr}") # offset years for seasons with December and January - for season_str, season_ind in zip(seasons, season_inds, strict=False): + for season_str, season_ind in zip(seasons, season_inds, strict=True): season_label[month.isin(season_ind)] = season_str if "DJ" in season_str: after_dec = season_ind[season_str.index("D") + 1 :] @@ -705,10 +710,17 @@ def factorize(self, group): first_items = g.first() counts = g.count() + if _contains_cftime_datetimes(group.data): + index_class = CFTimeIndex + datetime_class = type(first_n_items(group.data, 1).item()) + else: + index_class = pd.DatetimeIndex + datetime_class = datetime.datetime + # these are the seasons that are present - unique_coord = pd.DatetimeIndex( + unique_coord = index_class( [ - pd.Timestamp(year=year, month=season_tuples[season][0], day=1) + datetime_class(year=year, month=season_tuples[season][0], day=1) for year, season in first_items.index ] ) @@ -736,12 +748,12 @@ def factorize(self, group): unique_codes[idx] = -1 # all years and seasons - complete_index = pd.DatetimeIndex( + complete_index = index_class( # This sorted call is a hack. It's hard to figure out how # to start the iteration sorted( [ - pd.Timestamp(f"{y}-{m}-01") + datetime_class(year=y, month=m, day=1) for y, m in itertools.product( range(year[0].item(), year[-1].item() + 1), [s[0] for s in season_inds], From 82f3c21e9c1c259c8b57039bd7c19c941eea6c9a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 21:22:23 -0600 Subject: [PATCH 09/10] Add skeleton tests --- xarray/tests/test_groupby.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dff1fc61313..3949049f6a9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -11,7 +11,7 @@ from packaging.version import Version import xarray as xr -from xarray import DataArray, Dataset, Variable +from xarray import DataArray, Dataset, Variable, cftime_range from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible @@ -19,6 +19,7 @@ BinGrouper, EncodedGroups, Grouper, + SeasonResampler, TimeResampler, UniqueGrouper, season_to_month_tuple, @@ -2944,6 +2945,24 @@ def test_season_to_month_tuple(): ) +def test_season_resampler(): + time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + # through resample + da.resample(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum() + + # through groupby + da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "SON"])).sum() + + # skip september + da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum() + + # overlapping + with pytest.raises(ValueError): + da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() + + # Possible property tests # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array From 91805360eb3a7fcd6db40ea188ae050834a13672 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Sat, 21 Sep 2024 21:22:33 -0600 Subject: [PATCH 10/10] Support "subsampled" seasons --- xarray/groupers.py | 33 ++++++++++++++++++++++++++------- xarray/tests/test_groupby.py | 3 +++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/xarray/groupers.py b/xarray/groupers.py index 562a8323909..bf5f6f0c2d4 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,9 @@ from __future__ import annotations import datetime +import functools import itertools +import operator from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping, Sequence @@ -670,7 +672,12 @@ class SeasonResampler(Resampler): def __post_init__(self): self.season_inds = season_to_month_tuple(self.seasons) - self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False)) + all_inds = functools.reduce(operator.add, self.season_inds) + if len(all_inds) > len(set(all_inds)): + raise ValueError( + f"Overlapping seasons are not allowed. Received {self.seasons!r}" + ) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) def factorize(self, group): if group.ndim != 1: @@ -696,12 +703,22 @@ def factorize(self, group): season_label[month.isin(season_ind)] = season_str if "DJ" in season_str: after_dec = season_ind[season_str.index("D") + 1 :] + # important this is assuming non-overlapping seasons year[month.isin(after_dec)] -= 1 + # Allow users to skip one or more months? + # present_seasons is a mask that is True for months that are requestsed in the output + present_seasons = season_label != "" + if present_seasons.all(): + present_seasons = slice(None) frame = pd.DataFrame( - data={"index": np.arange(group.size), "month": month}, + data={ + "index": np.arange(group[present_seasons].size), + "month": month[present_seasons], + }, index=pd.MultiIndex.from_arrays( - [year.data, season_label], names=["year", "season"] + [year.data[present_seasons], season_label[present_seasons]], + names=["year", "season"], ), ) @@ -727,7 +744,7 @@ def factorize(self, group): sbins = first_items.values.astype(int) group_indices = [ - slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False) + slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) ] group_indices += [slice(sbins[-1], None)] @@ -735,11 +752,11 @@ def factorize(self, group): # are for the correct months,if not we have incomplete seasons unique_codes = np.arange(len(unique_coord)) if self.drop_incomplete: - for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False): + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=True): stamp_year, stamp_season = frame.index[idx] code = seasons.index(stamp_season) stamp_month = season_inds[code][idx] - if stamp_month != month[idx].item(): + if stamp_month != month[present_seasons][idx].item(): # we have an incomplete season! group_indices = group_indices[slicer] unique_coord = unique_coord[slicer] @@ -769,7 +786,9 @@ def factorize(self, group): if not full_index.equals(unique_coord): raise ValueError("Are there seasons missing in the middle of the dataset?") - codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) + final_codes = np.full(group.data.size, -1) + final_codes[present_seasons] = np.repeat(unique_codes, counts) + codes = group.copy(data=final_codes, deep=False) unique_coord_var = Variable(group.name, unique_coord, group.attrs) return EncodedGroups( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3949049f6a9..d9f1471c561 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2958,6 +2958,9 @@ def test_season_resampler(): # skip september da.groupby(time=SeasonResampler(["DJF", "MAM", "JJA", "ON"])).sum() + # "subsampling" + da.groupby(time=SeasonResampler(["JJAS"])).sum() + # overlapping with pytest.raises(ValueError): da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum()