Skip to content

Commit

Permalink
Add SeasonGrouper, SeasonResampler
Browse files Browse the repository at this point in the history
These two groupers allow defining custom seasons, and dropping
incomplete seasons from the output. Both cases are treated by adjusting
the factorization -- conversion from group labels to integer codes -- appropriately.
  • Loading branch information
dcherian committed Sep 20, 2024
1 parent 3c74509 commit 594d4a7
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,8 @@ Grouper Objects
groupers.BinGrouper
groupers.UniqueGrouper
groupers.TimeResampler
groupers.SeasonGrouper
groupers.SeasonResampler


Rolling objects
Expand Down
31 changes: 31 additions & 0 deletions properties/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

pytest.importorskip("hypothesis")

import hypothesis.strategies as st
from hypothesis import given

import xarray as xr
import xarray.testing.strategies as xrst
from xarray.groupers import season_to_month_tuple


@given(attrs=xrst.simple_attrs)
Expand All @@ -15,3 +17,32 @@ def test_assert_identical(attrs):

ds = xr.Dataset(attrs=attrs)
xr.testing.assert_identical(ds, ds.copy(deep=True))


@given(
roll=st.integers(min_value=0, max_value=12),
breaks=st.lists(
st.integers(min_value=0, max_value=11), min_size=1, max_size=12, unique=True
),
)
def test_property_season_month_tuple(roll, breaks):
chars = list("JFMAMJJASOND")
months = tuple(range(1, 13))

rolled_chars = chars[roll:] + chars[:roll]
rolled_months = months[roll:] + months[:roll]
breaks = sorted(breaks)
if breaks[0] != 0:
breaks = [0] + breaks
if breaks[-1] != 12:
breaks = breaks + [12]
seasons = tuple(
"".join(rolled_chars[start:stop])
for start, stop in zip(breaks[:-1], breaks[1:], strict=False)
)
actual = season_to_month_tuple(seasons)
expected = tuple(
rolled_months[start:stop]
for start, stop in zip(breaks[:-1], breaks[1:], strict=False)
)
assert expected == actual
56 changes: 56 additions & 0 deletions xarray/core/toolzcompat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# This file contains functions copied from the toolz library in accordance
# with its license. The original copyright notice is duplicated below.

# Copyright (c) 2013 Matthew Rocklin

# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of toolz nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.


# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.


def sliding_window(n, seq):
"""A sequence of overlapping subsequences
>>> list(sliding_window(2, [1, 2, 3, 4]))
[(1, 2), (2, 3), (3, 4)]
This function creates a sliding window suitable for transformations like
sliding means / smoothing
>>> mean = lambda seq: float(sum(seq)) / len(seq)
>>> list(map(mean, sliding_window(2, [1, 2, 3, 4])))
[1.5, 2.5, 3.5]
"""
import collections
import itertools

return zip(
*(
collections.deque(itertools.islice(it, i), 0) or it
for i, it in enumerate(itertools.tee(seq, n))
),
strict=False,
)
218 changes: 218 additions & 0 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from __future__ import annotations

import datetime
import itertools
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast

Expand All @@ -16,11 +18,13 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.common import _contains_datetime_like_objects
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.groupby import T_Group, _DummyGroup
from xarray.core.indexes import safe_cast_to_index
from xarray.core.resample_cftime import CFTimeGrouper
from xarray.core.toolzcompat import sliding_window
from xarray.core.types import (
Bins,
DatetimeLike,
Expand Down Expand Up @@ -485,3 +489,217 @@ def unique_value_groups(
if isinstance(values, pd.MultiIndex):
values.names = ar.names
return values, inverse


def season_to_month_tuple(seasons: Sequence[str]) -> tuple[tuple[int, ...], ...]:
initials = "JFMAMJJASOND"
starts = dict(
("".join(s), i + 1)
for s, i in zip(sliding_window(2, initials + "J"), range(12), strict=False)
)
result: list[tuple[int, ...]] = []
for i, season in enumerate(seasons):
if len(season) == 1:
if i < len(seasons) - 1:
suffix = seasons[i + 1][0]
else:
suffix = seasons[0][0]
else:
suffix = season[1]

start = starts[season[0] + suffix]

month_append = []
for i in range(len(season[1:])):
elem = start + i + 1
month_append.append(elem - 12 * (elem > 12))
result.append((start,) + tuple(month_append))
return tuple(result)


@dataclass
class SeasonGrouper(Grouper):
"""Allows grouping using a custom definition of seasons.
Parameters
----------
seasons: sequence of str
List of strings representing seasons. E.g. ``"JF"`` or ``"JJA"`` etc.
Examples
--------
>>> SeasonGrouper(["JF", "MAM", "JJAS", "OND"])
>>> SeasonGrouper(["DJFM", "AM", "JJA", "SON"])
"""

seasons: Sequence[str]
season_inds: Sequence[Sequence[int]] = field(init=False, repr=False)
# drop_incomplete: bool = field(default=True) # TODO

def __post_init__(self) -> None:
self.season_inds = season_to_month_tuple(self.seasons)

def factorize(self, group: T_Group) -> EncodedGroups:
if TYPE_CHECKING:
assert not isinstance(group, _DummyGroup)
if not _contains_datetime_like_objects(group.variable):
raise ValueError(
"SeasonGrouper can only be used to group by datetime-like arrays."
)

seasons = self.seasons
season_inds = self.season_inds

months = group.dt.month
codes_ = np.full(group.shape, -1)
group_indices: list[list[int]] = [[]] * len(seasons)

index = np.arange(group.size)
for idx, season_tuple in enumerate(season_inds):
mask = months.isin(season_tuple)
codes_[mask] = idx
group_indices[idx] = index[mask]

if np.all(codes_ == -1):
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = group.copy(data=codes_, deep=False).rename("season")
unique_coord = Variable("season", seasons, attrs=group.attrs)
full_index = pd.Index(seasons)
return EncodedGroups(
codes=codes,
group_indices=tuple(group_indices),
unique_coord=unique_coord,
full_index=full_index,
)


@dataclass
class SeasonResampler(Resampler):
"""Allows grouping using a custom definition of seasons.
Parameters
----------
seasons: Sequence[str]
An ordered list of seasons.
drop_incomplete: bool
Whether to drop seasons that are not completely included in the data.
For example, if a time series starts in Jan-2001, and seasons includes `"DJF"`
then observations from Jan-2001, and Feb-2001 are ignored in the grouping
since Dec-2000 isn't present.
Examples
--------
>>> SeasonResampler(["JF", "MAM", "JJAS", "OND"])
>>> SeasonResampler(["DJFM", "AM", "JJA", "SON"])
"""

seasons: Sequence[str]
drop_incomplete: bool = field(default=True)
season_inds: Sequence[Sequence[int]] = field(init=False, repr=False)
season_tuples: Mapping[str, Sequence[int]] = field(init=False, repr=False)

def __post_init__(self):
self.season_inds = season_to_month_tuple(self.seasons)
self.season_tuples = dict(zip(self.seasons, self.season_inds, strict=False))

def factorize(self, group):
if group.ndim != 1:
raise ValueError(
"SeasonResampler can only be used to resample by 1D arrays."
)
if not _contains_datetime_like_objects(group.variable):
raise ValueError(
"SeasonResampler can only be used to group by datetime-like arrays."
)

seasons = self.seasons
season_inds = self.season_inds
season_tuples = self.season_tuples

nstr = max(len(s) for s in seasons)
year = group.dt.year.astype(int)
month = group.dt.month.astype(int)
season_label = np.full(group.shape, "", dtype=f"U{nstr}")

# offset years for seasons with December and January
for season_str, season_ind in zip(seasons, season_inds, strict=False):
season_label[month.isin(season_ind)] = season_str
if "DJ" in season_str:
after_dec = season_ind[season_str.index("D") + 1 :]
year[month.isin(after_dec)] -= 1

frame = pd.DataFrame(
data={"index": np.arange(group.size), "month": month},
index=pd.MultiIndex.from_arrays(
[year.data, season_label], names=["year", "season"]
),
)

series = frame["index"]
g = series.groupby(["year", "season"], sort=False)
first_items = g.first()
counts = g.count()

# these are the seasons that are present
unique_coord = pd.DatetimeIndex(
[
pd.Timestamp(year=year, month=season_tuples[season][0], day=1)
for year, season in first_items.index
]
)

sbins = first_items.values.astype(int)
group_indices = [
slice(i, j) for i, j in zip(sbins[:-1], sbins[1:], strict=False)
]
group_indices += [slice(sbins[-1], None)]

# Make sure the first and last timestamps
# are for the correct months,if not we have incomplete seasons
unique_codes = np.arange(len(unique_coord))
if self.drop_incomplete:
for idx, slicer in zip([0, -1], (slice(1, None), slice(-1)), strict=False):
stamp_year, stamp_season = frame.index[idx]
code = seasons.index(stamp_season)
stamp_month = season_inds[code][idx]
if stamp_month != month[idx].item():
# we have an incomplete season!
group_indices = group_indices[slicer]
unique_coord = unique_coord[slicer]
if idx == 0:
unique_codes -= 1
unique_codes[idx] = -1

# all years and seasons
complete_index = pd.DatetimeIndex(
# This sorted call is a hack. It's hard to figure out how
# to start the iteration
sorted(
[
pd.Timestamp(f"{y}-{m}-01")
for y, m in itertools.product(
range(year[0].item(), year[-1].item() + 1),
[s[0] for s in season_inds],
)
]
)
)
# only keep that included in data
range_ = complete_index.get_indexer(unique_coord[[0, -1]])
full_index = complete_index[slice(range_[0], range_[-1] + 1)]
# check that there are no "missing" seasons in the middle
# print(full_index, unique_coord)
if not full_index.equals(unique_coord):
raise ValueError("Are there seasons missing in the middle of the dataset?")

codes = group.copy(data=np.repeat(unique_codes, counts), deep=False)
unique_coord_var = Variable(group.name, unique_coord, group.attrs)

return EncodedGroups(
codes=codes,
group_indices=group_indices,
unique_coord=unique_coord_var,
full_index=full_index,
)
28 changes: 22 additions & 6 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Grouper,
TimeResampler,
UniqueGrouper,
season_to_month_tuple,
)
from xarray.tests import (
InaccessibleArray,
Expand Down Expand Up @@ -2915,12 +2916,6 @@ def test_gappy_resample_reductions(reduction):
assert_identical(expected, actual)


# Possible property tests
# 1. lambda x: x
# 2. grouped-reduce on unique coords is identical to array
# 3. group_over == groupby-reduce along other dimensions


def test_groupby_transpose():
# GH5361
data = xr.DataArray(
Expand All @@ -2932,3 +2927,24 @@ def test_groupby_transpose():
second = data.groupby("x").sum()

assert_identical(first, second.transpose(*first.dims))


def test_season_to_month_tuple():
assert season_to_month_tuple(["JF", "MAM", "JJAS", "OND"]) == (
(1, 2),
(3, 4, 5),
(6, 7, 8, 9),
(10, 11, 12),
)
assert season_to_month_tuple(["DJFM", "AM", "JJAS", "ON"]) == (
(12, 1, 2, 3),
(4, 5),
(6, 7, 8, 9),
(10, 11),
)


# Possible property tests
# 1. lambda x: x
# 2. grouped-reduce on unique coords is identical to array
# 3. group_over == groupby-reduce along other dimensions

0 comments on commit 594d4a7

Please sign in to comment.