diff --git a/properties/test_properties.py b/properties/test_properties.py index fc0a1955539..859f9d4e500 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -2,10 +2,12 @@ pytest.importorskip("hypothesis") +import hypothesis.strategies as st from hypothesis import given import xarray as xr import xarray.testing.strategies as xrst +from xarray.groupers import season_to_month_tuple @given(attrs=xrst.simple_attrs) @@ -15,3 +17,32 @@ def test_assert_identical(attrs): ds = xr.Dataset(attrs=attrs) xr.testing.assert_identical(ds, ds.copy(deep=True)) + + +@given( + roll=st.integers(min_value=0, max_value=12), + breaks=st.lists( + st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True + ), +) +def test_property_season_month_tuple(roll, breaks): + chars = list("JFMAMJJASOND") + months = tuple(range(1, 13)) + + rolled_chars = chars[roll:] + chars[:roll] + rolled_months = months[roll:] + months[:roll] + breaks = sorted(breaks) + if breaks[0] != 0: + breaks = [0] + breaks + if breaks[-1] != 12: + breaks = breaks + [12] + seasons = tuple( + "".join(rolled_chars[start:stop]) + for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + ) + actual = season_to_month_tuple(seasons) + expected = tuple( + rolled_months[start:stop] + for start, stop in zip(breaks[:-1], breaks[1:], strict=False) + ) + assert expected == actual diff --git a/xarray/core/toolzcompat.py b/xarray/core/toolzcompat.py new file mode 100644 index 00000000000..5eeb9681afc --- /dev/null +++ b/xarray/core/toolzcompat.py @@ -0,0 +1,55 @@ +# This file contains functions copied from the toolz library in accordance +# with its license. The original copyright notice is duplicated below. + +# Copyright (c) 2013 Matthew Rocklin + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of toolz nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. + + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + + +def sliding_window(n, seq): + """A sequence of overlapping subsequences + + >>> list(sliding_window(2, [1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4)] + + This function creates a sliding window suitable for transformations like + sliding means / smoothing + + >>> mean = lambda seq: float(sum(seq)) / len(seq) + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) + [1.5, 2.5, 3.5] + """ + import collections + import itertools + + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ) + ) diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..434bd764fa4 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,7 +7,9 @@ from __future__ import annotations import datetime +import itertools from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, cast @@ -16,6 +18,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops +from xarray.core.common import _contains_datetime_like_objects from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup @@ -28,6 +31,7 @@ ResampleCompatible, SideOptions, ) +from xarray.core.toolzcompat import sliding_window from xarray.core.variable import Variable __all__ = [ @@ -485,3 +489,215 @@ def unique_value_groups( if isinstance(values, pd.MultiIndex): values.names = ar.names return values, inverse + + +def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: + initials = "JFMAMJJASOND" + starts = dict( + ("".join(s), i + 1) + for s, i in zip(sliding_window(2, initials + "J"), range(12)) + ) + result: list[tuple[int, ...]] = [] + for i, season in enumerate(seasons): + if len(season) == 1: + if i < len(seasons) - 1: + suffix = seasons[i + 1][0] + else: + suffix = seasons[0][0] + else: + suffix = season[1] + + start = starts[season[0] + suffix] + + month_append = [] + for i in range(len(season[1:])): + elem = start + i + 1 + month_append.append(elem - 12 * (elem > 12)) + result.append((start,) + tuple(month_append)) + return tuple(result) + + +@dataclass +class SeasonGrouper(Grouper): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: sequence of str + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + + Examples + -------- + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + # drop_incomplete: bool = field(default=True) # TODO + + def __post_init__(self) -> None: + self.season_inds = season_to_month_tuple(self.seasons) + + def factorize(self, group: T_Group) -> EncodedGroups: + if TYPE_CHECKING: + assert not isinstance(group, _DummyGroup) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonGrouper can only be used to group by datetime-like arrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + + months = group.dt.month + codes_ = np.full(group.shape, -1) + group_indices: list[list[int]] = [[]] * len(seasons) + + index = np.arange(group.size) + for idx, season_tuple in enumerate(season_inds): + mask = months.isin(season_tuple) + codes_[mask] = idx + group_indices[idx] = index[mask] + + if np.all(codes_ == -1): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + codes = group.copy(data=codes_, deep=False).rename("season") + unique_coord = Variable("season", seasons, attrs=group.attrs) + full_index = pd.Index(seasons) + return EncodedGroups( + codes=codes, + group_indices=tuple(group_indices), + unique_coord=unique_coord, + full_index=full_index, + ) + + +@dataclass +class SeasonResampler(Resampler): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: Sequence[str] + An ordered list of seasons. + drop_incomplete: bool + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. + + Examples + -------- + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + """ + + seasons: Sequence[str] + drop_incomplete: bool = field(default=True) + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + self.season_tuples = dict(zip(self.seasons, self.season_inds)) + + def factorize(self, group): + if group.ndim != 1: + raise ValueError( + "SeasonResampler can only be used to resample by 1D arrays." + ) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonResampler can only be used to group by datetime-like arrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + season_tuples = self.season_tuples + + nstr = max(len(s) for s in seasons) + year = group.dt.year.astype(int) + month = group.dt.month.astype(int) + season_label = np.full(group.shape, "", dtype=f"U{nstr}") + + # offset years for seasons with December and January + for season_str, season_ind in zip(seasons, season_inds): + season_label[month.isin(season_ind)] = season_str + if "DJ" in season_str: + after_dec = season_ind[season_str.index("D") + 1 :] + year[month.isin(after_dec)] -= 1 + + frame = pd.DataFrame( + data={"index": np.arange(group.size), "month": month}, + index=pd.MultiIndex.from_arrays( + [year.data, season_label], names=["year", "season"] + ), + ) + + series = frame["index"] + g = series.groupby(["year", "season"], sort=False) + first_items = g.first() + counts = g.count() + + # these are the seasons that are present + unique_coord = pd.DatetimeIndex( + [ + pd.Timestamp(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) + + sbins = first_items.values.astype(int) + group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + group_indices += [slice(sbins[-1], None)] + + # Make sure the first and last timestamps + # are for the correct months,if not we have incomplete seasons + unique_codes = np.arange(len(unique_coord)) + if self.drop_incomplete: + for idx, slicer in zip([0, -1], (slice(1, None), slice(-1))): + stamp_year, stamp_season = frame.index[idx] + code = seasons.index(stamp_season) + stamp_month = season_inds[code][idx] + if stamp_month != month[idx].item(): + # we have an incomplete season! + group_indices = group_indices[slicer] + unique_coord = unique_coord[slicer] + if idx == 0: + unique_codes -= 1 + unique_codes[idx] = -1 + + # all years and seasons + complete_index = pd.DatetimeIndex( + # This sorted call is a hack. It's hard to figure out how + # to start the iteration + sorted( + [ + pd.Timestamp(f"{y}-{m}-01") + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) + ) + # only keep that included in data + range_ = complete_index.get_indexer(unique_coord[[0, -1]]) + full_index = complete_index[slice(range_[0], range_[-1] + 1)] + # check that there are no "missing" seasons in the middle + # print(full_index, unique_coord) + if not full_index.equals(unique_coord): + raise ValueError("Are there seasons missing in the middle of the dataset?") + + codes = group.copy(data=np.repeat(unique_codes, counts), deep=False) + unique_coord_var = Variable(group.name, unique_coord, group.attrs) + + return EncodedGroups( + codes=codes, + group_indices=group_indices, + unique_coord=unique_coord_var, + full_index=full_index, + ) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dc869cc3a34..0a5fbccbdf4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -21,6 +21,7 @@ Grouper, TimeResampler, UniqueGrouper, + season_to_month_tuple, ) from xarray.tests import ( InaccessibleArray, @@ -2915,12 +2916,6 @@ def test_gappy_resample_reductions(reduction): assert_identical(expected, actual) -# Possible property tests -# 1. lambda x: x -# 2. grouped-reduce on unique coords is identical to array -# 3. group_over == groupby-reduce along other dimensions - - def test_groupby_transpose(): # GH5361 data = xr.DataArray( @@ -2932,3 +2927,24 @@ def test_groupby_transpose(): second = data.groupby("x").sum() assert_identical(first, second.transpose(*first.dims)) + + +def test_season_to_month_tuple(): + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( + (1, 2), + (3, 4, 5), + (6, 7, 8, 9), + (10, 11, 12), + ) + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( + (12, 1, 2, 3), + (4, 5), + (6, 7, 8, 9), + (10, 11), + ) + + +# Possible property tests +# 1. lambda x: x +# 2. grouped-reduce on unique coords is identical to array +# 3. group_over == groupby-reduce along other dimensions