diff --git a/doc/api.rst b/doc/api.rst index 63427447d53..ddba586fc83 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1139,6 +1139,8 @@ Grouper Objects groupers.BinGrouper groupers.UniqueGrouper groupers.TimeResampler + groupers.SeasonGrouper + groupers.SeasonResampler Rolling objects diff --git a/properties/test_properties.py b/properties/test_properties.py index fc0a1955539..24de8049f58 100644 --- a/properties/test_properties.py +++ b/properties/test_properties.py @@ -1,11 +1,15 @@ +import itertools + import pytest pytest.importorskip("hypothesis") +import hypothesis.strategies as st from hypothesis import given import xarray as xr import xarray.testing.strategies as xrst +from xarray.groupers import season_to_month_tuple @given(attrs=xrst.simple_attrs) @@ -15,3 +19,30 @@ def test_assert_identical(attrs): ds = xr.Dataset(attrs=attrs) xr.testing.assert_identical(ds, ds.copy(deep=True)) + + +@given( + roll=st.integers(min_value=0, max_value=12), + breaks=st.lists( + st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True + ), +) +def test_property_season_month_tuple(roll, breaks): + chars = list("JFMAMJJASOND") + months = tuple(range(1, 13)) + + rolled_chars = chars[roll:] + chars[:roll] + rolled_months = months[roll:] + months[:roll] + breaks = sorted(breaks) + if breaks[0] != 0: + breaks = [0] + breaks + if breaks[-1] != 12: + breaks = breaks + [12] + seasons = tuple( + "".join(rolled_chars[start:stop]) for start, stop in itertools.pairwise(breaks) + ) + actual = season_to_month_tuple(seasons) + expected = tuple( + rolled_months[start:stop] for start, stop in itertools.pairwise(breaks) + ) + assert expected == actual diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 52ce2463d51..7320daa11a5 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6850,7 +6850,7 @@ def groupby( >>> da.groupby("letters") + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b'> Execute a reduction @@ -6866,8 +6866,8 @@ def groupby( >>> da.groupby(["letters", "x"]) + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b' + 'x': UniqueGrouper('x'), 4/4 groups with labels 10, 20, 30, 40> Use Grouper objects to express more complicated GroupBy operations diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cc34a8cc04b..f4fd1eb1e58 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10448,7 +10448,7 @@ def groupby( >>> ds.groupby("letters") + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b'> Execute a reduction @@ -10465,8 +10465,8 @@ def groupby( >>> ds.groupby(["letters", "x"]) + 'letters': UniqueGrouper('letters'), 2/2 groups with labels 'a', 'b' + 'x': UniqueGrouper('x'), 4/4 groups with labels 10, 20, 30, 40> Use Grouper objects to express more complicated GroupBy operations diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 7a32cd7b1db..74f954bc308 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -253,6 +253,8 @@ def _ensure_1d( from xarray.core.dataarray import DataArray if isinstance(group, DataArray): + for dim in set(group.dims) - set(obj.dims): + obj = obj.expand_dims(dim) # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) @@ -750,7 +752,10 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}" + text += ( + f"\n {grouper.name!r}: {type(grouper.grouper).__name__}({grouper.group.name!r}), " + f"{coord.size}/{grouper.full_index.size} groups with labels {labels}" + ) return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: @@ -974,7 +979,7 @@ def _flox_reduce( parsed_dim_list = list() # preserve order for dim_ in itertools.chain( - *(grouper.group.dims for grouper in self.groupers) + *(grouper.codes.dims for grouper in self.groupers) ): if dim_ not in parsed_dim_list: parsed_dim_list.append(dim_) @@ -988,7 +993,7 @@ def _flox_reduce( # Better to control it here than in flox. for grouper in self.groupers: if any( - d not in grouper.group.dims and d not in obj.dims for d in parsed_dim + d not in grouper.codes.dims and d not in obj.dims for d in parsed_dim ): raise ValueError(f"cannot reduce over dimensions {dim}.") @@ -1232,9 +1237,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - if dim is None: - dim = (self._group_dim,) - # Dataset.quantile does this, do it for flox to ensure same output. q = np.asarray(q, dtype=np.float64) @@ -1253,7 +1255,7 @@ def quantile( self._obj.__class__.quantile, shortcut=False, q=q, - dim=dim, + dim=dim or self._group_dim, method=method, keep_attrs=keep_attrs, skipna=skipna, diff --git a/xarray/core/toolzcompat.py b/xarray/core/toolzcompat.py new file mode 100644 index 00000000000..4632419a845 --- /dev/null +++ b/xarray/core/toolzcompat.py @@ -0,0 +1,56 @@ +# This file contains functions copied from the toolz library in accordance +# with its license. The original copyright notice is duplicated below. + +# Copyright (c) 2013 Matthew Rocklin + +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# a. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# b. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# c. Neither the name of toolz nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. + + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. + + +def sliding_window(n, seq): + """A sequence of overlapping subsequences + + >>> list(sliding_window(2, [1, 2, 3, 4])) + [(1, 2), (2, 3), (3, 4)] + + This function creates a sliding window suitable for transformations like + sliding means / smoothing + + >>> mean = lambda seq: float(sum(seq)) / len(seq) + >>> list(map(mean, sliding_window(2, [1, 2, 3, 4]))) + [1.5, 2.5, 3.5] + """ + import collections + import itertools + + return zip( + *( + collections.deque(itertools.islice(it, i), 0) or it + for i, it in enumerate(itertools.tee(seq, n)) + ), + strict=False, + ) diff --git a/xarray/groupers.py b/xarray/groupers.py index 89b189e582e..60c7575ab42 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -7,9 +7,14 @@ from __future__ import annotations import datetime +import functools +import itertools +import operator from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from itertools import pairwise +from itertools import chain, pairwise from typing import TYPE_CHECKING, Any, Literal, cast import numpy as np @@ -17,19 +22,27 @@ from numpy.typing import ArrayLike from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops +from xarray.core.common import ( + _contains_cftime_datetimes, + _contains_datetime_like_objects, +) from xarray.core.computation import apply_ufunc from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.duck_array_ops import isnull +from xarray.core.formatting import first_n_items from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index from xarray.core.resample_cftime import CFTimeGrouper +from xarray.core.toolzcompat import sliding_window from xarray.core.types import ( Bins, DatetimeLike, GroupIndices, ResampleCompatible, + Self, SideOptions, ) from xarray.core.variable import Variable @@ -69,9 +82,9 @@ class EncodedGroups: codes: DataArray full_index: pd.Index - group_indices: GroupIndices - unique_coord: Variable | _DummyGroup - coords: Coordinates + group_indices: GroupIndices = field(init=False, repr=False) + unique_coord: Variable | _DummyGroup = field(init=False, repr=False) + coords: Coordinates = field(init=False, repr=False) def __init__( self, @@ -553,3 +566,370 @@ def unique_value_groups( if isinstance(values, pd.MultiIndex): values.names = ar.names return values, inverse + + +def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]: + initials = "JFMAMJJASOND" + starts = dict( + ("".join(s), i + 1) + for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=True) + ) + result: list[tuple[int, ...]] = [] + for i, season in enumerate(seasons): + if len(season) == 1: + if i < len(seasons) - 1: + suffix = seasons[i + 1][0] + else: + suffix = seasons[0][0] + else: + suffix = season[1] + + start = starts[season[0] + suffix] + + month_append = [] + for i in range(len(season[1:])): + elem = start + i + 1 + month_append.append(elem - 12 * (elem > 12)) + result.append((start,) + tuple(month_append)) + return tuple(result) + + +def inds_to_string(asints: tuple[tuple[int, ...], ...]) -> tuple[str, ...]: + inits = "JFMAMJJASOND" + return tuple("".join([inits[i_ - 1] for i_ in t]) for t in asints) + + +def is_sorted_periodic(lst): + n = len(lst) + + # Find the wraparound point where the list decreases + wrap_point = -1 + for i in range(1, n): + if lst[i] < lst[i - 1]: + wrap_point = i + break + + # If no wraparound point is found, the list is already sorted + if wrap_point == -1: + return True + + # Check if both parts around the wrap point are sorted + for i in range(1, wrap_point): + if lst[i] < lst[i - 1]: + return False + for i in range(wrap_point + 1, n): + if lst[i] < lst[i - 1]: + return False + + # Check wraparound condition + return lst[-1] <= lst[0] + + +@dataclass +class SeasonsGroup: + seasons: tuple[str, ...] + inds: tuple[tuple[int, ...], ...] + codes: Sequence[int] + + +def find_independent_seasons(seasons: Sequence[str]) -> Sequence[SeasonsGroup]: + """ + Iterates though a list of seasons e.g. ["DJF", "FMA", ...], + and splits that into multiple sequences of non-overlapping seasons. + """ + sinds = season_to_month_tuple(seasons) + grouped = defaultdict(list) + codes = defaultdict(list) + seen: set[tuple[int, ...]] = set() + idx = 0 + # This is quadratic, but the length of seasons is at most 12 + for i, current in enumerate(sinds): + # Start with a group + if current not in seen: + grouped[idx].append(current) + codes[idx].append(i) + seen.add(current) + + # Loop through remaining groups, and look for overlaps + for j, second in enumerate(sinds[i:]): + if not (set(chain(*grouped[idx])) & set(second)): + if second not in seen: + grouped[idx].append(second) + codes[idx].append(j + i) + seen.add(second) + if len(seen) == len(seasons): + break + # found all non-overlapping groups for this row, increment and start over + idx += 1 + + grouped_ints = tuple(tuple(idx) for idx in grouped.values() if idx) + return [ + SeasonsGroup(seasons=inds_to_string(inds), inds=inds, codes=codes) + for inds, codes in zip(grouped_ints, codes.values(), strict=False) + ] + + +@dataclass +class SeasonGrouper(Grouper): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: sequence of str + List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc. + + Examples + -------- + >>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"]) + SeasonGrouper(seasons=['JF', 'MAM', 'JJAS', 'OND']) + + The ordering is preserved + >>> SeasonGrouper(["MAM", "JJAS", "OND", "JF"]) + SeasonGrouper(seasons=['MAM', 'JJAS', 'OND', 'JF']) + + Overlapping seasons are allowed + >>> SeasonGrouper(["DJFM", "MAMJ", "JJAS", "SOND"]) + SeasonGrouper(seasons=['DJFM', 'MAMJ', 'JJAS', 'SOND']) + """ + + seasons: Sequence[str] + # drop_incomplete: bool = field(default=True) # TODO + + def factorize(self, group: T_Group) -> EncodedGroups: + if TYPE_CHECKING: + assert not isinstance(group, _DummyGroup) + if not _contains_datetime_like_objects(group.variable): + raise ValueError( + "SeasonGrouper can only be used to group by datetime-like arrays." + ) + months = group.dt.month.data + seasons_groups = find_independent_seasons(self.seasons) + codes_ = np.full((len(seasons_groups),) + group.shape, -1, dtype=np.int8) + group_indices: list[list[int]] = [[]] * len(self.seasons) + for axis_index, seasgroup in enumerate(seasons_groups): + for season_tuple, code in zip( + seasgroup.inds, seasgroup.codes, strict=False + ): + mask = np.isin(months, season_tuple) + codes_[axis_index, mask] = code + (indices,) = mask.nonzero() + group_indices[code] = indices.tolist() + + if np.all(codes_ == -1): + raise ValueError( + "Failed to group data. Are you grouping by a variable that is all NaN?" + ) + needs_dummy_dim = len(seasons_groups) > 1 + codes = DataArray( + dims=(("__season_dim__",) if needs_dummy_dim else tuple()) + group.dims, + data=codes_ if needs_dummy_dim else codes_.squeeze(), + attrs=group.attrs, + name="season", + ) + unique_coord = Variable("season", self.seasons, attrs=group.attrs) + full_index = pd.Index(self.seasons) + return EncodedGroups( + codes=codes, + group_indices=tuple(group_indices), + unique_coord=unique_coord, + full_index=full_index, + ) + + def reset(self) -> Self: + return type(self)(self.seasons) + + +@dataclass +class SeasonResampler(Resampler): + """Allows grouping using a custom definition of seasons. + + Parameters + ---------- + seasons: Sequence[str] + An ordered list of seasons. + drop_incomplete: bool + Whether to drop seasons that are not completely included in the data. + For example, if a time series starts in Jan-2001, and seasons includes `"DJF"` + then observations from Jan-2001, and Feb-2001 are ignored in the grouping + since Dec-2000 isn't present. + + Examples + -------- + >>> SeasonResampler(["JF", "MAM", "JJAS", "OND"]) + SeasonResampler(seasons=['JF', 'MAM', 'JJAS', 'OND'], drop_incomplete=True) + + >>> SeasonResampler(["DJFM", "AM", "JJA", "SON"]) + SeasonResampler(seasons=['DJFM', 'AM', 'JJA', 'SON'], drop_incomplete=True) + """ + + seasons: Sequence[str] + drop_incomplete: bool = field(default=True, kw_only=True) + season_inds: Sequence[Sequence[int]] = field(init=False, repr=False) + season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False) + + def __post_init__(self): + self.season_inds = season_to_month_tuple(self.seasons) + all_inds = functools.reduce(operator.add, self.season_inds) + if len(all_inds) > len(set(all_inds)): + raise ValueError( + f"Overlapping seasons are not allowed. Received {self.seasons!r}" + ) + self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=True)) + + if not is_sorted_periodic(list(itertools.chain(*self.season_inds))): + raise ValueError( + "Resampling is only supported with sorted seasons. " + f"Provided seasons {self.seasons!r} are not sorted." + ) + + def factorize(self, group: T_Group) -> EncodedGroups: + if group.ndim != 1: + raise ValueError( + "SeasonResampler can only be used to resample by 1D arrays." + ) + if not isinstance(group, DataArray) or not _contains_datetime_like_objects( + group.variable + ): + raise ValueError( + "SeasonResampler can only be used to group by datetime-like DataArrays." + ) + + seasons = self.seasons + season_inds = self.season_inds + season_tuples = self.season_tuples + + nstr = max(len(s) for s in seasons) + year = group.dt.year.astype(int) + month = group.dt.month.astype(int) + season_label = np.full(group.shape, "", dtype=f"U{nstr}") + + # offset years for seasons with December and January + for season_str, season_ind in zip(seasons, season_inds, strict=True): + season_label[month.isin(season_ind)] = season_str + if "DJ" in season_str: + after_dec = season_ind[season_str.index("D") + 1 :] + # important: this is assuming non-overlapping seasons + year[month.isin(after_dec)] -= 1 + + # Allow users to skip one or more months? + # present_seasons is a mask that is True for months that are requested in the output + present_seasons = season_label != "" + if present_seasons.all(): + # avoid copies if we can. + present_seasons = slice(None) + frame = pd.DataFrame( + data={ + "index": np.arange(group[present_seasons].size), + "month": month[present_seasons], + }, + index=pd.MultiIndex.from_arrays( + [year.data[present_seasons], season_label[present_seasons]], + names=["year", "season"], + ), + ) + + agged = ( + frame["index"] + .groupby(["year", "season"], sort=False) + .agg(["first", "count"]) + ) + first_items = agged["first"] + counts = agged["count"] + + index_class: CFTimeIndex | pd.DatetimeIndex + if _contains_cftime_datetimes(group.data): + index_class = CFTimeIndex + datetime_class = type(first_n_items(group.data, 1).item()) + else: + index_class = pd.DatetimeIndex + datetime_class = datetime.datetime + + # these are the seasons that are present + unique_coord = index_class( + [ + datetime_class(year=year, month=season_tuples[season][0], day=1) + for year, season in first_items.index + ] + ) + + # sbins = first_items.values.astype(int) + # group_indices = [ + # slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=True) + # ] + # group_indices += [slice(sbins[-1], None)] + + # This sorted call is a hack. It's hard to figure out how + # to start the iteration for arbitrary season ordering + # for example "DJF" as first entry or last entry + # So we construct the largest possible index and slice it to the + # range present in the data. + complete_index = index_class( + sorted( + [ + datetime_class(year=y, month=m, day=1) + for y, m in itertools.product( + range(year[0].item(), year[-1].item() + 1), + [s[0] for s in season_inds], + ) + ] + ) + ) + + # all years and seasons + def get_label(year, season): + month, *_ = season_tuples[season] + return f"{year}-{month:02d}-01" + + unique_codes = np.arange(len(unique_coord)) + valid_season_mask = season_label != "" + first_valid_season, last_valid_season = season_label[valid_season_mask][[0, -1]] + first_year, last_year = year.data[[0, -1]] + if self.drop_incomplete: + if month.data[valid_season_mask][0] != season_tuples[first_valid_season][0]: + if "DJ" in first_valid_season: + first_year += 1 + first_valid_season = seasons[ + (seasons.index(first_valid_season) + 1) % len(seasons) + ] + # group_indices = group_indices[slice(1, None)] + unique_codes -= 1 + + if ( + month.data[valid_season_mask][-1] + != season_tuples[last_valid_season][-1] + ): + last_valid_season = seasons[seasons.index(last_valid_season) - 1] + if "DJ" in last_valid_season: + last_year -= 1 + # group_indices = group_indices[slice(-1)] + unique_codes[-1] = -1 + + first_label = get_label(first_year, first_valid_season) + last_label = get_label(last_year, last_valid_season) + + slicer = complete_index.slice_indexer(first_label, last_label) + full_index = complete_index[slicer] + # TODO: group must be sorted + # codes = np.searchsorted(edges, group.data, side="left") + # codes -= 1 + # codes[~present_seasons | group.data >= edges[-1]] = -1 + # codes[isnull(group.data)] = -1 + # import ipdb; ipdb.set_trace() + # check that there are no "missing" seasons in the middle + # if not full_index.equals(unique_coord): + # raise ValueError("Are there seasons missing in the middle of the dataset?") + + final_codes = np.full(group.data.size, -1) + final_codes[present_seasons] = np.repeat(unique_codes, counts) + codes = group.copy(data=final_codes, deep=False) + # unique_coord_var = Variable(group.name, unique_coord, group.attrs) + + return EncodedGroups( + codes=codes, + # group_indices=group_indices, + # unique_coord=unique_coord_var, + full_index=full_index, + ) + + def reset(self) -> Self: + return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5ed334e61dd..349389a6448 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -16,6 +16,7 @@ import xarray.testing from xarray import Dataset +from xarray.coding.times import _STANDARD_CALENDARS as _STANDARD_CALENDARS_UNSORTED from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.extension_array import PandasExtensionArray from xarray.core.options import set_options @@ -353,12 +354,52 @@ def create_test_data( _CFTIME_CALENDARS = [ + pytest.param( + cal, marks=pytest.mark.skipif(not has_cftime, reason="requires cftime") + ) + for cal in sorted( + [ + "365_day", + "360_day", + "julian", + "all_leap", + "366_day", + "gregorian", + "proleptic_gregorian", + "standard", + ] + ) +] + +_STANDARD_CALENDAR_NAMES = sorted(_STANDARD_CALENDARS_UNSORTED) +_NON_STANDARD_CALENDAR_NAMES = { + "noleap", "365_day", "360_day", "julian", "all_leap", "366_day", - "gregorian", - "proleptic_gregorian", - "standard", +} +_NON_STANDARD_CALENDARS = [ + pytest.param( + cal, marks=pytest.mark.skipif(not has_cftime, reason="requires cftime") + ) + for cal in sorted(_NON_STANDARD_CALENDAR_NAMES) ] +_STANDARD_CALENDARS = [pytest.param(_) for _ in _STANDARD_CALENDAR_NAMES] +_ALL_CALENDARS = sorted(_STANDARD_CALENDARS + _NON_STANDARD_CALENDARS) + + +def _all_cftime_date_types(): + import cftime + + return { + "noleap": cftime.DatetimeNoLeap, + "365_day": cftime.DatetimeNoLeap, + "360_day": cftime.Datetime360Day, + "julian": cftime.DatetimeJulian, + "all_leap": cftime.DatetimeAllLeap, + "366_day": cftime.DatetimeAllLeap, + "gregorian": cftime.DatetimeGregorian, + "proleptic_gregorian": cftime.DatetimeProlepticGregorian, + } diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 64309966103..0d51a292be1 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -6,6 +6,8 @@ import xarray as xr from xarray.tests import ( + _CFTIME_CALENDARS, + _all_cftime_date_types, assert_allclose, assert_array_equal, assert_chunks_equal, @@ -390,15 +392,6 @@ def test_dask_accessor_method(self, method, parameters) -> None: assert_equal(actual.compute(), expected.compute()) -_CFTIME_CALENDARS = [ - "365_day", - "360_day", - "julian", - "all_leap", - "366_day", - "gregorian", - "proleptic_gregorian", -] _NT = 100 @@ -407,6 +400,13 @@ def calendar(request): return request.param +@pytest.fixture() +def cftime_date_type(calendar): + if calendar == "standard": + calendar = "proleptic_gregorian" + return _all_cftime_date_types()[calendar] + + @pytest.fixture() def times(calendar): import cftime @@ -571,13 +571,6 @@ def test_dask_field_access(times_3d, data, field) -> None: assert_equal(result.compute(), expected) -@pytest.fixture() -def cftime_date_type(calendar): - from xarray.tests.test_coding_times import _all_cftime_date_types - - return _all_cftime_date_types()[calendar] - - @requires_cftime def test_seasons(cftime_date_type) -> None: dates = xr.DataArray( diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 116487e2bcf..03ea5e544ed 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -18,16 +18,14 @@ parse_iso8601_like, ) from xarray.tests import ( + _ALL_CALENDARS, + _NON_STANDARD_CALENDAR_NAMES, + _all_cftime_date_types, assert_array_equal, assert_identical, has_cftime, requires_cftime, ) -from xarray.tests.test_coding_times import ( - _ALL_CALENDARS, - _NON_STANDARD_CALENDARS, - _all_cftime_date_types, -) # cftime 1.5.2 renames "gregorian" to "standard" standard_or_gregorian = "" @@ -1161,7 +1159,7 @@ def test_to_datetimeindex(calendar, unsafe): index = xr.cftime_range("2000", periods=5, calendar=calendar) expected = pd.date_range("2000", periods=5) - if calendar in _NON_STANDARD_CALENDARS and not unsafe: + if calendar in _NON_STANDARD_CALENDAR_NAMES and not unsafe: with pytest.warns(RuntimeWarning, match="non-standard"): result = index.to_datetimeindex() else: diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 9a51ca40d07..90c5724782e 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -19,7 +19,6 @@ date_range, decode_cf, ) -from xarray.coding.times import _STANDARD_CALENDARS as _STANDARD_CALENDARS_UNSORTED from xarray.coding.times import ( CFDatetimeCoder, _encode_datetime_with_cftime, @@ -42,7 +41,12 @@ from xarray.core.utils import is_duck_dask_array from xarray.testing import assert_equal, assert_identical from xarray.tests import ( + _ALL_CALENDARS, + _NON_STANDARD_CALENDARS, + _STANDARD_CALENDAR_NAMES, + _STANDARD_CALENDARS, FirstElementAccessibleArray, + _all_cftime_date_types, arm_xfail, assert_array_equal, assert_duckarray_allclose, @@ -53,17 +57,6 @@ requires_dask, ) -_NON_STANDARD_CALENDARS_SET = { - "noleap", - "365_day", - "360_day", - "julian", - "all_leap", - "366_day", -} -_STANDARD_CALENDARS = sorted(_STANDARD_CALENDARS_UNSORTED) -_ALL_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET.union(_STANDARD_CALENDARS)) -_NON_STANDARD_CALENDARS = sorted(_NON_STANDARD_CALENDARS_SET) _CF_DATETIME_NUM_DATES_UNITS = [ (np.arange(10), "days since 2000-01-01"), (np.arange(10).astype("float64"), "days since 2000-01-01"), @@ -99,26 +92,11 @@ _CF_DATETIME_TESTS = [ num_dates_units + (calendar,) for num_dates_units, calendar in product( - _CF_DATETIME_NUM_DATES_UNITS, _STANDARD_CALENDARS + _CF_DATETIME_NUM_DATES_UNITS, _STANDARD_CALENDAR_NAMES ) ] -def _all_cftime_date_types(): - import cftime - - return { - "noleap": cftime.DatetimeNoLeap, - "365_day": cftime.DatetimeNoLeap, - "360_day": cftime.Datetime360Day, - "julian": cftime.DatetimeJulian, - "all_leap": cftime.DatetimeAllLeap, - "366_day": cftime.DatetimeAllLeap, - "gregorian": cftime.DatetimeGregorian, - "proleptic_gregorian": cftime.DatetimeProlepticGregorian, - } - - @requires_cftime @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") @pytest.mark.filterwarnings("ignore:Times can't be serialized faithfully") @@ -666,13 +644,13 @@ def test_decode_cf(calendar) -> None: ds[v].attrs["units"] = "days since 2001-01-01" ds[v].attrs["calendar"] = calendar - if not has_cftime and calendar not in _STANDARD_CALENDARS: + if not has_cftime and calendar not in _STANDARD_CALENDAR_NAMES: with pytest.raises(ValueError): ds = decode_cf(ds) else: ds = decode_cf(ds) - if calendar not in _STANDARD_CALENDARS: + if calendar not in _STANDARD_CALENDAR_NAMES: assert ds.test.dtype == np.dtype("O") else: assert ds.test.dtype == np.dtype("M8[ns]") @@ -1006,7 +984,7 @@ def test_decode_ambiguous_time_warns(calendar) -> None: # we don't decode non-standard calendards with # pandas so expect no warning will be emitted - is_standard_calendar = calendar in _STANDARD_CALENDARS + is_standard_calendar = calendar in _STANDARD_CALENDAR_NAMES dates = [1, 2, 3] units = "days since 1-1-1" diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3fc7fcac132..5e1a4401d0d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -12,7 +12,7 @@ from packaging.version import Version import xarray as xr -from xarray import DataArray, Dataset, Variable +from xarray import DataArray, Dataset, Variable, cftime_range, date_range from xarray.core.alignment import broadcast from xarray.core.groupby import _consolidate_slices from xarray.core.types import InterpOptions, ResampleCompatible @@ -20,11 +20,15 @@ BinGrouper, EncodedGroups, Grouper, + SeasonGrouper, + SeasonResampler, TimeResampler, UniqueGrouper, + season_to_month_tuple, ) from xarray.namedarray.pycompat import is_chunked_array from xarray.tests import ( + _ALL_CALENDARS, InaccessibleArray, assert_allclose, assert_equal, @@ -610,7 +614,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n {dim!r}: {N}/{N} groups present with labels " + expected += f"\n {dim!r}: UniqueGrouper({dim!r}), {N}/{N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -627,7 +631,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += " 'month': 12/12 groups present with labels " + expected += " 'month': UniqueGrouper('month'), 12/12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected @@ -3142,6 +3146,162 @@ def test_groupby_dask_eager_load_warnings(): ds.groupby_bins("x", bins=[1, 2, 3], eagerly_compute_group=False) +class TestSeasonGrouperAndResampler: + def test_season_to_month_tuple(self): + assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == ( + (1, 2), + (3, 4, 5), + (6, 7, 8, 9), + (10, 11, 12), + ) + assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == ( + (12, 1, 2, 3), + (4, 5), + (6, 7, 8, 9), + (10, 11), + ) + + @pytest.mark.parametrize("calendar", _ALL_CALENDARS) + def test_season_grouper_simple(self, calendar) -> None: + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + expected = da.groupby("time.season").mean() + # note season order matches expected + actual = da.groupby( + time=SeasonGrouper( + ["DJF", "JJA", "MAM", "SON"], # drop_incomplete=False + ) + ).mean() + assert_identical(expected, actual) + + @pytest.mark.parametrize("seasons", [["JJA", "MAM", "SON", "DJF"]]) + def test_season_resampling_raises_unsorted_seasons(self, seasons): + calendar = "standard" + time = date_range("2001-01-01", "2002-12-30", freq="D", calendar=calendar) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + with pytest.raises(ValueError, match="sort"): + da.resample(time=SeasonResampler(seasons)) + + @pytest.mark.parametrize( + "use_cftime", + [ + pytest.param( + True, marks=pytest.mark.skipif(not has_cftime, reason="no cftime") + ), + False, + ], + ) + @pytest.mark.parametrize("drop_incomplete", [True, False]) + @pytest.mark.parametrize( + "seasons", + [ + pytest.param(["DJF", "MAM", "JJA", "SON"], id="standard"), + pytest.param(["NDJ", "FMA", "MJJ", "ASO"], id="nov-first"), + pytest.param(["MAM", "JJA", "SON", "DJF"], id="standard-diff-order"), + pytest.param(["JFM", "AMJ", "JAS", "OND"], id="december-same-year"), + pytest.param(["DJF", "MAM", "JJA", "ON"], id="skip-september"), + pytest.param(["JJAS"], id="jjas-only"), + pytest.param(["MAM", "JJA", "SON", "DJF"], id="different-order"), + ], + ) + def test_season_resampler( + self, seasons: list[str], drop_incomplete: bool, use_cftime: bool + ) -> None: + calendar = "standard" + time = date_range( + "2001-01-01", + "2002-12-30", + freq="D", + calendar=calendar, + use_cftime=use_cftime, + ) + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + counts = da.resample(time="ME").count() + + seasons_as_ints = season_to_month_tuple(seasons) + month = counts.time.dt.month.data + year = counts.time.dt.year.data + for season, as_ints in zip(seasons, seasons_as_ints, strict=True): + if "DJ" in season: + for imonth in as_ints[season.index("D") + 1 :]: + year[month == imonth] -= 1 + counts["time"] = ( + "time", + [pd.Timestamp(f"{y}-{m}-01") for y, m in zip(year, month, strict=True)], + ) + if has_cftime: + counts = counts.convert_calendar(calendar, "time", align_on="date") + + expected_vals = [] + expected_time = [] + for year in [2001, 2002, 2003]: + for season, as_ints in zip(seasons, seasons_as_ints, strict=True): + out_year = year + if "DJ" in season: + out_year = year - 1 + if out_year == 2003: + # this is a dummy year added to make sure we cover 2002-DJF + continue + available = [ + counts.sel(time=f"{out_year}-{month:02d}").data for month in as_ints + ] + if any(len(a) == 0 for a in available) and drop_incomplete: + continue + output_label = pd.Timestamp(f"{out_year}-{as_ints[0]:02d}-01") + expected_time.append(output_label) + # use concatenate to handle empty array when dec value does not exist + expected_vals.append(np.concatenate(available).sum()) + + expected = ( + # we construct expected in the standard calendar + xr.DataArray(expected_vals, dims="time", coords={"time": expected_time}) + ) + if has_cftime: + # and then convert to the expected calendar, + expected = expected.convert_calendar( + calendar, align_on="date", use_cftime=use_cftime + ) + # and finally sort since DJF will be out-of-order + expected = expected.sortby("time") + + rs = SeasonResampler(seasons, drop_incomplete=drop_incomplete) + # through resample + actual = da.resample(time=rs).sum() + assert_identical(actual, expected) + + @requires_cftime + def test_season_resampler_errors(self): + time = cftime_range("2001-01-01", "2002-12-30", freq="D", calendar="360_day") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + # non-datetime array + with pytest.raises(ValueError): + DataArray(np.ones(5), dims="time").groupby(time=SeasonResampler(["DJF"])) + + # ndim > 1 array + with pytest.raises(ValueError): + DataArray( + np.ones((5, 5)), dims=("t", "x"), coords={"x": np.arange(5)} + ).groupby(x=SeasonResampler(["DJF"])) + + # overlapping seasons + with pytest.raises(ValueError): + da.groupby(time=SeasonResampler(["DJFM", "MAMJ", "JJAS", "SOND"])).sum() + + @requires_cftime + def test_season_resampler_groupby_identical(self): + time = date_range("2001-01-01", "2002-12-30", freq="D") + da = DataArray(np.ones(time.size), dims="time", coords={"time": time}) + + # through resample + resampler = SeasonResampler(["DJF", "MAM", "JJA", "SON"]) + rs = da.resample(time=resampler).sum() + + # through groupby + gb = da.groupby(time=resampler).sum() + assert_identical(rs, gb) + + # TODO: Possible property tests to add to this module # 1. lambda x: x # 2. grouped-reduce on unique coords is identical to array