From 873e0cecc459410fe9e0dd7db7a9e78344e565a8 Mon Sep 17 00:00:00 2001 From: Max Grover Date: Mon, 25 Sep 2023 09:09:47 -0500 Subject: [PATCH] ADD: Add new xradar functionality (#1456) * FIX: Add isinstance checks to improve linting * ADD: Add suite of new utilities to xradar accessor * ADD: Add new xradar tests --- pyart/xradar/accessor.py | 320 +++++++++++++++++++++++++++++++++- tests/xradar/test_accessor.py | 14 +- 2 files changed, 330 insertions(+), 4 deletions(-) diff --git a/pyart/xradar/accessor.py b/pyart/xradar/accessor.py index 283d62e08a..00f4584830 100644 --- a/pyart/xradar/accessor.py +++ b/pyart/xradar/accessor.py @@ -4,9 +4,176 @@ """ +import copy +from collections.abc import Hashable, Mapping +from typing import Any, overload + +import numpy as np +import pandas as pd +from datatree import DataTree, formatting, formatting_html +from datatree.treenode import NodePath +from xarray import DataArray, Dataset, concat +from xarray.core import utils + + class Xradar: - def __init__(self, xradar): + def __init__(self, xradar, default_sweep="sweep_0"): self.xradar = xradar + self.scan_type = "ppi" + self.combined_sweeps = self._combine_sweeps(self.xradar) + self.fields = self._find_fields(self.combined_sweeps) + self.scan_type = None + self.time = dict( + data=(self.combined_sweeps.time - self.combined_sweeps.time.min()).astype( + "int64" + ) + / 1e9, + units=f"seconds since {pd.to_datetime(self.combined_sweeps.time.min().values).strftime('%Y-%m-%d %H:%M:%S.0')}", + calendar="gregorian", + ) + self.range = dict(data=self.combined_sweeps.range.values) + self.azimuth = dict(data=self.combined_sweeps.azimuth.values) + self.elevation = dict(data=self.combined_sweeps.elevation.values) + self.fixed_angle = dict(data=self.combined_sweeps.sweep_fixed_angle.values) + self.antenna_transition = None + self.latitude = dict(data=self.xradar["latitude"].values) + self.longitude = dict(data=self.xradar["longitude"].values) + self.sweep_end_ray_index = dict( + data=self.combined_sweeps.ngates.groupby("sweep_number").max().values + ) + self.sweep_start_ray_index = dict( + data=self.combined_sweeps.ngates.groupby("sweep_number").min().values + ) + self.metadata = dict(**self.xradar.attrs) + self.ngates = len(self.range["data"]) + self.nrays = len(self.azimuth["data"]) + self.nsweeps = len(self.xradar.sweep_group_name) + self.instrument_parameters = dict(**self.xradar["radar_parameters"].attrs) + + def __repr__(self): + return formatting.datatree_repr(self.xradar) + + def _repr_html_(self): + return formatting_html.datatree_repr(self.xradar) + + @overload + def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[misc] + ... + + @overload + def __getitem__(self, key: Any) -> Dataset: + ... + + def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: + """ + Access child nodes, variables, or coordinates stored anywhere in this tree. + + Returned object will be either a DataTree or DataArray object depending on whether the key given points to a + child or variable. + + Parameters + ---------- + key : str + Name of variable / child within this node, or unix-like path to variable / child within another node. + + Returns + ------- + Union[DataTree, DataArray] + """ + + # Either: + if utils.is_dict_like(key): + # dict-like indexing + raise NotImplementedError("Should this index over whole tree?") + elif isinstance(key, str): + # path-like: a name of a node/variable, or path to a node/variable + path = NodePath(key) + return self.xradar._get_item(path) + elif utils.is_list_like(key): + # iterable of variable names + raise NotImplementedError( + "Selecting via tags is deprecated, and selecting multiple items should be " + "implemented via .subset" + ) + else: + raise ValueError(f"Invalid format for key: {key}") + + # Iterators + + def iter_start(self): + """Return an iterator over the sweep start indices.""" + return (s for s in self.sweep_start_ray_index["data"]) + + def iter_end(self): + """Return an iterator over the sweep end indices.""" + return (s for s in self.sweep_end_ray_index["data"]) + + def iter_start_end(self): + """Return an iterator over the sweep start and end indices.""" + return ((s, e) for s, e in zip(self.iter_start(), self.iter_end())) + + def iter_slice(self): + """Return an iterator which returns sweep slice objects.""" + return (slice(s, e + 1) for s, e in self.iter_start_end()) + + def iter_field(self, field_name): + """Return an iterator which returns sweep field data.""" + self.check_field_exists(field_name) + return (self.fields[field_name]["data"][s] for s in self.iter_slice()) + + def iter_azimuth(self): + """Return an iterator which returns sweep azimuth data.""" + return (self.azimuth["data"][s] for s in self.iter_slice()) + + def iter_elevation(self): + """Return an iterator which returns sweep elevation data.""" + return (self.elevation["data"][s] for s in self.iter_slice()) + + def add_field(self, field_name, dic, replace_existing=False): + """ + Add a field to the object. + + Parameters + ---------- + field_name : str + Name of the field to add to the dictionary of fields. + dic : dict + Dictionary contain field data and metadata. + replace_existing : bool, optional + True to replace the existing field with key field_name if it + exists, loosing any existing data. False will raise a ValueError + when the field already exists. + + """ + # check that the field dictionary to add is valid + if field_name in self.fields and replace_existing is False: + err = "A field with name: %s already exists" % (field_name) + raise ValueError(err) + if "data" not in dic: + raise KeyError("dic must contain a 'data' key") + if dic["data"].shape != (self.nrays, self.ngates): + t = (self.nrays, self.ngates) + err = "'data' has invalid shape, should be (%i, %i)" % t + raise ValueError(err) + # add the field + self.fields[field_name] = dic + for sweep in range(self.nsweeps): + sweep_ds = ( + self.xradar[f"sweep_{sweep}"].to_dataset().drop_duplicates("azimuth") + ) + sweep_ds[field_name] = ( + ("azimuth", "range"), + self.fields[field_name]["data"][self.get_slice(sweep)], + ) + attrs = dic.copy() + del attrs["data"] + sweep_ds[field_name].attrs = attrs + self.xradar[f"sweep_{sweep}"].ds = sweep_ds + return def get_field(self, sweep, field_name, copy=False): """ @@ -39,6 +206,22 @@ def get_field(self, sweep, field_name, copy=False): else: return data + def check_field_exists(self, field_name): + """ + Check that a field exists in the fields dictionary. + + If the field does not exist raise a KeyError. + + Parameters + ---------- + field_name : str + Name of field to check. + + """ + if field_name not in self.fields: + raise KeyError("Field not available: " + field_name) + return + def get_gate_x_y_z(self, sweep, edges=False, filter_transitions=False): """ Return the x, y and z gate locations in meters for a given sweep. @@ -75,8 +258,141 @@ def get_gate_x_y_z(self, sweep, edges=False, filter_transitions=False): """ # Check to see if the data needs to be georeferenced - if "x" not in self.xradar[f"sweep_{0}"].coords: + if "x" not in self.xradar[f"sweep_{sweep}"].coords: self.xradar = self.xradar.xradar.georeference() data = self.xradar[f"sweep_{sweep}"].xradar.georeference() return data["x"].values, data["y"].values, data["z"].values + + def _combine_sweeps(self, radar): + # Loop through and extract the different datasets + ds_list = [] + for sweep in radar.sweep_group_name.values: + ds_list.append(radar[sweep].ds.drop_duplicates("azimuth")) + + # Merge based on the sweep number + merged = concat(ds_list, dim="sweep_number") + + # Stack the sweep number and azimuth together + stacked = merged.stack(gates=["sweep_number", "azimuth"]).transpose() + + # Drop the missing gates + cleaned = stacked.where(stacked.time == stacked.time.dropna("gates")) + + # Add in number of gates variable + cleaned["ngates"] = ("gates", np.arange(len(cleaned.gates))) + + # Return the non-missing times, ensuring valid data is returned + return cleaned + + def add_filter(self, gatefilter, replace_existing=False, include_fields=None): + """ + Updates the radar object with an applied gatefilter provided + by the user that masks values in fields within the radar object. + + Parameters + ---------- + gatefilter : GateFilter + GateFilter instance. This filter will exclude equal to + the conditions provided in the gatefilter and mask values + in fields specified or all fields if include_fields is None. + replace_existing : bool, optional + If True, replaces the fields in the radar object with + copies of those fields with the applied gatefilter. + False will return new fields with the appended 'filtered_' + prefix. + include_fields : list, optional + List of fields to have filtered applied to. If none, all + fields will have applied filter. + + """ + # If include_fields is None, sets list to all fields to include. + if include_fields is None: + include_fields = [*self.fields.keys()] + + try: + # Replace current fields with masked versions with applied gatefilter. + if replace_existing: + for field in include_fields: + self.fields[field]["data"] = np.ma.masked_where( + gatefilter.gate_excluded, self.fields[field]["data"] + ) + # Add new fields with prefix 'filtered_' + else: + for field in include_fields: + field_dict = copy.deepcopy(self.fields[field]) + field_dict["data"] = np.ma.masked_where( + gatefilter.gate_excluded, field_dict["data"] + ) + self.add_field( + "filtered_" + field, field_dict, replace_existing=True + ) + + # If fields don't match up throw an error. + except KeyError: + raise KeyError( + field + " not found in the original radar object, " + "please check that names in the include_fields list " + "match those in the radar object." + ) + return + + def get_nyquist_vel(self, sweep, check_uniform=True): + """ + Return the Nyquist velocity in meters per second for a given sweep. + + Raises a LookupError if the Nyquist velocity is not available, an + Exception is raised if the velocities are not uniform in the sweep + unless check_uniform is set to False. + + Parameters + ---------- + sweep : int + Sweep number to retrieve data for, 0 based. + check_uniform : bool + True to check to perform a check on the Nyquist velocities that + they are uniform in the sweep, False will skip this check and + return the velocity of the first ray in the sweep. + + Returns + ------- + nyquist_velocity : float + Array containing the Nyquist velocity in m/s for a given sweep. + + """ + s = self.get_slice(sweep) + try: + nyq_vel = self.instrument_parameters["nyquist_velocity"]["data"][s] + except TypeError: + raise LookupError("Nyquist velocity unavailable") + if check_uniform: + if np.any(nyq_vel != nyq_vel[0]): + raise Exception("Nyquist velocities are not uniform in sweep") + return float(nyq_vel[0]) + + def get_start(self, sweep): + """Return the starting ray index for a given sweep.""" + return int(self.combined_sweeps.ngates.sel(sweep_number=sweep).min()) + + def get_end(self, sweep): + """Return the ending ray for a given sweep.""" + return self.sweep_end_ray_index["data"][sweep] + + def get_start_end(self, sweep): + """Return the starting and ending ray for a given sweep.""" + return self.get_start(sweep), self.get_end(sweep) + + def get_slice(self, sweep): + """Return a slice for selecting rays for a given sweep.""" + start, end = self.get_start_end(sweep) + return slice(start, end + 1) + + def _find_fields(self, ds): + fields = {} + for field in self.combined_sweeps.variables: + if self.combined_sweeps[field].dims == ("gates", "range"): + fields[field] = { + "data": self.combined_sweeps[field].values, + **self.combined_sweeps[field].attrs, + } + return fields diff --git a/tests/xradar/test_accessor.py b/tests/xradar/test_accessor.py index 5664267fa9..282af3cd54 100644 --- a/tests/xradar/test_accessor.py +++ b/tests/xradar/test_accessor.py @@ -9,7 +9,6 @@ def test_get_field(filename=filename): dtree = xd.io.open_cfradial1_datatree( filename, - first_dim="time", optional=False, ) radar = pyart.xradar.Xradar(dtree) @@ -20,7 +19,6 @@ def test_get_field(filename=filename): def test_get_gate_x_y_z(filename=filename): dtree = xd.io.open_cfradial1_datatree( filename, - first_dim="time", optional=False, ) radar = pyart.xradar.Xradar(dtree) @@ -28,3 +26,15 @@ def test_get_gate_x_y_z(filename=filename): assert x.shape == (483, 996) assert y.shape == (483, 996) assert z.shape == (483, 996) + + +def test_add_field(filename=filename): + dtree = xd.io.open_cfradial1_datatree( + filename, + optional=False, + ) + radar = pyart.xradar.Xradar(dtree) + new_field = radar.fields["DBZ"] + radar.add_field("reflectivity", new_field) + assert "reflectivity" in radar.fields + assert radar["sweep_0"]["reflectivity"].shape == radar["sweep_0"]["DBZ"].shape