From 464886b84ac4914a9e36c113bf5e9d5e5852881b Mon Sep 17 00:00:00 2001 From: tomvothecoder Date: Thu, 4 Apr 2024 14:44:25 -0700 Subject: [PATCH] Fix tests --- tests/test_temporal.py | 178 ++++++++++++++++++++++++++++------------- xcdat/temporal.py | 24 +++--- 2 files changed, 137 insertions(+), 65 deletions(-) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 525055d1..1b30566b 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -1,4 +1,5 @@ import logging +import warnings import cftime import numpy as np @@ -121,7 +122,7 @@ def test_averages_for_yearly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -139,7 +140,7 @@ def test_averages_for_yearly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_averages_for_monthly_time_series(self): # Set up dataset @@ -293,7 +294,7 @@ def test_averages_for_daily_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -310,7 +311,7 @@ def test_averages_for_daily_time_series(self): "weighted": "False", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_averages_for_hourly_time_series(self): ds = xr.Dataset( @@ -378,7 +379,7 @@ def test_averages_for_hourly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) # Test unweighted averages result = ds.temporal.average("ts", weighted=False) @@ -396,7 +397,7 @@ def test_averages_for_hourly_time_series(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestGroupAverage: @@ -620,7 +621,7 @@ def test_weighted_seasonal_averages_with_DJF_without_dropping_incomplete_seasons }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): ds = generate_dataset(decode_times=True, cf_compliant=True, has_bounds=True) @@ -668,7 +669,7 @@ def test_weighted_seasonal_averages_with_DJF_and_drop_incomplete_seasons(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_averages_with_JFD(self): ds = self.ds.copy() @@ -728,7 +729,7 @@ def test_weighted_seasonal_averages_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_raises_error_with_incorrect_custom_seasons_argument(self): # Test raises error with non-3 letter strings @@ -816,7 +817,7 @@ def test_weighted_custom_seasonal_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): ds = self.ds.copy() @@ -872,7 +873,7 @@ def test_weighted_custom_seasonal_averages_drops_incomplete_seasons(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_averages_with_seasons_spanning_calendar_years( self, @@ -973,7 +974,7 @@ def test_weighted_monthly_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_averages_with_masked_data(self): ds = self.ds.copy() @@ -1024,7 +1025,7 @@ def test_weighted_monthly_averages_with_masked_data(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_averages(self): ds = self.ds.copy() @@ -1067,7 +1068,7 @@ def test_weighted_daily_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_hourly_averages(self): ds = self.ds.copy() @@ -1111,7 +1112,7 @@ def test_weighted_hourly_averages(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestClimatology: @@ -1151,7 +1152,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000", "01-01-2000"), ) @@ -1159,7 +1160,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000"), ) @@ -1169,7 +1170,7 @@ def test_subsets_climatology_based_on_reference_period(self): result = ds.temporal.climatology( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1201,11 +1202,11 @@ def test_subsets_climatology_based_on_reference_period(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "True", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_climatology_with_DJF(self): ds = self.ds.copy() @@ -1259,7 +1260,71 @@ def test_weighted_seasonal_climatology_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) + + def test_raises_deprecation_warning_with_drop_incomplete_djf_season_config(self): + # NOTE: This will test will also cover the other public APIs that + # have drop_incomplete_djf as a season_config arg. + ds = self.ds.copy() + + with warnings.catch_warnings(record=True) as w: + result = ds.temporal.climatology( + "ts", + "season", + season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + ) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert str(w[0].message) == ( + "The `season_config` argument 'drop_incomplete_djf' is being deprecated. " + "Please use 'drop_incomplete_seasons' instead." + ) + + expected = ds.copy() + expected = expected.drop_dims("time") + expected_time = xr.DataArray( + data=np.array( + [ + cftime.DatetimeGregorian(1, 1, 1), + cftime.DatetimeGregorian(1, 4, 1), + cftime.DatetimeGregorian(1, 7, 1), + cftime.DatetimeGregorian(1, 10, 1), + ], + ), + coords={ + "time": np.array( + [ + cftime.DatetimeGregorian(1, 1, 1), + cftime.DatetimeGregorian(1, 4, 1), + cftime.DatetimeGregorian(1, 7, 1), + cftime.DatetimeGregorian(1, 10, 1), + ], + ), + }, + attrs={ + "axis": "T", + "long_name": "time", + "standard_name": "time", + "bounds": "time_bnds", + }, + ) + expected["ts"] = xr.DataArray( + name="ts", + data=np.ones((4, 4, 4)), + coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, + dims=["time", "lat", "lon"], + attrs={ + "operation": "temporal_avg", + "mode": "climatology", + "freq": "season", + "weighted": "True", + "drop_incomplete_seasons": "True", + "dec_mode": "DJF", + }, + ) + + xr.testing.assert_identical(result, expected) @requires_dask def test_chunked_weighted_seasonal_climatology_with_DJF(self): @@ -1314,7 +1379,7 @@ def test_chunked_weighted_seasonal_climatology_with_DJF(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_climatology_with_JFD(self): ds = self.ds.copy() @@ -1366,7 +1431,7 @@ def test_weighted_seasonal_climatology_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_climatology(self): ds = self.ds.copy() @@ -1430,7 +1495,7 @@ def test_weighted_custom_seasonal_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_years( self, @@ -1468,7 +1533,7 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea expected["ts"] = xr.DataArray( name="ts", - data=np.ones((4, 4, 4)), + data=np.ones((1, 4, 4)), coords={"lat": expected.lat, "lon": expected.lon, "time": expected_time}, dims=["time", "lat", "lon"], attrs={ @@ -1481,7 +1546,7 @@ def test_weighted_custom_seasonal_climatology_with_seasons_spanning_calendar_yea }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month") @@ -1544,7 +1609,7 @@ def test_weighted_monthly_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_monthly_climatology(self): result = self.ds.temporal.climatology("ts", "month", weighted=False) @@ -1606,7 +1671,7 @@ def test_unweighted_monthly_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_climatology(self): result = self.ds.temporal.climatology("ts", "day", weighted=True) @@ -1668,7 +1733,7 @@ def test_weighted_daily_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_climatology_drops_leap_days_with_matching_calendar(self): time = xr.DataArray( @@ -1759,7 +1824,7 @@ def test_weighted_daily_climatology_drops_leap_days_with_matching_calendar(self) }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_daily_climatology(self): result = self.ds.temporal.climatology("ts", "day", weighted=False) @@ -1821,7 +1886,7 @@ def test_unweighted_daily_climatology(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class TestDepartures: @@ -1902,7 +1967,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.departures( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000", "01-01-2000"), ) @@ -1910,7 +1975,7 @@ def test_raises_error_if_reference_period_arg_is_incorrect(self): ds.temporal.departures( "ts", "season", - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("01-01-2000"), ) @@ -1921,7 +1986,7 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -1929,13 +1994,14 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[np.nan]], [[np.nan]], [[np.nan]]]), + data=np.array([[[0.0]], [[0.0]], [[np.nan]], [[np.nan]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -1959,11 +2025,11 @@ def test_seasonal_departures_relative_to_climatology_reference_period(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_djf": "True", + "drop_incomplete_seasons": "False", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_monthly_departures_relative_to_climatology_reference_period_with_same_output_freq( self, @@ -1974,7 +2040,7 @@ def test_monthly_departures_relative_to_climatology_reference_period_with_same_o "ts", "month", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_djf": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, reference_period=("2000-01-01", "2000-06-01"), ) @@ -2048,7 +2114,7 @@ def test_monthly_departures_relative_to_climatology_reference_period_with_same_o }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -2057,20 +2123,21 @@ def test_weighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=True, - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2094,11 +2161,11 @@ def test_weighted_seasonal_departures_with_DJF(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_seasons": "True", + "drop_incomplete_seasons": "False", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): ds = self.ds.copy() @@ -2108,20 +2175,21 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "season", weighted=True, keep_weights=True, - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2145,16 +2213,17 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): "freq": "season", "weighted": "True", "dec_mode": "DJF", - "drop_incomplete_seasons": "True", + "drop_incomplete_seasons": "False", }, ) expected["time_wts"] = xr.DataArray( name="ts", - data=np.array([1.0, 1.0, 1.0, 1.0]), + data=np.array([0.52542373, 1.0, 1.0, 1.0, 0.47457627]), coords={ "time_original": xr.DataArray( data=np.array( [ + "2000-01-16T12:00:00.000000000", "2000-03-16T12:00:00.000000000", "2000-06-16T00:00:00.000000000", "2000-09-16T00:00:00.000000000", @@ -2174,7 +2243,7 @@ def test_weighted_seasonal_departures_with_DJF_and_keep_weights(self): dims=["time_original"], ) - assert result.identical(expected) + xr.testing.assert_allclose(result, expected) def test_unweighted_seasonal_departures_with_DJF(self): ds = self.ds.copy() @@ -2183,20 +2252,21 @@ def test_unweighted_seasonal_departures_with_DJF(self): "ts", "season", weighted=False, - season_config={"dec_mode": "DJF", "drop_incomplete_seasons": True}, + season_config={"dec_mode": "DJF", "drop_incomplete_seasons": False}, ) expected = ds.copy() expected = expected.drop_dims("time") expected["ts"] = xr.DataArray( name="ts", - data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]]]), + data=np.array([[[0.0]], [[0.0]], [[0.0]], [[0.0]], [[0.0]]]), coords={ "lat": expected.lat, "lon": expected.lon, "time": xr.DataArray( data=np.array( [ + cftime.DatetimeGregorian(2000, 1, 1), cftime.DatetimeGregorian(2000, 4, 1), cftime.DatetimeGregorian(2000, 7, 1), cftime.DatetimeGregorian(2000, 10, 1), @@ -2219,12 +2289,12 @@ def test_unweighted_seasonal_departures_with_DJF(self): "mode": "departures", "freq": "season", "weighted": "False", - "drop_incomplete_seasons": "True", + "drop_incomplete_seasons": "False", "dec_mode": "DJF", }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_unweighted_seasonal_departures_with_JFD(self): ds = self.ds.copy() @@ -2275,7 +2345,7 @@ def test_unweighted_seasonal_departures_with_JFD(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): time = xr.DataArray( @@ -2368,7 +2438,7 @@ def test_weighted_daily_departures_drops_leap_days_with_matching_calendar(self): }, ) - assert result.identical(expected) + xr.testing.assert_identical(result, expected) class Test_GetWeights: diff --git a/xcdat/temporal.py b/xcdat/temporal.py index ff812a0d..d492d4ae 100644 --- a/xcdat/temporal.py +++ b/xcdat/temporal.py @@ -66,7 +66,6 @@ SeasonConfigInput = TypedDict( "SeasonConfigInput", { - "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[List[List[str]]], @@ -77,7 +76,6 @@ SeasonConfigAttr = TypedDict( "SeasonConfigAttr", { - "drop_incomplete_djf": bool, "drop_incomplete_seasons": bool, "dec_mode": Literal["DJF", "JFD"], "custom_seasons": Optional[Dict[str, List[str]]], @@ -318,7 +316,7 @@ def group_average( * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. This config overrides the `decod + representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -503,7 +501,7 @@ def climatology( * "custom_seasons" ([List[List[str]]], by default None) List of sublists containing month strings, with each sublist - representing a custom season. This config overrides the `decod + representing a custom season. * Month strings must be in the three letter format (e.g., 'Jan') * Each month must be included once in a custom season @@ -996,7 +994,9 @@ def _set_arg_attrs( # "season" frequency specific configuration attributes. for key in season_config.keys(): - if key not in DEFAULT_SEASON_CONFIG.keys(): + # TODO: Deprecate `drop_incomplete_djf`. + valid_keys = list(DEFAULT_SEASON_CONFIG.keys()) + ["drop_incomplete_djf"] + if key not in valid_keys: raise KeyError( f"'{key}' is not a supported season config. Supported " f"configs include: {DEFAULT_SEASON_CONFIG.keys()}." @@ -1014,7 +1014,7 @@ def _set_arg_attrs( "deprecated. Please use 'drop_incomplete_seasons' instead.", DeprecationWarning, ) - self._season_config["drop_incomplete_seasons"] = drop_incomplete_djf + self._season_config["drop_incomplete_seasons"] = drop_incomplete_djf # type: ignore else: self._season_config["drop_incomplete_seasons"] = season_config.get( "drop_incomplete_seasons", False @@ -1114,7 +1114,7 @@ def _preprocess_dataset(self, ds: xr.Dataset) -> xr.Dataset: # "NDJFM", we should subset the dataset for time coordinates # belonging to those months. months = self._season_config["custom_seasons"].values() # type: ignore - months = list(chain.from_iterable(months.values())) # type: ignore + months = list(chain.from_iterable(months)) if len(months) != 12: ds = self._subset_coords_for_custom_seasons(ds, months) @@ -1158,12 +1158,14 @@ def _subset_coords_for_custom_seasons( The dataset with time coordinate subsetted to months used in custom seasons. """ - months_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) + month_ints = sorted([MONTH_STR_TO_INT[month] for month in months]) coords_by_month = ds.time.groupby(f"{self.dim}.month").groups - months_idxs = {k: coords_by_month[k] for k in months_ints} - months_idxs = sorted(list(chain.from_iterable(months_idxs.values()))) # type: ignore - ds_new = ds.isel({f"{self.dim}": months_idxs}) + month_to_time_idx = { + k: coords_by_month[k] for k in month_ints if k in coords_by_month + } + month_to_time_idx = sorted(list(chain.from_iterable(month_to_time_idx.values()))) # type: ignore + ds_new = ds.isel({f"{self.dim}": month_to_time_idx}) return ds_new