Skip to content

Commit

Permalink
Initial work on #596
Browse files Browse the repository at this point in the history
  • Loading branch information
pochedls committed Feb 4, 2024
1 parent fbf1db6 commit c027d0e
Showing 1 changed file with 69 additions and 11 deletions.
80 changes: 69 additions & 11 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#: Type alias for a dictionary of axis keys mapped to their bounds.
AxisWeights = Dict[Hashable, xr.DataArray]
#: Type alias for supported spatial axis keys.
SpatialAxis = Literal["X", "Y"]
SpatialAxis = Literal["X", "Y", "Z"]
SPATIAL_AXES: Tuple[SpatialAxis, ...] = get_args(SpatialAxis)
#: Type alias for a tuple of floats/ints for the regional selection bounds.
RegionAxisBounds = Tuple[float, float]
Expand Down Expand Up @@ -73,10 +73,12 @@ def average(
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
lev_bounds: Optional[RegionAxisBounds] = None,
) -> xr.Dataset:
"""
Calculates the spatial average for a rectilinear grid over an optionally
specified regional domain.
Calculates the weighted spatial and/or vertical average for a
rectilinear grid over an optionally specified regional and/or vertical
domain.
Operations include:
Expand All @@ -101,7 +103,7 @@ def average(
average.
axis : List[SpatialAxis]
List of axis dimensions to average over, by default ["X", "Y"].
Valid axis keys include "X" and "Y".
Valid axis keys include "X", "Y", and "Z".
weights : {"generate", xr.DataArray}, optional
If "generate", then weights are generated. Otherwise, pass a
DataArray containing the regional weights used for weighted
Expand All @@ -122,6 +124,10 @@ def average(
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.
lev_bounds : Optional[RegionAxisBounds], optional
A tuple of floats/ints for the regional lower and upper level
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The default is None.
Returns
-------
Expand All @@ -143,11 +149,15 @@ def average(
>>>
>>> ds.lon.attrs["axis"]
>>> X
>>>
>>> ds.level.attrs["axis"]
>>> Z
Set the 'axis' attribute for the required coordinates if it isn't:
>>> ds.lat.attrs["axis"] = "Y"
>>> ds.lon.attrs["axis"] = "X"
>>> ds.level.attrs["axis"] = "Z"
Call spatial averaging method:
Expand All @@ -167,6 +177,10 @@ def average(
>>> ts_zonal = ds.spatial.average("tas", axis=["X"])["tas"]
Get the vertical average (between 100 and 1000 hPa):
>>> ta_column = ds.spatial.average("ta", axis=["Z"], lev_bounds=(100, 1000))["ta"]
Using custom weights for averaging:
>>> # The shape of the weights must align with the data var.
Expand All @@ -178,6 +192,12 @@ def average(
>>>
>>> ts_global = ds.spatial.average("tas", axis=["X", "Y"],
>>> weights=weights)["tas"]
Notes:
------
Weights are generally computed as the difference between the bounds. If
sub-selecting a region, the units must match the axis units (e.g.,
Pa/hPa or m/km).
"""
ds = self._dataset.copy()
dv = _get_data_var(ds, data_var)
Expand All @@ -188,7 +208,11 @@ def average(
self._validate_region_bounds("Y", lat_bounds)
if lon_bounds is not None:
self._validate_region_bounds("X", lon_bounds)
self._weights = self.get_weights(axis, lat_bounds, lon_bounds, data_var)
if lev_bounds is not None:
self._validate_region_bounds("Z", lev_bounds)
self._weights = self.get_weights(
axis, lat_bounds, lon_bounds, lev_bounds, data_var
)
elif isinstance(weights, xr.DataArray):
self._weights = weights

Expand All @@ -205,6 +229,7 @@ def get_weights(
axis: List[SpatialAxis],
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
lev_bounds: Optional[RegionAxisBounds] = None,
data_var: Optional[str] = None,
) -> xr.DataArray:
"""
Expand All @@ -216,9 +241,9 @@ def get_weights(
weights are then combined to form a DataArray of weights that can be
used to perform a weighted (spatial) average.
If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells
outside this selected regional domain are given zero weight. Grid cells
that are partially in this domain are given partial weight.
If ``lat_bounds``, ``lon_bounds``, or ``lev_bounds`` are supplied, then
grid cells outside this selected regional domain are given zero weight.
Grid cells that are partially in this domain are given partial weight.
Parameters
----------
Expand All @@ -230,6 +255,9 @@ def get_weights(
lon_bounds : Optional[RegionAxisBounds]
Tuple of longitude boundaries for regional selection, by default
None.
lev_bounds : Optional[RegionAxisBounds]
Tuple of level boundaries for vertical selection, by default
None.
data_var: Optional[str]
The key of the data variable, by default None. Pass this argument
when the dataset has more than one bounds per axis (e.g., "lon"
Expand All @@ -246,9 +274,7 @@ def get_weights(
Notes
-----
This method was developed for rectilinear grids only. ``get_weights()``
recognizes and operate on latitude and longitude, but could be extended
to work with other standard geophysical dimensions (e.g., time, depth,
and pressure).
recognizes and operate on latitude, longitude, and vertical levels.
"""
Bounds = TypedDict(
"Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]}
Expand All @@ -267,6 +293,12 @@ def get_weights(
if lat_bounds is not None
else None,
},
"Z": {
"weights_method": self._get_vertical_weights,
"region": np.array(lev_bounds, dtype="float")
if lev_bounds is not None
else None,
},
}

axis_weights: AxisWeights = {}
Expand Down Expand Up @@ -476,6 +508,32 @@ def _get_latitude_weights(
weights = self._calculate_weights(d_bounds)
return weights

def _get_vertical_weights(
self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray]
) -> xr.DataArray:
"""Gets weights for the vertical axis.
This method scales the domain to a region (if selected) and returns weights
proportional to the difference between each pair of level bounds.
Parameters
----------
domain_bounds : xr.DataArray
The array of bounds for the vertical domain.
region_bounds : Optional[np.ndarray]
The array of bounds for vertical selection.
Returns
-------
xr.DataArray
The vertical axis weights.
"""
if region_bounds is not None:
domain_bounds = self._scale_domain_to_region(domain_bounds, region_bounds)

weights = self._calculate_weights(domain_bounds)
return weights

def _calculate_weights(self, domain_bounds: xr.DataArray):
"""Calculate weights for the domain.
Expand Down

0 comments on commit c027d0e

Please sign in to comment.