From 3acf03c185a5895e5e5b12cb479b281fc5d0acda Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Wed, 18 Oct 2023 20:58:28 -0400 Subject: [PATCH 01/13] ENH: xarray grid compatibility --- pyart/core/grid.py | 58 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 22dfe1fb0b..8ca3475212 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -316,11 +316,6 @@ def to_xarray(self): """ - if not _XARRAY_AVAILABLE: - raise MissingOptionalDependency( - "Xarray is required to use Grid.to_xarray but is not " + "installed!" - ) - lon, lat = self.get_point_longitude_latitude() z = self.z["data"] y = self.y["data"] @@ -348,6 +343,7 @@ def to_xarray(self): data.attrs.update({meta: self.fields[field][meta]}) ds[field] = data + ds.lon.attrs = [ ("long_name", "longitude of grid cell center"), ("units", "degree_E"), @@ -366,6 +362,58 @@ def to_xarray(self): ds.z.encoding["_FillValue"] = None ds.lat.encoding["_FillValue"] = None ds.lon.encoding["_FillValue"] = None + + # Delayed import + from ..io.grid_io import _make_coordinatesystem_dict + + ds["ProjectionCoordinateSystem"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + dims=None, + attrs=_make_coordinatesystem_dict(self), + ) + + if self.origin_latitude is not None: + ds["origin_latitude"] = xarray.DataArray( + np.ma.expand_dims(self.origin_latitude["data"][0], 0), + dims=("time"), + attrs=get_metadata("origin_latitude"), + ) + + if self.origin_longitude is not None: + ds["origin_longitude"] = xarray.DataArray( + np.ma.expand_dims(self.origin_longitude["data"][0], 0), + dims=("time"), + attrs=get_metadata("origin_longitude"), + ) + + if self.origin_altitude is not None: + ds["origin_altitude"] = xarray.DataArray( + np.ma.expand_dims(self.origin_altitude["data"][0], 0), + dims=("time"), + attrs=get_metadata("origin_altitude"), + ) + + if self.radar_altitude is not None: + ds["radar_altitude"] = xarray.DataArray( + np.ma.expand_dims(self.radar_altitude["data"][0], 0), + dims=("nradar"), + attrs=get_metadata("radar_altitude"), + ) + + if self.radar_latitude is not None: + ds["radar_latitude"] = xarray.DataArray( + np.ma.expand_dims(self.radar_latitude["data"][0], 0), + dims=("nradar"), + attrs=get_metadata("radar_latitude"), + ) + + if self.radar_longitude is not None: + ds["radar_longitude"] = xarray.DataArray( + np.ma.expand_dims(self.radar_longitude["data"][0], 0), + dims=("nradar"), + attrs=get_metadata("radar_longitude"), + ) + ds.close() return ds From 8ad98a918ae3a7a846a93ddbcd780ce577547b36 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 19 Oct 2023 20:36:18 -0400 Subject: [PATCH 02/13] ENH: xarray grid output --- pyart/core/grid.py | 171 ++++++++++++++++++++++++--------------------- 1 file changed, 91 insertions(+), 80 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 8ca3475212..081595afcb 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -315,106 +315,117 @@ def to_xarray(self): in a one dimensional array. """ - lon, lat = self.get_point_longitude_latitude() z = self.z["data"] y = self.y["data"] x = self.x["data"] - time = np.array([num2date(self.time["data"][0], self.time["units"])]) + time = np.array( + [num2date(self.time["data"][0], units=self.time["units"])], + dtype="datetime64[ns]", + ) ds = xarray.Dataset() - for field in list(self.fields.keys()): - field_data = self.fields[field]["data"] + for field, field_info in self.fields.items(): + field_data = field_info["data"] data = xarray.DataArray( np.ma.expand_dims(field_data, 0), dims=("time", "z", "y", "x"), coords={ - "time": (["time"], time), - "z": (["z"], z), + "time": time, + "z": z, "lat": (["y", "x"], lat), "lon": (["y", "x"], lon), - "y": (["y"], y), - "x": (["x"], x), + "y": y, + "x": x, }, ) - for meta in list(self.fields[field].keys()): + + for meta, value in field_info.items(): if meta != "data": - data.attrs.update({meta: self.fields[field][meta]}) + data.attrs.update({meta: value}) ds[field] = data - ds.lon.attrs = [ - ("long_name", "longitude of grid cell center"), - ("units", "degree_E"), - ("standard_name", "Longitude"), - ] - ds.lat.attrs = [ - ("long_name", "latitude of grid cell center"), - ("units", "degree_N"), - ("standard_name", "Latitude"), - ] - - ds.z.attrs = get_metadata("z") - ds.y.attrs = get_metadata("y") - ds.x.attrs = get_metadata("x") - - ds.z.encoding["_FillValue"] = None - ds.lat.encoding["_FillValue"] = None - ds.lon.encoding["_FillValue"] = None - - # Delayed import - from ..io.grid_io import _make_coordinatesystem_dict - - ds["ProjectionCoordinateSystem"] = xarray.DataArray( - data=np.array(1, dtype="int32"), - dims=None, - attrs=_make_coordinatesystem_dict(self), - ) + ds.lon.attrs = [ + ("long_name", "longitude of grid cell center"), + ("units", "degree_E"), + ("standard_name", "Longitude"), + ] + ds.lat.attrs = [ + ("long_name", "latitude of grid cell center"), + ("units", "degree_N"), + ("standard_name", "Latitude"), + ] + + ds.z.attrs = get_metadata("z") + ds.y.attrs = get_metadata("y") + ds.x.attrs = get_metadata("x") + + for attr in [ds.z, ds.lat, ds.lon]: + attr.encoding["_FillValue"] = None + + # Delayed import + from ..io.grid_io import _make_coordinatesystem_dict + + ds["ProjectionCoordinateSystem"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + attrs=_make_coordinatesystem_dict(self), + ) - if self.origin_latitude is not None: - ds["origin_latitude"] = xarray.DataArray( - np.ma.expand_dims(self.origin_latitude["data"][0], 0), - dims=("time"), - attrs=get_metadata("origin_latitude"), - ) - - if self.origin_longitude is not None: - ds["origin_longitude"] = xarray.DataArray( - np.ma.expand_dims(self.origin_longitude["data"][0], 0), - dims=("time"), - attrs=get_metadata("origin_longitude"), - ) - - if self.origin_altitude is not None: - ds["origin_altitude"] = xarray.DataArray( - np.ma.expand_dims(self.origin_altitude["data"][0], 0), - dims=("time"), - attrs=get_metadata("origin_altitude"), - ) - - if self.radar_altitude is not None: - ds["radar_altitude"] = xarray.DataArray( - np.ma.expand_dims(self.radar_altitude["data"][0], 0), - dims=("nradar"), - attrs=get_metadata("radar_altitude"), - ) - - if self.radar_latitude is not None: - ds["radar_latitude"] = xarray.DataArray( - np.ma.expand_dims(self.radar_latitude["data"][0], 0), - dims=("nradar"), - attrs=get_metadata("radar_latitude"), - ) - - if self.radar_longitude is not None: - ds["radar_longitude"] = xarray.DataArray( - np.ma.expand_dims(self.radar_longitude["data"][0], 0), - dims=("nradar"), - attrs=get_metadata("radar_longitude"), - ) - - ds.close() + # write the projection dictionary as a scalar + projection = self.projection.copy() + # NetCDF does not support boolean attribute, covert to string + if "_include_lon_0_lat_0" in projection: + include = projection["_include_lon_0_lat_0"] + projection["_include_lon_0_lat_0"] = ["false", "true"][include] + ds["projection"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + dims=None, + attrs=projection, + ) + + for attr_name in [ + "origin_latitude", + "origin_longitude", + "origin_altitude", + "radar_altitude", + "radar_latitude", + "radar_longitude", + "radar_time", + ]: + if hasattr(self, attr_name): + attr_data = getattr(self, attr_name) + if attr_data is not None: + if "radar_time" not in attr_name: + attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + else: + attr_value = [ + np.array( + num2date( + attr_data["data"][0], + units=attr_data["units"], + ), + dtype="datetime64[ns]", + ) + ] + dims = ("nradar",) + ds[attr_name] = xarray.DataArray( + attr_value, dims=dims, attrs=get_metadata(attr_name) + ) + + if "radar_time" in ds.variables: + ds.radar_time.attrs.pop("calendar") + + if self.radar_name is not None: + radar_name = self.radar_name["data"][0] + ds["radar_name"] = xarray.DataArray( + np.array([b"".join(radar_name)], dtype="S4"), + dims=("nradar"), + attrs=get_metadata("radar_name"), + ) + ds.attrs = self.metadata + ds.close() return ds def add_field(self, field_name, field_dict, replace_existing=False): From 6f8a2607d9f65490e2a8684fea87b6e945192cac Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 19 Oct 2023 22:23:18 -0400 Subject: [PATCH 03/13] adding back xarray availablity check --- pyart/core/grid.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 081595afcb..8ac24dec81 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -315,6 +315,12 @@ def to_xarray(self): in a one dimensional array. """ + + if not _XARRAY_AVAILABLE: + raise MissingOptionalDependency( + "Xarray is required to use Grid.to_xarray but is not " + "installed!" + ) + lon, lat = self.get_point_longitude_latitude() z = self.z["data"] y = self.y["data"] From 116a31c07d7021abe576f0a42ea38e81ac5db5d8 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 19 Oct 2023 22:28:17 -0400 Subject: [PATCH 04/13] pre-commit run --- pyart/core/grid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 8ac24dec81..f078681d27 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -315,11 +315,11 @@ def to_xarray(self): in a one dimensional array. """ - + if not _XARRAY_AVAILABLE: raise MissingOptionalDependency( "Xarray is required to use Grid.to_xarray but is not " + "installed!" - ) + ) lon, lat = self.get_point_longitude_latitude() z = self.z["data"] From 2f73ad042c7e20ef8c1239569ec5d8272b64a57e Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Fri, 20 Oct 2023 12:09:43 -0400 Subject: [PATCH 05/13] fix minor error in grid_to_xarary --- pyart/core/grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index f078681d27..17c3a4942a 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -328,7 +328,6 @@ def to_xarray(self): time = np.array( [num2date(self.time["data"][0], units=self.time["units"])], - dtype="datetime64[ns]", ) ds = xarray.Dataset() @@ -424,12 +423,13 @@ def to_xarray(self): ds.radar_time.attrs.pop("calendar") if self.radar_name is not None: - radar_name = self.radar_name["data"][0] + radar_name = self.radar_name["data"] ds["radar_name"] = xarray.DataArray( - np.array([b"".join(radar_name)], dtype="S4"), + np.array([b"".join(radar_name)]), dims=("nradar"), attrs=get_metadata("radar_name"), ) + ds.attrs = self.metadata ds.close() return ds From 46261c8d31daf5c04ddd22fc06fd5b8ab5828b3b Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Wed, 28 Aug 2024 13:59:15 -0400 Subject: [PATCH 06/13] modified coords in function to_xarray() --- pyart/core/grid.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index d9a6289283..13baeb0c42 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -373,7 +373,7 @@ def to_xarray(self): # Delayed import from ..io.grid_io import _make_coordinatesystem_dict - ds["ProjectionCoordinateSystem"] = xarray.DataArray( + ds.coords["ProjectionCoordinateSystem"] = xarray.DataArray( data=np.array(1, dtype="int32"), attrs=_make_coordinatesystem_dict(self), ) @@ -384,7 +384,7 @@ def to_xarray(self): if "_include_lon_0_lat_0" in projection: include = projection["_include_lon_0_lat_0"] projection["_include_lon_0_lat_0"] = ["false", "true"][include] - ds["projection"] = xarray.DataArray( + ds.coords["projection"] = xarray.DataArray( data=np.array(1, dtype="int32"), dims=None, attrs=projection, @@ -402,20 +402,26 @@ def to_xarray(self): if hasattr(self, attr_name): attr_data = getattr(self, attr_name) if attr_data is not None: - if "radar_time" not in attr_name: + if attr_name in ["origin_latitude", "origin_longitude", "origin_altitude"]: + # Adjusting the dims to 'time' for the origin attributes attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + dims = ("time",) else: - attr_value = [ - np.array( - num2date( - attr_data["data"][0], - units=attr_data["units"], - ), - dtype="datetime64[ns]", - ) - ] - dims = ("nradar",) - ds[attr_name] = xarray.DataArray( + if "radar_time" not in attr_name: + attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + else: + attr_value = [ + np.array( + num2date( + attr_data["data"][0], + units=attr_data["units"], + ), + dtype="datetime64[ns]", + ) + ] + dims = ("nradar",) + + ds.coords[attr_name] = xarray.DataArray( attr_value, dims=dims, attrs=get_metadata(attr_name) ) @@ -431,6 +437,13 @@ def to_xarray(self): ) ds.attrs = self.metadata + for key in ds.attrs: + try: + ds.attrs[key] = ds.attrs[key].decode('utf-8') + except AttributeError: + # If the attribute is not a byte string, just pass + pass + ds.close() return ds @@ -454,7 +467,7 @@ def add_field(self, field_name, field_dict, replace_existing=False): if "data" not in field_dict: raise KeyError('Field dictionary must contain a "data" key') if field_name in self.fields and replace_existing is False: - raise ValueError(f"A field named {field_name} already exists") + raise ValueError("A field named %s already exists" % (field_name)) if field_dict["data"].shape != (self.nz, self.ny, self.nx): raise ValueError("Field has invalid shape") From 5eb6dd67258ae8a79148b6508a34c56402d0f7e0 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Wed, 28 Aug 2024 14:03:58 -0400 Subject: [PATCH 07/13] modified coords in function to_xarray() + pre-commit --- pyart/core/grid.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 13baeb0c42..bb8102e4d6 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -402,7 +402,11 @@ def to_xarray(self): if hasattr(self, attr_name): attr_data = getattr(self, attr_name) if attr_data is not None: - if attr_name in ["origin_latitude", "origin_longitude", "origin_altitude"]: + if attr_name in [ + "origin_latitude", + "origin_longitude", + "origin_altitude", + ]: # Adjusting the dims to 'time' for the origin attributes attr_value = np.ma.expand_dims(attr_data["data"][0], 0) dims = ("time",) @@ -420,7 +424,7 @@ def to_xarray(self): ) ] dims = ("nradar",) - + ds.coords[attr_name] = xarray.DataArray( attr_value, dims=dims, attrs=get_metadata(attr_name) ) @@ -439,7 +443,7 @@ def to_xarray(self): ds.attrs = self.metadata for key in ds.attrs: try: - ds.attrs[key] = ds.attrs[key].decode('utf-8') + ds.attrs[key] = ds.attrs[key].decode("utf-8") except AttributeError: # If the attribute is not a byte string, just pass pass From 7fecb2da6297c2ff6b660aa37aa85a944a0afdc2 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 29 Aug 2024 13:18:19 -0400 Subject: [PATCH 08/13] Add plot_max_cappi example and update related modules --- examples/plotting/plot_max_cappi.py | 112 +++++++ pyart/graph/__init__.py | 2 + pyart/graph/gridmapdisplay.py | 54 +++- pyart/graph/max_cappi.py | 443 ++++++++++++++++++++++++++++ tests/graph/test_plot_maxcappi.py | 106 +++++++ 5 files changed, 711 insertions(+), 6 deletions(-) create mode 100644 examples/plotting/plot_max_cappi.py create mode 100644 pyart/graph/max_cappi.py create mode 100644 tests/graph/test_plot_maxcappi.py diff --git a/examples/plotting/plot_max_cappi.py b/examples/plotting/plot_max_cappi.py new file mode 100644 index 0000000000..3e7567fa40 --- /dev/null +++ b/examples/plotting/plot_max_cappi.py @@ -0,0 +1,112 @@ +""" +============= +Plot Max-CAPPI +============= + +This is an example of how to plot a Max-CAPPI +within a Py-ART grid display object. + +""" + +print(__doc__) + +# Author: Hamid Ali Syed (syed44@purdue.edu) +# License: BSD 3 clause + +import matplotlib.pyplot as plt +import numpy as np + +import pyart +from pyart.testing import get_test_data + +######################################### +# ** MAX-CAPPI Display +# + +# Define and Read in the test data +grid_file = get_test_data("20110520100000_nexrad_grid.nc") +grid = pyart.io.read_grid(grid_file) + + +# Create a grid display +gdisplay = pyart.graph.GridMapDisplay(grid) +gdisplay.plot_maxcappi(field="REF", range_rings=True, add_slogan=True) + + +######################################### +# ** Second Example +# +# Let's read in a cfradial file and +# create a grid. + + +import logging +from datetime import datetime + +import fsspec +import pytz + + +def download_nexrad(timezone, date, site, local_date=False): + """Download NEXRAD radar data from an S3 bucket.""" + try: + utc_date = ( + pytz.timezone(timezone).localize(date).astimezone(pytz.utc) + if local_date + else date + ) + logging.info(f"Time: {utc_date}") + + fs = fsspec.filesystem("s3", anon=True) + nexrad_path = utc_date.strftime( + f"s3://noaa-nexrad-level2/%Y/%m/%d/{site}/{site}%Y%m%d_%H*" + ) + files = sorted(fs.glob(nexrad_path)) + + return [file for file in files if not file.endswith("_MDM")] + except Exception as e: + logging.error("Error in processing: %s", e) + return [] + + +# Load NEXRAD data from S3 Bucket +site = "PHWA" +timezone = "UTC" +date = datetime(2024, 8, 25, 8, 29) + +# Correctly passing the site and timezone +file = download_nexrad(timezone, date, site, local_date=False)[0] + + +# Read the data using nexrad_archive reader +radar = pyart.io.read_nexrad_archive("s3://" + file) + +# Create a 3D grid +# mask out last 10 gates of each ray, this removes the "ring" around the radar. +radar.fields["reflectivity"]["data"][:, -10:] = np.ma.masked + +# exclude masked gates from the gridding +gatefilter = pyart.filters.GateFilter(radar) +gatefilter.exclude_transition() +gatefilter.exclude_masked("reflectivity") +gatefilter.exclude_outside("reflectivity", 10, 80) + +# perform Cartesian mapping, limit to the reflectivity field. +max_range = np.ceil(radar.range["data"].max()) +if max_range / 1e3 > 250: + max_range = 250 * 1e3 + +grid = pyart.map.grid_from_radars( + (radar,), + gatefilters=(gatefilter,), + grid_shape=(30, 441, 441), + grid_limits=((0, 10000), (-max_range, max_range), (-max_range, max_range)), + fields=["reflectivity"], +) + +# Create a grid display +gdisplay = pyart.graph.GridMapDisplay(grid) +with plt.style.context("dark_background"): + gdisplay.plot_maxcappi( + field="reflectivity", cmap="pyart_HomeyerRainbow", add_slogan=True + ) diff --git a/pyart/graph/__init__.py b/pyart/graph/__init__.py index 8de2949cf6..2cf9b6cd8a 100644 --- a/pyart/graph/__init__.py +++ b/pyart/graph/__init__.py @@ -62,5 +62,7 @@ from .radardisplay_airborne import AirborneRadarDisplay # noqa from .radarmapdisplay import RadarMapDisplay # noqa from .radarmapdisplay_basemap import RadarMapDisplayBasemap # noqa +from . import max_cappi # noqa +from .max_cappi import plot_maxcappi # noqa __all__ = [s for s in dir() if not s.startswith("_")] diff --git a/pyart/graph/gridmapdisplay.py b/pyart/graph/gridmapdisplay.py index a2f38b7706..c1a630642b 100644 --- a/pyart/graph/gridmapdisplay.py +++ b/pyart/graph/gridmapdisplay.py @@ -50,6 +50,8 @@ except ImportError: _LAMBERT_GRIDLINES = False +from . import max_cappi # noqa + class GridMapDisplay: """ @@ -120,7 +122,7 @@ def plot_grid( add_grid_lines=True, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot the grid using xarray and cartopy. @@ -257,7 +259,7 @@ def plot_grid( vmin=vmin, vmax=vmax, add_colorbar=False, - **kwargs + **kwargs, ) self.mappables.append(pm) @@ -414,7 +416,7 @@ def plot_latitudinal_level( fig=None, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot a slice along a given latitude. @@ -575,7 +577,7 @@ def plot_longitudinal_level( fig=None, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot a slice along a given longitude. @@ -718,7 +720,7 @@ def plot_cross_section( fig=None, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot a cross section through a set of given points (latitude, @@ -849,7 +851,7 @@ def plot_cross_section( add_colorbar=False, ax=ax, cmap=cmap, - **kwargs + **kwargs, ) self.mappables.append(plot) @@ -1110,6 +1112,46 @@ def cartopy_coastlines(self): category="physical", name="coastline", scale="10m", facecolor="none" ) + def plot_maxcappi( + self, + field, + cmap=None, + vmin=None, + vmax=None, + title=None, + lat_lines=None, + lon_lines=None, + add_map=True, + projection=None, + colorbar=True, + range_rings=False, + dpi=100, + savedir=None, + show_figure=True, + add_slogan=False, + **kwargs, + ): + # Call the plot_maxcappi function from the max_cappi module or object + max_cappi.plot_maxcappi( + grid=self.grid, # Assuming `self.grid` holds the Grid object in your class + field=field, + cmap=cmap, + vmin=vmin, + vmax=vmax, + title=title, + lat_lines=lat_lines, + lon_lines=lon_lines, + add_map=add_map, + projection=projection, + colorbar=colorbar, + range_rings=range_rings, + dpi=dpi, + savedir=savedir, + show_figure=show_figure, + add_slogan=add_slogan, + **kwargs, + ) + # These methods are a hack to allow gridlines when the projection is lambert # https://nbviewer.jupyter.org/gist/ajdawson/dd536f786741e987ae4e diff --git a/pyart/graph/max_cappi.py b/pyart/graph/max_cappi.py new file mode 100644 index 0000000000..3e0015a936 --- /dev/null +++ b/pyart/graph/max_cappi.py @@ -0,0 +1,443 @@ +""" +Plot Max-CAPPI + +This module provides a function to plot a Maximum Constant Altitude Plan Position Indicator (Max-CAPPI) +from radar data using an xarray dataset. The function includes options for adding map features, range rings, +color bars, and customized visual settings. + +Author: Syed Hamid Ali (@syedhamidali) +""" + +__all__ = ["plot_maxcappi"] + +import os +import warnings + +import cartopy.crs as ccrs +import cartopy.feature as feat +import matplotlib.pyplot as plt +import numpy as np +from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER +from matplotlib.ticker import NullFormatter + +warnings.filterwarnings("ignore") + + +def plot_maxcappi( + grid, + field, + cmap=None, + vmin=None, + vmax=None, + title=None, + lat_lines=None, + lon_lines=None, + add_map=True, + projection=None, + colorbar=True, + range_rings=False, + dpi=100, + show_progress=False, + savedir=None, + show_figure=True, + add_slogan=False, + **kwargs, +): + """ + Plots a Constant Altitude Plan Position Indicator (CAPPI) using an xarray Dataset. + + Parameters + ---------- + grid : pyart.core.Grid + The grid object containing the radar data to be plotted. + field : str + The radar field to be plotted (e.g., "REF", "VEL", "WIDTH"). + cmap : str or matplotlib colormap, optional + Colormap to use for the plot. Default is "SyedSpectral" if available, otherwise "HomeyerRainbow". + vmin : float, optional + Minimum value for the color scaling. Default is None, which sets it to the minimum value of the data. + vmax : float, optional + Maximum value for the color scaling. Default is None, which sets it to the maximum value of the data. + title : str, optional + Title of the plot. If None, the title is set to "Max-{field}". + lat_lines : array-like, optional + Latitude lines to be included in the plot. Default is calculated based on dataset coordinates. + lon_lines : array-like, optional + Longitude lines to be included in the plot. Default is calculated based on dataset coordinates. + add_map : bool, optional + Whether to include a map background in the plot. Default is True. + projection : cartopy.crs.Projection, optional + The map projection for the plot. Default is cartopy.crs.LambertAzimuthalEqualArea(). + colorbar : bool, optional + Whether to include a colorbar in the plot. Default is True. + range_rings : bool, optional + Whether to include range rings at 50 km intervals. Default is True. + dpi : int, optional + DPI (dots per inch) for the plot. Default is 100. + show_progress : bool, optional + Whether to print progress messages. Default is False. + savedir : str, optional + Directory where the plot will be saved. If None, the plot is not saved. + show_figure : bool, optional + Whether to display the plot. Default is True. + **kwargs : dict, optional + Additional keyword arguments to pass to matplotlib's `pcolormesh` function. + + Returns + ------- + None + This function does not return any value. It generates and optionally displays/saves a plot. + Notes + ----- + Author : Hamid Ali Syed (@syedhamidali) + """ + + ds = grid.to_xarray().squeeze() + + if lon_lines is None: + lon_lines = np.arange(int(ds.lon.min().values), int(ds.lon.max().values) + 1) + if lat_lines is None: + lat_lines = np.arange(int(ds.lat.min().values), int(ds.lat.max().values) + 1) + + plt.rcParams.copy() + plt.rcParams.update( + { + "font.weight": "bold", + "axes.labelweight": "bold", + "xtick.direction": "in", + "ytick.direction": "in", + "xtick.major.size": 10, + "ytick.major.size": 10, + "xtick.minor.size": 7, + "ytick.minor.size": 7, + "font.size": 14, + "axes.linewidth": 2, + "ytick.labelsize": 12, + "xtick.labelsize": 12, + } + ) + + max_c = ds[field].max(dim="z") + max_x = ds[field].max(dim="y") + max_y = ds[field].max(dim="x").T + + trgx = ds["x"].values + trgy = ds["y"].values + trgz = ds["z"].values + + max_height = int(np.floor(trgz.max()) / 1e3) + sideticks = np.arange(max_height / 4, max_height + 1, max_height / 4).astype(int) + + if cmap is None: + cmap = "pyart_Carbone42" + if vmin is None: + vmin = grid.fields[field]["data"].min() + if vmax is None: + vmax = grid.fields[field]["data"].max() + if title is None: + title = f"Max-{field.upper()[:3]}" + + def plot_range_rings(ax_xy, max_range): + """ + Plots range rings at 50 km intervals. + + Parameters + ---------- + ax_xy : matplotlib.axes.Axes + The axis on which to plot the range rings. + max_range : float + The maximum range for the range rings. + + Returns + ------- + None + """ + background_color = ax_xy.get_facecolor() + color = "k" if sum(background_color[:3]) / 3 > 0.5 else "w" + + for i, r in enumerate(np.arange(5e4, np.floor(max_range) + 1, 5e4)): + label = f"Ring Dist. {int(r/1e3)} km" if i == 0 else None + ax_xy.plot( + r * np.cos(np.arange(0, 360) * np.pi / 180), + r * np.sin(np.arange(0, 360) * np.pi / 180), + color=color, + ls="--", + linewidth=0.4, + alpha=0.3, + label=label, + ) + + ax_xy.legend(loc="upper right", prop={"weight": "normal", "size": 8}) + + if show_progress: + print("...................................") + print( + f"Plotting {title}: {ds['time'].dt.strftime('%Y%m%d %H:%M:%S').values.item()}" + ) + print("...................................\n") + + def _get_projection(ds): + """ + Determine the central latitude and longitude from a dataset + and return the corresponding projection. + + Parameters + ---------- + ds : xarray.Dataset + The dataset from which to extract latitude and longitude + information. + + Returns + ------- + projection : cartopy.crs.Projection + A Cartopy projection object centered on the extracted or + calculated latitude and longitude. + """ + + def get_coord_or_attr(ds, coord_name, attr_name): + """Helper function to get a coordinate or attribute, or + calculate median if available. + """ + if coord_name in ds: + return ( + ds[coord_name].values.item() + if ds[coord_name].values.ndim == 0 + else ds[coord_name].values[0] + ) + if f"origin_{coord_name}" in ds.coords: + return ds.coords[f"origin_{coord_name}"].median().item() + if f"radar_{coord_name}" in ds.coords: + return ds.coords[f"radar_{coord_name}"].median().item() + return ds.attrs.get(attr_name, None) + + lat_0 = get_coord_or_attr( + ds, "latitude", "origin_latitude" + ) or get_coord_or_attr(ds, "radar_latitude", "origin_latitude") + lon_0 = get_coord_or_attr( + ds, "longitude", "origin_longitude" + ) or get_coord_or_attr(ds, "radar_longitude", "origin_longitude") + + if lat_0 is None or lon_0 is None: + lat_0 = ds.lat.mean().item() + lon_0 = ds.lon.mean().item() + + projection = ccrs.LambertAzimuthalEqualArea(lon_0, lat_0) + return projection + + projection = _get_projection(ds) + + # FIG + fig = plt.figure(figsize=[10.3, 10]) + left, bottom, width, height = 0.1, 0.1, 0.6, 0.2 + ax_xy = plt.axes((left, bottom, width, width), projection=projection) + ax_x = plt.axes((left, bottom + width, width, height)) + ax_y = plt.axes((left + width, bottom, height, width)) + ax_cnr = plt.axes((left + width, bottom + width, left + left, height)) + if colorbar: + ax_cb = plt.axes((left - 0.015 + width + height + 0.02, bottom, 0.02, width)) + + # Set axis label formatters + ax_x.xaxis.set_major_formatter(NullFormatter()) + ax_y.yaxis.set_major_formatter(NullFormatter()) + ax_cnr.yaxis.set_major_formatter(NullFormatter()) + ax_cnr.xaxis.set_major_formatter(NullFormatter()) + ax_x.set_ylabel("Height (km)", size=13) + ax_y.set_xlabel("Height (km)", size=13) + + # Draw CAPPI + plt.sca(ax_xy) + xy = ax_xy.pcolormesh(trgx, trgy, max_c, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) + + # Add map features + if add_map: + map_features(ax_xy, lat_lines, lon_lines) + + ax_xy.minorticks_on() + + if range_rings: + plot_range_rings(ax_xy, trgx.max()) + + ax_xy.set_xlim(trgx.min(), trgx.max()) + ax_xy.set_ylim(trgx.min(), trgx.max()) + + # Draw colorbar + if colorbar: + cb = plt.colorbar(xy, cax=ax_cb) + cb.set_label(ds[field].attrs["units"], size=15) + + background_color = ax_xy.get_facecolor() + color = "k" if sum(background_color[:3]) / 3 > 0.5 else "w" + + plt.sca(ax_x) + plt.pcolormesh(trgx / 1e3, trgz / 1e3, max_x, cmap=cmap, vmin=vmin, vmax=vmax) + plt.yticks(sideticks) + ax_x.set_xlim(trgx.min() / 1e3, trgx.max() / 1e3) + ax_x.grid(axis="y", lw=0.5, color=color, alpha=0.5, ls=":") + ax_x.minorticks_on() + + plt.sca(ax_y) + plt.pcolormesh(trgz / 1e3, trgy / 1e3, max_y, cmap=cmap, vmin=vmin, vmax=vmax) + ax_y.set_xticks(sideticks) + ax_y.set_ylim(trgx.min() / 1e3, trgx.max() / 1e3) + ax_y.grid(axis="x", lw=0.5, color=color, alpha=0.5, ls=":") + ax_y.minorticks_on() + + plt.sca(ax_cnr) + plt.tick_params( + axis="both", # changes apply to both axes + which="both", # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + left=False, + right=False, + labelbottom=False, + ) + + # Retrieve instrument name + instrument_name = ds.attrs.get("instrument_name", "N/A")[:4] + + # Center-align text in the corner box + plt.text( + 0.5, + 0.90, + f"Site: {instrument_name}", + size=13, + weight="bold", + ha="center", + va="center", + ) + plt.text(0.5, 0.76, title, size=13, weight="bold", ha="center", va="center") + plt.text( + 0.5, + 0.63, + f"Max Range: {np.floor(trgx.max() / 1e3)} km", + size=11, + ha="center", + va="center", + ) + plt.text( + 0.5, + 0.47, + f"Max Height: {np.floor(trgz.max() / 1e3)} km", + size=11, + ha="center", + va="center", + ) + plt.text( + 0.5, + 0.28, + ds["time"].dt.strftime("%H:%M:%S Z").values.item(), + weight="bold", + size=16, + ha="center", + va="center", + ) + plt.text( + 0.5, + 0.13, + ds["time"].dt.strftime("%d %b, %Y UTC").values.item(), + size=13.5, + ha="center", + va="center", + ) + ax_xy.set_aspect("auto") + + if add_slogan: + fig.text( + 0.1, + 0.06, + "Powered by Py-ART", # Coordinates close to (0, 0) for lower-left corner + fontsize=9, + fontname="Courier New", + # bbox=dict(facecolor='none', boxstyle='round,pad=0.5') + ) + + if savedir is not None: + radar_name = ds.attrs.get("instrument_name", "Radar") + time_str = ds["time"].dt.strftime("%Y%m%d%H%M%S").values.item() + figname = f"{savedir}{os.sep}{title}_{radar_name}_{time_str}.png" + plt.savefig(fname=figname, dpi=dpi, bbox_inches="tight") + print(f"Figure(s) saved as {figname}") + + # plt.rcParams.update(original_rc_params) + plt.rcdefaults() + + if show_figure: + plt.show() + else: + plt.close() + + +def map_features(ax, lat_lines, lon_lines): + """ + Adds map features and gridlines to the plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis on which to add map features and gridlines. + lat_lines : array-like + Latitude lines for the gridlines. + lon_lines : array-like + Longitude lines for the gridlines. + + Returns + ------- + None + """ + background_color = ax.get_facecolor() + color = "k" if sum(background_color[:3]) / 3 > 0.5 else "w" + + # Labeling gridlines depending on the projection + if isinstance(ax.projection, (ccrs.PlateCarree, ccrs.Mercator)): + gl = ax.gridlines( + xlocs=lon_lines, + ylocs=lat_lines, + linewidth=1, + alpha=0.5, + linestyle="--", + draw_labels=True, + ) + gl.top_labels = False + gl.right_labels = False + gl.xlines = False + gl.ylines = False + gl.xlabel_style = {"color": color} + gl.ylabel_style = {"color": color} + ax.add_feature(feat.COASTLINE, alpha=0.8, lw=1, ec=color) + ax.add_feature(feat.BORDERS, alpha=0.7, lw=0.7, ls="--", ec=color) + ax.add_feature( + feat.STATES.with_scale("10m"), alpha=0.6, lw=0.5, ls=":", ec=color + ) + + elif isinstance( + ax.projection, (ccrs.LambertConformal, ccrs.LambertAzimuthalEqualArea) + ): + ax.figure.canvas.draw() + gl = ax.gridlines( + crs=ccrs.PlateCarree(), + xlocs=lon_lines, + ylocs=lat_lines, + linewidth=1, + alpha=0.5, + linestyle="--", + draw_labels=False, + ) + gl.xlines = False + gl.ylines = False + gl.xlabel_style = {"color": color} + gl.ylabel_style = {"color": color} + ax.add_feature(feat.COASTLINE, alpha=0.8, lw=1, ec=color) + ax.add_feature(feat.BORDERS, alpha=0.7, lw=0.7, ls="--", ec=color) + ax.add_feature( + feat.STATES.with_scale("10m"), alpha=0.6, lw=0.5, ls=":", ec=color + ) + # Label the end-points of the gridlines using custom tick makers + ax.xaxis.set_major_formatter(LONGITUDE_FORMATTER) + ax.yaxis.set_major_formatter(LATITUDE_FORMATTER) + from pyart.graph.gridmapdisplay import lambert_xticks, lambert_yticks + + lambert_xticks(ax, lon_lines) + lambert_yticks(ax, lat_lines) + else: + ax.gridlines(xlocs=lon_lines, ylocs=lat_lines) diff --git a/tests/graph/test_plot_maxcappi.py b/tests/graph/test_plot_maxcappi.py new file mode 100644 index 0000000000..857183fcd2 --- /dev/null +++ b/tests/graph/test_plot_maxcappi.py @@ -0,0 +1,106 @@ +import os + +import matplotlib.pyplot as plt +import pytest + +import pyart + + +@pytest.mark.skipif( + not pyart.graph.gridmapdisplay._CARTOPY_AVAILABLE, reason="Cartopy is not installed" +) +def test_plot_maxcappi_simple(outfile=None): + """ + Test the basic functionality of plot_maxcappi. + """ + # Create a test grid using Py-ART's testing utility + grid = pyart.testing.make_target_grid() + grid.z["data"] = grid.z["data"] * 10 + 100 + grid.metadata["instrument_name"] = "GRID" + + # Use plot_maxcappi with the generated grid + pyart.graph.max_cappi.plot_maxcappi( + grid=grid, + field="reflectivity", + savedir=None, # Do not save the plot + show_figure=False, # Do not show the plot + ) + + if outfile: + plt.savefig(outfile) + plt.close() + + +@pytest.mark.skipif( + not pyart.graph.gridmapdisplay._CARTOPY_AVAILABLE, reason="Cartopy is not installed" +) +def test_plot_maxcappi_with_save(outfile=None): + """ + Test plot_maxcappi and save the output to a file. + """ + # Create a test grid using Py-ART's testing utility + grid = pyart.testing.make_target_grid() + grid.z["data"] = grid.z["data"] * 10 + 100 + grid.metadata["instrument_name"] = "GRID" + + # Define the output file path + outfile = outfile or "test_plot_maxcappi_output.png" + + # Use plot_maxcappi with the generated grid + pyart.graph.max_cappi.plot_maxcappi( + grid=grid, + field="reflectivity", + savedir=None, # Handle saving manually below + show_figure=False, # Do not show the plot + ) + + # Save the figure to a file + plt.savefig(outfile) + plt.close() + + # Check if the file was created + assert os.path.exists(outfile), "The plot was not saved as expected." + + +@pytest.mark.skipif( + not pyart.graph.gridmapdisplay._CARTOPY_AVAILABLE, reason="Cartopy is not installed" +) +def test_plot_maxcappi_with_all_options(outfile=None): + """ + Test plot_maxcappi with all options enabled. + """ + import cartopy.crs as ccrs + + # Create a test grid using Py-ART's testing utility + grid = pyart.testing.make_target_grid() + grid.z["data"] = grid.z["data"] * 10 + 100 + grid.metadata["instrument_name"] = "GRID" + + # Use a custom projection for testing + projection = ccrs.Mercator() + + # Use plot_maxcappi with additional options + pyart.graph.max_cappi.plot_maxcappi( + grid=grid, + field="reflectivity", + title="Test Max-CAPPI", + lat_lines=None, + lon_lines=None, + add_map=True, + projection=projection, + colorbar=True, + range_rings=True, + dpi=150, + savedir=None, + show_figure=False, + ) + + if outfile: + plt.savefig(outfile) + plt.close() + + +if __name__ == "__main__": + test_plot_maxcappi_simple("figure_plot_maxcappi_simple.png") + test_plot_maxcappi_with_save("figure_plot_maxcappi_output.png") + test_plot_maxcappi_with_all_options("figure_plot_maxcappi_all_options.png") From 6de78cad36bde9f2ce9242c8d5e8e8497a9e0ce8 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 29 Aug 2024 17:37:23 -0400 Subject: [PATCH 09/13] Update tests and fix Example --- examples/plotting/plot_max_cappi.py | 9 ++++----- pyart/core/grid.py | 11 ++++++++++- tests/core/test_grid.py | 4 ++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/plotting/plot_max_cappi.py b/examples/plotting/plot_max_cappi.py index 3e7567fa40..823ae95856 100644 --- a/examples/plotting/plot_max_cappi.py +++ b/examples/plotting/plot_max_cappi.py @@ -36,8 +36,7 @@ ######################################### # ** Second Example # -# Let's read in a cfradial file and -# create a grid. +# Let's read in a cfradial file and create a grid. import logging @@ -82,16 +81,16 @@ def download_nexrad(timezone, date, site, local_date=False): radar = pyart.io.read_nexrad_archive("s3://" + file) # Create a 3D grid -# mask out last 10 gates of each ray, this removes the "ring" around the radar. +# Mask out last 10 gates of each ray, this removes the "ring" around the radar. radar.fields["reflectivity"]["data"][:, -10:] = np.ma.masked -# exclude masked gates from the gridding +# Exclude masked gates from the gridding gatefilter = pyart.filters.GateFilter(radar) gatefilter.exclude_transition() gatefilter.exclude_masked("reflectivity") gatefilter.exclude_outside("reflectivity", 10, 80) -# perform Cartesian mapping, limit to the reflectivity field. +# Perform Cartesian mapping, limit to the reflectivity field. max_range = np.ceil(radar.range["data"].max()) if max_range / 1e3 > 250: max_range = 250 * 1e3 diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 2ac37d8a3a..85fe149022 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -425,8 +425,17 @@ def _process_radar_name(radar_name): ds["radar_name"] = xarray.DataArray( radar_name, dims=("nradar"), attrs=get_metadata("radar_name") ) + else: + radar_name = np.array( + ["ExampleRadar"], dtype="S" + ) # or use 'U' for unicode strings - ds.attrs = self.metadata + # Add radar_name to attributes, defaulting to 'ExampleRadar' if radar_name doesn't exist or is empty + ds.attrs["radar_name"] = ( + radar_name.item() if radar_name.size > 0 else "ExampleRadar" + ) + ds.attrs["nradar"] = self.nradar + ds.attrs.update(self.metadata) for key in ds.attrs: try: ds.attrs[key] = ds.attrs[key].decode("utf-8") diff --git a/tests/core/test_grid.py b/tests/core/test_grid.py index 8d6f775aec..0fedbe0450 100644 --- a/tests/core/test_grid.py +++ b/tests/core/test_grid.py @@ -103,8 +103,8 @@ def test_grid_to_xarray(): assert np.array_equal(ds.time.data, time) # Check radar-specific attributes - assert "radar_name" in ds.data_vars - assert ds.radar_name.data[0] == "ExampleRadar" + assert ds.attrs["nradar"] == 1 + assert ds.attrs["radar_name"] == "ExampleRadar" def _check_dicts_similar(dic1, dic2): From 89efaf5841baccb0ea8312f3ffcfa0d41499ce67 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 29 Aug 2024 19:03:18 -0400 Subject: [PATCH 10/13] drop a few args --- pyart/graph/__init__.py | 1 - pyart/graph/max_cappi.py | 11 +---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/pyart/graph/__init__.py b/pyart/graph/__init__.py index 2cf9b6cd8a..49f08214ae 100644 --- a/pyart/graph/__init__.py +++ b/pyart/graph/__init__.py @@ -62,7 +62,6 @@ from .radardisplay_airborne import AirborneRadarDisplay # noqa from .radarmapdisplay import RadarMapDisplay # noqa from .radarmapdisplay_basemap import RadarMapDisplayBasemap # noqa -from . import max_cappi # noqa from .max_cappi import plot_maxcappi # noqa __all__ = [s for s in dir() if not s.startswith("_")] diff --git a/pyart/graph/max_cappi.py b/pyart/graph/max_cappi.py index 3e0015a936..25d0309f75 100644 --- a/pyart/graph/max_cappi.py +++ b/pyart/graph/max_cappi.py @@ -37,7 +37,6 @@ def plot_maxcappi( colorbar=True, range_rings=False, dpi=100, - show_progress=False, savedir=None, show_figure=True, add_slogan=False, @@ -74,8 +73,6 @@ def plot_maxcappi( Whether to include range rings at 50 km intervals. Default is True. dpi : int, optional DPI (dots per inch) for the plot. Default is 100. - show_progress : bool, optional - Whether to print progress messages. Default is False. savedir : str, optional Directory where the plot will be saved. If None, the plot is not saved. show_figure : bool, optional @@ -87,6 +84,7 @@ def plot_maxcappi( ------- None This function does not return any value. It generates and optionally displays/saves a plot. + Notes ----- Author : Hamid Ali Syed (@syedhamidali) @@ -169,13 +167,6 @@ def plot_range_rings(ax_xy, max_range): ax_xy.legend(loc="upper right", prop={"weight": "normal", "size": 8}) - if show_progress: - print("...................................") - print( - f"Plotting {title}: {ds['time'].dt.strftime('%Y%m%d %H:%M:%S').values.item()}" - ) - print("...................................\n") - def _get_projection(ds): """ Determine the central latitude and longitude from a dataset From 0272e7be2a819498c66c577450a830102c7fd253 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Thu, 29 Aug 2024 19:56:17 -0400 Subject: [PATCH 11/13] Fix grid.to_xarray() --- pyart/core/grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 85fe149022..3fc0ca562f 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -422,7 +422,7 @@ def _process_radar_name(radar_name): if self.radar_name is not None: radar_name = _process_radar_name(self.radar_name["data"]) - ds["radar_name"] = xarray.DataArray( + ds.coords["radar_name"] = xarray.DataArray( radar_name, dims=("nradar"), attrs=get_metadata("radar_name") ) else: From ebcf03141b438c74882a8b0eee6e4fd7ac600884 Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Fri, 30 Aug 2024 00:17:23 -0400 Subject: [PATCH 12/13] FIX: Fixed some tests and to_xarray() function --- pyart/core/grid.py | 49 +++++++++++++++++++------------------- tests/core/test_grid.py | 2 +- tests/retrieve/test_qvp.py | 4 +++- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 3fc0ca562f..cf8da27be1 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -329,7 +329,7 @@ def to_xarray(self): def _process_radar_name(radar_name): """Process radar_name to handle different formats.""" if radar_name.dtype.kind in {"S", "U"} and radar_name.ndim > 1: - return np.array([b"".join(radar_name.flatten())]) + return radar_name.flatten() return radar_name lon, lat = self.get_point_longitude_latitude() @@ -388,10 +388,22 @@ def _process_radar_name(radar_name): attrs=projection, ) + # Handle origin and radar attributes with appropriate dimensions for attr_name in [ "origin_latitude", "origin_longitude", "origin_altitude", + ]: + if hasattr(self, attr_name): + attr_data = getattr(self, attr_name) + if attr_data is not None: + attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + ds.coords[attr_name] = xarray.DataArray( + attr_value, dims=("time",), attrs=get_metadata(attr_name) + ) + + # Radar-specific attributes that should have the nradar dimension + for attr_name in [ "radar_altitude", "radar_latitude", "radar_longitude", @@ -400,41 +412,30 @@ def _process_radar_name(radar_name): if hasattr(self, attr_name): attr_data = getattr(self, attr_name) if attr_data is not None: - dims = ("time",) if "origin_" in attr_name else ("nradar",) - attr_value = ( - np.ma.expand_dims(attr_data["data"][0], 0) - if "radar_time" not in attr_name - else [ - np.array( - num2date( - attr_data["data"][0], units=attr_data["units"] - ), - dtype="datetime64[ns]", - ) - ] - ) ds.coords[attr_name] = xarray.DataArray( - attr_value, dims=dims, attrs=get_metadata(attr_name) + attr_data["data"], + dims=("nradar",), + attrs=get_metadata(attr_name), ) if "radar_time" in ds.variables: ds.radar_time.attrs.pop("calendar") + # Handle radar_name and ensure it has the correct dimension if self.radar_name is not None: radar_name = _process_radar_name(self.radar_name["data"]) ds.coords["radar_name"] = xarray.DataArray( - radar_name, dims=("nradar"), attrs=get_metadata("radar_name") + radar_name, dims=("nradar",), attrs=get_metadata("radar_name") ) else: - radar_name = np.array( - ["ExampleRadar"], dtype="S" - ) # or use 'U' for unicode strings + radar_name = np.array(["ExampleRadar"], dtype="U") + ds.coords["radar_name"] = xarray.DataArray( + radar_name, dims=("nradar",), attrs=get_metadata("radar_name") + ) - # Add radar_name to attributes, defaulting to 'ExampleRadar' if radar_name doesn't exist or is empty - ds.attrs["radar_name"] = ( - radar_name.item() if radar_name.size > 0 else "ExampleRadar" - ) - ds.attrs["nradar"] = self.nradar + # Add radar_name to attributes + ds.attrs["radar_name"] = radar_name if radar_name.size > 0 else "ExampleRadar" + ds.attrs["nradar"] = radar_name.size ds.attrs.update(self.metadata) for key in ds.attrs: try: diff --git a/tests/core/test_grid.py b/tests/core/test_grid.py index 0fedbe0450..8ea9c58b47 100644 --- a/tests/core/test_grid.py +++ b/tests/core/test_grid.py @@ -104,7 +104,7 @@ def test_grid_to_xarray(): # Check radar-specific attributes assert ds.attrs["nradar"] == 1 - assert ds.attrs["radar_name"] == "ExampleRadar" + assert ds.attrs["radar_name"][0] == "ExampleRadar" def _check_dicts_similar(dic1, dic2): diff --git a/tests/retrieve/test_qvp.py b/tests/retrieve/test_qvp.py index 2293ae56f7..7991fe8b45 100644 --- a/tests/retrieve/test_qvp.py +++ b/tests/retrieve/test_qvp.py @@ -498,7 +498,9 @@ def test_find_nearest_gate(test_radar): assert ind_ray == 141.0 assert ind_rng == 145.0 assert azi == 141.0 - assert rng == 14514.514 + assert ( + abs(rng - 14514.514) < 1e-3 + ) # Allow for a small tolerance in floating-point comparison def test_find_neighbour_gates(test_radar): From ef72aa00f895757cdb755e92afc26b8b90c0de2a Mon Sep 17 00:00:00 2001 From: syedhamidali Date: Fri, 30 Aug 2024 05:45:49 -0400 Subject: [PATCH 13/13] FIX: Fixed to_xarray() for more than one radar --- pyart/core/grid.py | 12 +++++++++--- tests/core/test_grid.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index cf8da27be1..b79295f4fd 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -328,8 +328,12 @@ def to_xarray(self): def _process_radar_name(radar_name): """Process radar_name to handle different formats.""" - if radar_name.dtype.kind in {"S", "U"} and radar_name.ndim > 1: - return radar_name.flatten() + if radar_name.dtype.kind == "S" and radar_name.ndim > 1: + # Join each row of bytes into a single byte string + return np.array( + [b"".join(row) for row in radar_name], + dtype=f"|S{radar_name.shape[1]}", + ) return radar_name lon, lat = self.get_point_longitude_latitude() @@ -434,7 +438,9 @@ def _process_radar_name(radar_name): ) # Add radar_name to attributes - ds.attrs["radar_name"] = radar_name if radar_name.size > 0 else "ExampleRadar" + ds.attrs["radar_name"] = ( + radar_name.tolist() if radar_name.size > 1 else radar_name.item() + ) ds.attrs["nradar"] = radar_name.size ds.attrs.update(self.metadata) for key in ds.attrs: diff --git a/tests/core/test_grid.py b/tests/core/test_grid.py index 8ea9c58b47..0fedbe0450 100644 --- a/tests/core/test_grid.py +++ b/tests/core/test_grid.py @@ -104,7 +104,7 @@ def test_grid_to_xarray(): # Check radar-specific attributes assert ds.attrs["nradar"] == 1 - assert ds.attrs["radar_name"][0] == "ExampleRadar" + assert ds.attrs["radar_name"] == "ExampleRadar" def _check_dicts_similar(dic1, dic2):