diff --git a/tests/test_temporal.py b/tests/test_temporal.py index b3e08b02..bda10d68 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -575,27 +575,28 @@ def test_weighted_annual_averages_with_chunking(self): assert result.ts.attrs == expected.ts.attrs assert result.time.attrs == expected.time.attrs - def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): + def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( + self, + ): ds = self.ds.copy() result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() - # Drop the incomplete DJF seasons - expected = expected.isel(time=slice(2, -1)) expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[1]], [[1]], [[1]], [[2.0]]]), + data=np.array([[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -618,35 +619,33 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "DJF", - "drop_incomplete_seasons": "True", }, ) assert result.identical(expected) - def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons( - self, - ): - ds = self.ds.copy() + def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): + ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=True) result = ds.temporal.group_average( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, ) + expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[2.0]], [[1.0]], [[1.0]], [[1.0]], [[2.0]]]), + data=np.ones((4, 4, 4)), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ - cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -664,13 +663,12 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons }, dims=["time", "lat", "lon"], attrs={ - "test_attr": "test", "operation": "temporal_avg", "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "True", "dec_mode": "DJF", - "drop_incomplete_seasons": "False", }, ) @@ -729,6 +727,7 @@ def test_weighted_seasonal_averages_with_JFD(self): "mode": "group_average", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) @@ -810,13 +809,14 @@ def test_weighted_custom_seasonal_averages(self): "operation": "temporal_avg", "mode": "group_average", "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "False", "custom_seasons": [ "JanFebMar", "AprMayJun", "JulAugSep", "OctNovDec", ], - "weighted": "True", }, ) @@ -870,8 +870,9 @@ def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": ["NovDec", "FebMarApr"], "weighted": "True", + "drop_incomplete_seasons": "True", + "custom_seasons": ["NovDec", "FebMarApr"], }, ) @@ -926,8 +927,9 @@ def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( "operation": "temporal_avg", "mode": "group_average", "freq": "season", - "custom_seasons": ["NovDecJanFebMar"], "weighted": "True", + "drop_incomplete_seasons": "False", + "custom_seasons": ["NovDecJanFebMar"], }, ) @@ -1198,8 +1200,8 @@ def test_weighted_seasonal_climatology_with_DJF(self): "mode": "climatology", "freq": "season", "weighted": "True", - "dec_mode": "DJF", "drop_incomplete_seasons": "True", + "dec_mode": "DJF", }, ) @@ -1305,6 +1307,7 @@ def test_weighted_seasonal_climatology_with_JFD(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) @@ -1363,6 +1366,7 @@ def test_weighted_custom_seasonal_climatology(self): "mode": "climatology", "freq": "season", "weighted": "True", + "drop_incomplete_seasons": "False", "custom_seasons": [ "JanFebMar", "AprMayJun", @@ -1374,24 +1378,18 @@ def test_weighted_custom_seasonal_climatology(self): assert result.identical(expected) - @pytest.mark.xfail def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years( self, ): ds = self.ds.copy() - custom_seasons = [ - ["Jan", "Feb", "Mar"], - ["Apr", "May", "Jun"], - ["Jul", "Aug", "Sep"], - ["Oct", "Nov", "Dec"], - ] + custom_seasons = [["Nov", "Dec", "Jan", "Feb", "Mar"]] result = ds.temporal.climatology( "ts", "season", season_config={ + "drop_incomplete_seasons": False, "custom_seasons": custom_seasons, - "drop_incomplete_seasons": True, }, ) @@ -1399,21 +1397,11 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea expected = expected.drop_dims("time") expected_time = xr.DataArray( data=np.array( - [ - cftime.DatetimeGregorian(1, 2, 1), - cftime.DatetimeGregorian(1, 5, 1), - cftime.DatetimeGregorian(1, 8, 1), - cftime.DatetimeGregorian(1, 11, 1), - ], + [cftime.DatetimeGregorian(1, 1, 1)], ), coords={ "time": np.array( - [ - cftime.DatetimeGregorian(1, 2, 1), - cftime.DatetimeGregorian(1, 5, 1), - cftime.DatetimeGregorian(1, 8, 1), - cftime.DatetimeGregorian(1, 11, 1), - ], + [cftime.DatetimeGregorian(1, 1, 1)], ), }, attrs={ @@ -1434,12 +1422,8 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea "mode": "climatology", "freq": "season", "weighted": "True", - "custom_seasons": [ - "JanFebMar", - "AprMayJun", - "JulAugSep", - "OctNovDec", - ], + "drop_incomplete_seasons": "False", + "custom_seasons": ["NovDecJanFebMar"], }, ) @@ -1938,8 +1922,8 @@ def test_unweighted_seasonal_departures_with_DJF(self): "mode": "departures", "freq": "season", "weighted": "False", - "dec_mode": "DJF", "drop_incomplete_seasons": "True", + "dec_mode": "DJF", }, ) @@ -1969,6 +1953,7 @@ def test_unweighted_seasonal_departures_with_JFD(self): "mode": "departures", "freq": "season", "weighted": "False", + "drop_incomplete_seasons": "False", "dec_mode": "JFD", }, ) diff --git a/xcdat/temporal.py b/xcdat/temporal.py index 172ed354..8aa4fce5 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -63,8 +63,8 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { - "dec_mode": Literal["DJF", "JFD"], "drop_incomplete_seasons": bool, + "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], }, total=False, @@ -73,16 +73,16 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { - "dec_mode": Literal["DJF", "JFD"], "drop_incomplete_seasons": bool, + "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], }, total=False, ) DEFAULT_SEASON_CONFIG: SeasonConfigInput = { - "dec_mode": "DJF", "drop_incomplete_seasons": False, + "dec_mode": "DJF", "custom_seasons": None, } @@ -241,16 +241,6 @@ def group_average( predefined seasons are passed, configs for custom seasons are ignored and vice versa. - Configs for predefined seasons: - - * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") - The mode for the season that includes December. - - * "DJF": season includes the previous year December. - * "JFD": season includes the same year December. - Xarray labels the season with December as "DJF", but it is - actually "JFD". - * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of the required months to form the season. For example, if we have @@ -264,11 +254,19 @@ def group_average( season because it only has "Jan" and "Feb". Therefore, these time coordinates are dropped. - Configs for custom seasons: + * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") + The mode for the season that includes December in the list of + list of pre-defined seasons ("DJF"/"JFD", "MAM", "JJA", "SON"). + This config is ignored if the ``custom_seasons`` config is set. + + * "DJF": season includes the previous year December. + * "JFD": season includes the same year December. + Xarray labels the season with December as "DJF", but it is + actually "JFD". * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. + representing a custom season. This config overrides the `decod * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -388,8 +386,6 @@ def climatology( predefined seasons are passed, configs for custom seasons are ignored and vice versa. - General configs: - * "drop_incomplete_seasons" (bool, by default False) Seasons are considered incomplete if they do not have all of the required months to form the season. For example, if we have @@ -403,21 +399,19 @@ def climatology( season because it only has "Jan" and "Feb". Therefore, these time coordinates are dropped. - Configs for predefined seasons: - * "dec_mode" (Literal["DJF", "JFD"], by default "DJF") - The mode for the season that includes December. + The mode for the season that includes December in the list of + list of pre-defined seasons ("DJF"/"JFD", "MAM", "JJA", "SON"). + This config is ignored if the ``custom_seasons`` config is set. * "DJF": season includes the previous year December. * "JFD": season includes the same year December. Xarray labels the season with December as "DJF", but it is actually "JFD". - Configs for custom seasons: - * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. + representing a custom season. This config overrides the `decod * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -935,7 +929,22 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: """ if ( self._freq == "season" - and self._season_config.get("drop_incomplete_seasons") is True + and self._season_config.get("custom_seasons") is not None + ): + # Get a flat list of all of the months included in the custom + # seasons to determine if the dataset needs to be subsetted + # on just those months. For example, if we define a custom season + # "NDJFM", we should subset the dataset for time coordinates + # belonging to those months. + months = self._season_config["custom_seasons"].values() # type: ignore + months = list(chain.from_iterable(months.values())) # type: ignore + + if len(months) != 12: + ds = self._subset_coords_for_custom_seasons(ds, months) + + if ( + self._freq == "season" + and self._season_config["drop_incomplete_seasons"] is True ): ds = self._drop_incomplete_seasons(ds) @@ -948,6 +957,34 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: return ds + def _subset_coords_for_custom_seasons( + self, ds: xr.Dataset, months: List[str] + ) -> xr.Dataset: + """Subsets time coordinates to the months included in custom seasons. + + Parameters + ---------- + ds : xr.Dataset + The dataset. + months : List[str] + A list of months included in custom seasons. + Example: ["Nov", "Dec", "Jan"] + + Returns + ------- + xr.Dataset + The dataset with time coordinate subsetted to months used in + custom seasons. + """ + months_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) + + coords_by_month = ds.time.groupby(f"{self.dim}.month").groups + months_idxs = {k: coords_by_month[k] for k in months_ints} + months_idxs = sorted(list(chain.from_iterable(months_idxs.values()))) # type: ignore + ds_new = ds.isel({f"{self.dim}": months_idxs}) + + return ds_new + def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: """Drops incomplete seasons within a continuous time series. @@ -974,37 +1011,34 @@ def _drop_incomplete_seasons(self, ds: xr.Dataset) -> xr.Dataset: A DataFrame of seasonal datetime components with only complete seasons. """ - # Algorithm - # Prereq - This needs to be done AFTER time coordinates are labeled - # and BEFORE obsoelete columns are dropped because custom seasons can be - # assigned to the time coordiantes first. - # 1. Get the count of months per season (pre-defined seasons by xarray - # all have 3), otherwise use custom seasons count - # 2. Label all time coordinates by groups - # 3. Group the time coordinates by group and the get count - # 4. Drop time coordinates where count != expected count for season - ds_new = ds.copy() - time_coords = ds[self.dim].copy() - # Transform the time coords into a DataFrame of seasonal datetime # components based on the grouping mode. + time_coords = ds[self.dim].copy() df = self._get_df_dt_components(time_coords, drop_obsolete_cols=False) - # Add a column for the expected count of months for that season - # For example, "NovDec" is split into ["Nov", "Dec"] which equals an - # expected count of 2 months. + # Get the expected and actual number of months for each season group. df["expected_months"] = df["season"].str.split(r"(?<=.)(?=[A-Z])").str.len() - # Add a column for the actual count of months for that season. - df["actual_months"] = df.groupby(["season"])["year"].transform("count") + df["actual_months"] = df.groupby(["year", "season"])["year"].transform("count") # Get the incomplete seasons and drop the time coordinates that are in # those incomplete seasons. indexes_to_drop = df[df["expected_months"] != df["actual_months"]].index if len(indexes_to_drop) > 0: + # The dataset needs to be split into a dataset with and a dataset + # without the time dimension because the xarray `.where()` method + # concatenates the time dimension to non-time dimension data vars, + # which is an undesired behavior. + ds_no_time = ds.get([v for v in ds.data_vars if self.dim not in ds[v].dims]) # type: ignore + ds_time = ds.get([v for v in ds.data_vars if self.dim in ds[v].dims]) # type: ignore + coords_to_drop = time_coords.values[indexes_to_drop] - ds_new = ds_new.where(~time_coords.isin(coords_to_drop), drop=True) + ds_time = ds_time.where(~time_coords.isin(coords_to_drop), drop=True) - return ds_new + ds_new = xr.merge([ds_time, ds_no_time]) + + return ds_new + + return ds def _drop_leap_days(self, ds: xr.Dataset): """Drop leap days from time coordinates. @@ -1681,8 +1715,8 @@ def _add_operation_attrs(self, data_var: xr.DataArray) -> xr.DataArray: ) if self._freq == "season": - data_var.attrs["drop_incomplete_seasons"] = self._season_config.get( - "drop_incomplete_seasons" + data_var.attrs["drop_incomplete_seasons"] = str( + self._season_config["drop_incomplete_seasons"] ) custom_seasons = self._season_config.get("custom_seasons")