Skip to content

Commit

Permalink
Update temporal.py to properly handle piControl simulations (#696)
Browse files Browse the repository at this point in the history
* Update temporal.py to properly handle piControl simulations

* pre-commit style fix

* to ensure year to be int

* Convert `self.dim` to a str to fix mypy warnings
- Remove unused type ignore comments

---------

Co-authored-by: Tom Vo <[email protected]>
  • Loading branch information
lee1043 and tomvothecoder authored Sep 23, 2024
1 parent 753a046 commit 643e72c
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def _averager(
# it becomes obsolete after the data variable is averaged. When the
# averaged data variable is added to the dataset, the new time dimension
# and its associated coordinates are also added.
ds = ds.drop_dims(self.dim) # type: ignore
ds = ds.drop_dims(self.dim)
ds[dv_avg.name] = dv_avg

if keep_weights:
Expand Down Expand Up @@ -847,7 +847,7 @@ def _set_data_var_attrs(self, data_var: str):
dv = _get_data_var(self._dataset, data_var)

self.data_var = data_var
self.dim = get_dim_coords(dv, "T").name
self.dim = str(get_dim_coords(dv, "T").name)

if not _contains_datetime_like_objects(dv[self.dim]):
first_time_coord = dv[self.dim].values[0]
Expand Down Expand Up @@ -1084,7 +1084,11 @@ def _drop_incomplete_djf(self, dataset: xr.Dataset) -> xr.Dataset:
ds[self.dim].dt.year.values[0],
ds[self.dim].dt.year.values[-1],
)
incomplete_seasons = (f"{start_year}-01", f"{start_year}-02", f"{end_year}-12")
incomplete_seasons = (
f"{int(start_year):04d}-01",
f"{int(start_year):04d}-02",
f"{int(end_year):04d}-12",
)

for year_month in incomplete_seasons:
try:
Expand Down Expand Up @@ -1115,9 +1119,7 @@ def _drop_leap_days(self, ds: xr.Dataset):
-------
xr.Dataset
"""
ds = ds.sel( # type: ignore
**{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))}
)
ds = ds.sel(**{self.dim: ~((ds.time.dt.month == 2) & (ds.time.dt.day == 29))})
return ds

def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
Expand All @@ -1142,9 +1144,9 @@ def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

dv = dv.weighted(self._weights).mean(dim=self.dim) # type: ignore
dv = dv.weighted(self._weights).mean(dim=self.dim)
else:
dv = dv.mean(dim=self.dim) # type: ignore
dv = dv.mean(dim=self.dim)

dv = self._add_operation_attrs(dv)

Expand Down

0 comments on commit 643e72c

Please sign in to comment.