diff --git a/examples/plotting/plot_max_cappi.py b/examples/plotting/plot_max_cappi.py new file mode 100644 index 0000000000..823ae95856 --- /dev/null +++ b/examples/plotting/plot_max_cappi.py @@ -0,0 +1,111 @@ +""" +============= +Plot Max-CAPPI +============= + +This is an example of how to plot a Max-CAPPI +within a Py-ART grid display object. + +""" + +print(__doc__) + +# Author: Hamid Ali Syed (syed44@purdue.edu) +# License: BSD 3 clause + +import matplotlib.pyplot as plt +import numpy as np + +import pyart +from pyart.testing import get_test_data + +######################################### +# ** MAX-CAPPI Display +# + +# Define and Read in the test data +grid_file = get_test_data("20110520100000_nexrad_grid.nc") +grid = pyart.io.read_grid(grid_file) + + +# Create a grid display +gdisplay = pyart.graph.GridMapDisplay(grid) +gdisplay.plot_maxcappi(field="REF", range_rings=True, add_slogan=True) + + +######################################### +# ** Second Example +# +# Let's read in a cfradial file and create a grid. + + +import logging +from datetime import datetime + +import fsspec +import pytz + + +def download_nexrad(timezone, date, site, local_date=False): + """Download NEXRAD radar data from an S3 bucket.""" + try: + utc_date = ( + pytz.timezone(timezone).localize(date).astimezone(pytz.utc) + if local_date + else date + ) + logging.info(f"Time: {utc_date}") + + fs = fsspec.filesystem("s3", anon=True) + nexrad_path = utc_date.strftime( + f"s3://noaa-nexrad-level2/%Y/%m/%d/{site}/{site}%Y%m%d_%H*" + ) + files = sorted(fs.glob(nexrad_path)) + + return [file for file in files if not file.endswith("_MDM")] + except Exception as e: + logging.error("Error in processing: %s", e) + return [] + + +# Load NEXRAD data from S3 Bucket +site = "PHWA" +timezone = "UTC" +date = datetime(2024, 8, 25, 8, 29) + +# Correctly passing the site and timezone +file = download_nexrad(timezone, date, site, local_date=False)[0] + + +# Read the data using nexrad_archive reader +radar = pyart.io.read_nexrad_archive("s3://" + file) + +# Create a 3D grid +# Mask out last 10 gates of each ray, this removes the "ring" around the radar. +radar.fields["reflectivity"]["data"][:, -10:] = np.ma.masked + +# Exclude masked gates from the gridding +gatefilter = pyart.filters.GateFilter(radar) +gatefilter.exclude_transition() +gatefilter.exclude_masked("reflectivity") +gatefilter.exclude_outside("reflectivity", 10, 80) + +# Perform Cartesian mapping, limit to the reflectivity field. +max_range = np.ceil(radar.range["data"].max()) +if max_range / 1e3 > 250: + max_range = 250 * 1e3 + +grid = pyart.map.grid_from_radars( + (radar,), + gatefilters=(gatefilter,), + grid_shape=(30, 441, 441), + grid_limits=((0, 10000), (-max_range, max_range), (-max_range, max_range)), + fields=["reflectivity"], +) + +# Create a grid display +gdisplay = pyart.graph.GridMapDisplay(grid) +with plt.style.context("dark_background"): + gdisplay.plot_maxcappi( + field="reflectivity", cmap="pyart_HomeyerRainbow", add_slogan=True + ) diff --git a/pyart/core/grid.py b/pyart/core/grid.py index 9026f5d35b..b79295f4fd 100644 --- a/pyart/core/grid.py +++ b/pyart/core/grid.py @@ -320,68 +320,136 @@ def to_xarray(self): x, y, z : dict, 1D Distance from the grid origin for each Cartesian coordinate axis in a one dimensional array. - """ - if not _XARRAY_AVAILABLE: raise MissingOptionalDependency( - "Xarray is required to use Grid.to_xarray but is not " + "installed!" + "Xarray is required to use Grid.to_xarray but is not installed!" ) + def _process_radar_name(radar_name): + """Process radar_name to handle different formats.""" + if radar_name.dtype.kind == "S" and radar_name.ndim > 1: + # Join each row of bytes into a single byte string + return np.array( + [b"".join(row) for row in radar_name], + dtype=f"|S{radar_name.shape[1]}", + ) + return radar_name + lon, lat = self.get_point_longitude_latitude() z = self.z["data"] y = self.y["data"] x = self.x["data"] - time = np.array([num2date(self.time["data"][0], self.time["units"])]) + time = np.array([num2date(self.time["data"][0], units=self.time["units"])]) ds = xarray.Dataset() - for field in list(self.fields.keys()): - field_data = self.fields[field]["data"] + for field, field_info in self.fields.items(): + field_data = field_info["data"] data = xarray.DataArray( np.ma.expand_dims(field_data, 0), dims=("time", "z", "y", "x"), coords={ - "time": (["time"], time), - "z": (["z"], z), + "time": time, + "z": z, "lat": (["y", "x"], lat), "lon": (["y", "x"], lon), - "y": (["y"], y), - "x": (["x"], x), + "y": y, + "x": x, }, ) - for meta in list(self.fields[field].keys()): - if meta != "data": - data.attrs.update({meta: self.fields[field][meta]}) - + data.attrs.update({k: v for k, v in field_info.items() if k != "data"}) ds[field] = data - ds.lon.attrs = [ - ("long_name", "longitude of grid cell center"), - ("units", "degree_E"), - ("standard_name", "Longitude"), - ] - ds.lat.attrs = [ - ("long_name", "latitude of grid cell center"), - ("units", "degree_N"), - ("standard_name", "Latitude"), - ] - - ds.z.attrs = get_metadata("z") - ds.y.attrs = get_metadata("y") - ds.x.attrs = get_metadata("x") - - ds.z.encoding["_FillValue"] = None - ds.lat.encoding["_FillValue"] = None - ds.lon.encoding["_FillValue"] = None - - # Grab original radar(s) name and number of radars used to make grid - ds.attrs["nradar"] = self.nradar - ds.attrs["radar_name"] = self.radar_name - - # Grab all metadata - ds.attrs.update(self.metadata) - - ds.close() + + ds.lon.attrs = { + "long_name": "longitude of grid cell center", + "units": "degree_E", + "standard_name": "Longitude", + } + ds.lat.attrs = { + "long_name": "latitude of grid cell center", + "units": "degree_N", + "standard_name": "Latitude", + } + + for attr in [ds.z, ds.lat, ds.lon]: + attr.encoding["_FillValue"] = None + + from ..io.grid_io import _make_coordinatesystem_dict + + ds.coords["ProjectionCoordinateSystem"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + attrs=_make_coordinatesystem_dict(self), + ) + + projection = self.projection.copy() + if "_include_lon_0_lat_0" in projection: + projection["_include_lon_0_lat_0"] = str( + projection["_include_lon_0_lat_0"] + ).lower() + ds.coords["projection"] = xarray.DataArray( + data=np.array(1, dtype="int32"), + attrs=projection, + ) + + # Handle origin and radar attributes with appropriate dimensions + for attr_name in [ + "origin_latitude", + "origin_longitude", + "origin_altitude", + ]: + if hasattr(self, attr_name): + attr_data = getattr(self, attr_name) + if attr_data is not None: + attr_value = np.ma.expand_dims(attr_data["data"][0], 0) + ds.coords[attr_name] = xarray.DataArray( + attr_value, dims=("time",), attrs=get_metadata(attr_name) + ) + + # Radar-specific attributes that should have the nradar dimension + for attr_name in [ + "radar_altitude", + "radar_latitude", + "radar_longitude", + "radar_time", + ]: + if hasattr(self, attr_name): + attr_data = getattr(self, attr_name) + if attr_data is not None: + ds.coords[attr_name] = xarray.DataArray( + attr_data["data"], + dims=("nradar",), + attrs=get_metadata(attr_name), + ) + + if "radar_time" in ds.variables: + ds.radar_time.attrs.pop("calendar") + + # Handle radar_name and ensure it has the correct dimension + if self.radar_name is not None: + radar_name = _process_radar_name(self.radar_name["data"]) + ds.coords["radar_name"] = xarray.DataArray( + radar_name, dims=("nradar",), attrs=get_metadata("radar_name") + ) + else: + radar_name = np.array(["ExampleRadar"], dtype="U") + ds.coords["radar_name"] = xarray.DataArray( + radar_name, dims=("nradar",), attrs=get_metadata("radar_name") + ) + + # Add radar_name to attributes + ds.attrs["radar_name"] = ( + radar_name.tolist() if radar_name.size > 1 else radar_name.item() + ) + ds.attrs["nradar"] = radar_name.size + ds.attrs.update(self.metadata) + for key in ds.attrs: + try: + ds.attrs[key] = ds.attrs[key].decode("utf-8") + except AttributeError: + pass + + ds.close() return ds def add_field(self, field_name, field_dict, replace_existing=False): diff --git a/pyart/graph/__init__.py b/pyart/graph/__init__.py index 8de2949cf6..49f08214ae 100644 --- a/pyart/graph/__init__.py +++ b/pyart/graph/__init__.py @@ -62,5 +62,6 @@ from .radardisplay_airborne import AirborneRadarDisplay # noqa from .radarmapdisplay import RadarMapDisplay # noqa from .radarmapdisplay_basemap import RadarMapDisplayBasemap # noqa +from .max_cappi import plot_maxcappi # noqa __all__ = [s for s in dir() if not s.startswith("_")] diff --git a/pyart/graph/gridmapdisplay.py b/pyart/graph/gridmapdisplay.py index fc7dcaecb8..80ca87b877 100644 --- a/pyart/graph/gridmapdisplay.py +++ b/pyart/graph/gridmapdisplay.py @@ -50,6 +50,8 @@ except ImportError: _LAMBERT_GRIDLINES = False +from . import max_cappi # noqa + class GridMapDisplay: """ @@ -120,7 +122,7 @@ def plot_grid( add_grid_lines=True, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot the grid using xarray and cartopy. @@ -257,7 +259,7 @@ def plot_grid( vmin=vmin, vmax=vmax, add_colorbar=False, - **kwargs + **kwargs, ) self.mappables.append(pm) @@ -414,7 +416,7 @@ def plot_latitudinal_level( fig=None, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot a slice along a given latitude. @@ -575,7 +577,7 @@ def plot_longitudinal_level( fig=None, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot a slice along a given longitude. @@ -718,7 +720,7 @@ def plot_cross_section( fig=None, ticks=None, ticklabs=None, - **kwargs + **kwargs, ): """ Plot a cross section through a set of given points (latitude, @@ -849,7 +851,7 @@ def plot_cross_section( add_colorbar=False, ax=ax, cmap=cmap, - **kwargs + **kwargs, ) self.mappables.append(plot) @@ -1110,6 +1112,46 @@ def cartopy_coastlines(self): category="physical", name="coastline", scale="10m", facecolor="none" ) + def plot_maxcappi( + self, + field, + cmap=None, + vmin=None, + vmax=None, + title=None, + lat_lines=None, + lon_lines=None, + add_map=True, + projection=None, + colorbar=True, + range_rings=False, + dpi=100, + savedir=None, + show_figure=True, + add_slogan=False, + **kwargs, + ): + # Call the plot_maxcappi function from the max_cappi module or object + max_cappi.plot_maxcappi( + grid=self.grid, # Assuming `self.grid` holds the Grid object in your class + field=field, + cmap=cmap, + vmin=vmin, + vmax=vmax, + title=title, + lat_lines=lat_lines, + lon_lines=lon_lines, + add_map=add_map, + projection=projection, + colorbar=colorbar, + range_rings=range_rings, + dpi=dpi, + savedir=savedir, + show_figure=show_figure, + add_slogan=add_slogan, + **kwargs, + ) + # These methods are a hack to allow gridlines when the projection is lambert # https://nbviewer.jupyter.org/gist/ajdawson/dd536f786741e987ae4e diff --git a/pyart/graph/max_cappi.py b/pyart/graph/max_cappi.py new file mode 100644 index 0000000000..25d0309f75 --- /dev/null +++ b/pyart/graph/max_cappi.py @@ -0,0 +1,434 @@ +""" +Plot Max-CAPPI + +This module provides a function to plot a Maximum Constant Altitude Plan Position Indicator (Max-CAPPI) +from radar data using an xarray dataset. The function includes options for adding map features, range rings, +color bars, and customized visual settings. + +Author: Syed Hamid Ali (@syedhamidali) +""" + +__all__ = ["plot_maxcappi"] + +import os +import warnings + +import cartopy.crs as ccrs +import cartopy.feature as feat +import matplotlib.pyplot as plt +import numpy as np +from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER +from matplotlib.ticker import NullFormatter + +warnings.filterwarnings("ignore") + + +def plot_maxcappi( + grid, + field, + cmap=None, + vmin=None, + vmax=None, + title=None, + lat_lines=None, + lon_lines=None, + add_map=True, + projection=None, + colorbar=True, + range_rings=False, + dpi=100, + savedir=None, + show_figure=True, + add_slogan=False, + **kwargs, +): + """ + Plots a Constant Altitude Plan Position Indicator (CAPPI) using an xarray Dataset. + + Parameters + ---------- + grid : pyart.core.Grid + The grid object containing the radar data to be plotted. + field : str + The radar field to be plotted (e.g., "REF", "VEL", "WIDTH"). + cmap : str or matplotlib colormap, optional + Colormap to use for the plot. Default is "SyedSpectral" if available, otherwise "HomeyerRainbow". + vmin : float, optional + Minimum value for the color scaling. Default is None, which sets it to the minimum value of the data. + vmax : float, optional + Maximum value for the color scaling. Default is None, which sets it to the maximum value of the data. + title : str, optional + Title of the plot. If None, the title is set to "Max-{field}". + lat_lines : array-like, optional + Latitude lines to be included in the plot. Default is calculated based on dataset coordinates. + lon_lines : array-like, optional + Longitude lines to be included in the plot. Default is calculated based on dataset coordinates. + add_map : bool, optional + Whether to include a map background in the plot. Default is True. + projection : cartopy.crs.Projection, optional + The map projection for the plot. Default is cartopy.crs.LambertAzimuthalEqualArea(). + colorbar : bool, optional + Whether to include a colorbar in the plot. Default is True. + range_rings : bool, optional + Whether to include range rings at 50 km intervals. Default is True. + dpi : int, optional + DPI (dots per inch) for the plot. Default is 100. + savedir : str, optional + Directory where the plot will be saved. If None, the plot is not saved. + show_figure : bool, optional + Whether to display the plot. Default is True. + **kwargs : dict, optional + Additional keyword arguments to pass to matplotlib's `pcolormesh` function. + + Returns + ------- + None + This function does not return any value. It generates and optionally displays/saves a plot. + + Notes + ----- + Author : Hamid Ali Syed (@syedhamidali) + """ + + ds = grid.to_xarray().squeeze() + + if lon_lines is None: + lon_lines = np.arange(int(ds.lon.min().values), int(ds.lon.max().values) + 1) + if lat_lines is None: + lat_lines = np.arange(int(ds.lat.min().values), int(ds.lat.max().values) + 1) + + plt.rcParams.copy() + plt.rcParams.update( + { + "font.weight": "bold", + "axes.labelweight": "bold", + "xtick.direction": "in", + "ytick.direction": "in", + "xtick.major.size": 10, + "ytick.major.size": 10, + "xtick.minor.size": 7, + "ytick.minor.size": 7, + "font.size": 14, + "axes.linewidth": 2, + "ytick.labelsize": 12, + "xtick.labelsize": 12, + } + ) + + max_c = ds[field].max(dim="z") + max_x = ds[field].max(dim="y") + max_y = ds[field].max(dim="x").T + + trgx = ds["x"].values + trgy = ds["y"].values + trgz = ds["z"].values + + max_height = int(np.floor(trgz.max()) / 1e3) + sideticks = np.arange(max_height / 4, max_height + 1, max_height / 4).astype(int) + + if cmap is None: + cmap = "pyart_Carbone42" + if vmin is None: + vmin = grid.fields[field]["data"].min() + if vmax is None: + vmax = grid.fields[field]["data"].max() + if title is None: + title = f"Max-{field.upper()[:3]}" + + def plot_range_rings(ax_xy, max_range): + """ + Plots range rings at 50 km intervals. + + Parameters + ---------- + ax_xy : matplotlib.axes.Axes + The axis on which to plot the range rings. + max_range : float + The maximum range for the range rings. + + Returns + ------- + None + """ + background_color = ax_xy.get_facecolor() + color = "k" if sum(background_color[:3]) / 3 > 0.5 else "w" + + for i, r in enumerate(np.arange(5e4, np.floor(max_range) + 1, 5e4)): + label = f"Ring Dist. {int(r/1e3)} km" if i == 0 else None + ax_xy.plot( + r * np.cos(np.arange(0, 360) * np.pi / 180), + r * np.sin(np.arange(0, 360) * np.pi / 180), + color=color, + ls="--", + linewidth=0.4, + alpha=0.3, + label=label, + ) + + ax_xy.legend(loc="upper right", prop={"weight": "normal", "size": 8}) + + def _get_projection(ds): + """ + Determine the central latitude and longitude from a dataset + and return the corresponding projection. + + Parameters + ---------- + ds : xarray.Dataset + The dataset from which to extract latitude and longitude + information. + + Returns + ------- + projection : cartopy.crs.Projection + A Cartopy projection object centered on the extracted or + calculated latitude and longitude. + """ + + def get_coord_or_attr(ds, coord_name, attr_name): + """Helper function to get a coordinate or attribute, or + calculate median if available. + """ + if coord_name in ds: + return ( + ds[coord_name].values.item() + if ds[coord_name].values.ndim == 0 + else ds[coord_name].values[0] + ) + if f"origin_{coord_name}" in ds.coords: + return ds.coords[f"origin_{coord_name}"].median().item() + if f"radar_{coord_name}" in ds.coords: + return ds.coords[f"radar_{coord_name}"].median().item() + return ds.attrs.get(attr_name, None) + + lat_0 = get_coord_or_attr( + ds, "latitude", "origin_latitude" + ) or get_coord_or_attr(ds, "radar_latitude", "origin_latitude") + lon_0 = get_coord_or_attr( + ds, "longitude", "origin_longitude" + ) or get_coord_or_attr(ds, "radar_longitude", "origin_longitude") + + if lat_0 is None or lon_0 is None: + lat_0 = ds.lat.mean().item() + lon_0 = ds.lon.mean().item() + + projection = ccrs.LambertAzimuthalEqualArea(lon_0, lat_0) + return projection + + projection = _get_projection(ds) + + # FIG + fig = plt.figure(figsize=[10.3, 10]) + left, bottom, width, height = 0.1, 0.1, 0.6, 0.2 + ax_xy = plt.axes((left, bottom, width, width), projection=projection) + ax_x = plt.axes((left, bottom + width, width, height)) + ax_y = plt.axes((left + width, bottom, height, width)) + ax_cnr = plt.axes((left + width, bottom + width, left + left, height)) + if colorbar: + ax_cb = plt.axes((left - 0.015 + width + height + 0.02, bottom, 0.02, width)) + + # Set axis label formatters + ax_x.xaxis.set_major_formatter(NullFormatter()) + ax_y.yaxis.set_major_formatter(NullFormatter()) + ax_cnr.yaxis.set_major_formatter(NullFormatter()) + ax_cnr.xaxis.set_major_formatter(NullFormatter()) + ax_x.set_ylabel("Height (km)", size=13) + ax_y.set_xlabel("Height (km)", size=13) + + # Draw CAPPI + plt.sca(ax_xy) + xy = ax_xy.pcolormesh(trgx, trgy, max_c, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) + + # Add map features + if add_map: + map_features(ax_xy, lat_lines, lon_lines) + + ax_xy.minorticks_on() + + if range_rings: + plot_range_rings(ax_xy, trgx.max()) + + ax_xy.set_xlim(trgx.min(), trgx.max()) + ax_xy.set_ylim(trgx.min(), trgx.max()) + + # Draw colorbar + if colorbar: + cb = plt.colorbar(xy, cax=ax_cb) + cb.set_label(ds[field].attrs["units"], size=15) + + background_color = ax_xy.get_facecolor() + color = "k" if sum(background_color[:3]) / 3 > 0.5 else "w" + + plt.sca(ax_x) + plt.pcolormesh(trgx / 1e3, trgz / 1e3, max_x, cmap=cmap, vmin=vmin, vmax=vmax) + plt.yticks(sideticks) + ax_x.set_xlim(trgx.min() / 1e3, trgx.max() / 1e3) + ax_x.grid(axis="y", lw=0.5, color=color, alpha=0.5, ls=":") + ax_x.minorticks_on() + + plt.sca(ax_y) + plt.pcolormesh(trgz / 1e3, trgy / 1e3, max_y, cmap=cmap, vmin=vmin, vmax=vmax) + ax_y.set_xticks(sideticks) + ax_y.set_ylim(trgx.min() / 1e3, trgx.max() / 1e3) + ax_y.grid(axis="x", lw=0.5, color=color, alpha=0.5, ls=":") + ax_y.minorticks_on() + + plt.sca(ax_cnr) + plt.tick_params( + axis="both", # changes apply to both axes + which="both", # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + left=False, + right=False, + labelbottom=False, + ) + + # Retrieve instrument name + instrument_name = ds.attrs.get("instrument_name", "N/A")[:4] + + # Center-align text in the corner box + plt.text( + 0.5, + 0.90, + f"Site: {instrument_name}", + size=13, + weight="bold", + ha="center", + va="center", + ) + plt.text(0.5, 0.76, title, size=13, weight="bold", ha="center", va="center") + plt.text( + 0.5, + 0.63, + f"Max Range: {np.floor(trgx.max() / 1e3)} km", + size=11, + ha="center", + va="center", + ) + plt.text( + 0.5, + 0.47, + f"Max Height: {np.floor(trgz.max() / 1e3)} km", + size=11, + ha="center", + va="center", + ) + plt.text( + 0.5, + 0.28, + ds["time"].dt.strftime("%H:%M:%S Z").values.item(), + weight="bold", + size=16, + ha="center", + va="center", + ) + plt.text( + 0.5, + 0.13, + ds["time"].dt.strftime("%d %b, %Y UTC").values.item(), + size=13.5, + ha="center", + va="center", + ) + ax_xy.set_aspect("auto") + + if add_slogan: + fig.text( + 0.1, + 0.06, + "Powered by Py-ART", # Coordinates close to (0, 0) for lower-left corner + fontsize=9, + fontname="Courier New", + # bbox=dict(facecolor='none', boxstyle='round,pad=0.5') + ) + + if savedir is not None: + radar_name = ds.attrs.get("instrument_name", "Radar") + time_str = ds["time"].dt.strftime("%Y%m%d%H%M%S").values.item() + figname = f"{savedir}{os.sep}{title}_{radar_name}_{time_str}.png" + plt.savefig(fname=figname, dpi=dpi, bbox_inches="tight") + print(f"Figure(s) saved as {figname}") + + # plt.rcParams.update(original_rc_params) + plt.rcdefaults() + + if show_figure: + plt.show() + else: + plt.close() + + +def map_features(ax, lat_lines, lon_lines): + """ + Adds map features and gridlines to the plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis on which to add map features and gridlines. + lat_lines : array-like + Latitude lines for the gridlines. + lon_lines : array-like + Longitude lines for the gridlines. + + Returns + ------- + None + """ + background_color = ax.get_facecolor() + color = "k" if sum(background_color[:3]) / 3 > 0.5 else "w" + + # Labeling gridlines depending on the projection + if isinstance(ax.projection, (ccrs.PlateCarree, ccrs.Mercator)): + gl = ax.gridlines( + xlocs=lon_lines, + ylocs=lat_lines, + linewidth=1, + alpha=0.5, + linestyle="--", + draw_labels=True, + ) + gl.top_labels = False + gl.right_labels = False + gl.xlines = False + gl.ylines = False + gl.xlabel_style = {"color": color} + gl.ylabel_style = {"color": color} + ax.add_feature(feat.COASTLINE, alpha=0.8, lw=1, ec=color) + ax.add_feature(feat.BORDERS, alpha=0.7, lw=0.7, ls="--", ec=color) + ax.add_feature( + feat.STATES.with_scale("10m"), alpha=0.6, lw=0.5, ls=":", ec=color + ) + + elif isinstance( + ax.projection, (ccrs.LambertConformal, ccrs.LambertAzimuthalEqualArea) + ): + ax.figure.canvas.draw() + gl = ax.gridlines( + crs=ccrs.PlateCarree(), + xlocs=lon_lines, + ylocs=lat_lines, + linewidth=1, + alpha=0.5, + linestyle="--", + draw_labels=False, + ) + gl.xlines = False + gl.ylines = False + gl.xlabel_style = {"color": color} + gl.ylabel_style = {"color": color} + ax.add_feature(feat.COASTLINE, alpha=0.8, lw=1, ec=color) + ax.add_feature(feat.BORDERS, alpha=0.7, lw=0.7, ls="--", ec=color) + ax.add_feature( + feat.STATES.with_scale("10m"), alpha=0.6, lw=0.5, ls=":", ec=color + ) + # Label the end-points of the gridlines using custom tick makers + ax.xaxis.set_major_formatter(LONGITUDE_FORMATTER) + ax.yaxis.set_major_formatter(LATITUDE_FORMATTER) + from pyart.graph.gridmapdisplay import lambert_xticks, lambert_yticks + + lambert_xticks(ax, lon_lines) + lambert_yticks(ax, lat_lines) + else: + ax.gridlines(xlocs=lon_lines, ylocs=lat_lines) diff --git a/tests/core/test_grid.py b/tests/core/test_grid.py index bdc2ba68ab..0fedbe0450 100644 --- a/tests/core/test_grid.py +++ b/tests/core/test_grid.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from numpy.testing import assert_almost_equal, assert_equal +from numpy.testing import assert_almost_equal try: import pyproj @@ -88,21 +88,23 @@ def test_grid_to_xarray(): grid = pyart.testing.make_target_grid() ds = grid.to_xarray() - lon, lat = pyart.core.Grid.get_point_longitude_latitude(grid) + lon, lat = grid.get_point_longitude_latitude() time = np.array([netCDF4.num2date(grid.time["data"][0], grid.time["units"])]) - z = grid.z["data"] - y = grid.y["data"] - x = grid.x["data"] - assert_equal(ds.x.data, x) - assert_equal(ds.y.data, y) - assert_equal(ds.z.data, z) - assert_equal(ds.lon.data, lon) - assert_equal(ds.lat.data, lat) - assert_equal(ds.time.data, time) + # Check dimensions + assert ds.dims == {"time": 1, "z": 2, "y": 400, "x": 320, "nradar": 1} + # Check coordinate data + assert np.array_equal(ds.x.data, grid.x["data"]) + assert np.array_equal(ds.y.data, grid.y["data"]) + assert np.array_equal(ds.z.data, grid.z["data"]) + assert np.array_equal(ds.lon.data, lon) + assert np.array_equal(ds.lat.data, lat) + assert np.array_equal(ds.time.data, time) + + # Check radar-specific attributes assert ds.attrs["nradar"] == 1 - assert ds.attrs["radar_name"]["data"][0] == "ExampleRadar" + assert ds.attrs["radar_name"] == "ExampleRadar" def _check_dicts_similar(dic1, dic2): diff --git a/tests/graph/test_plot_maxcappi.py b/tests/graph/test_plot_maxcappi.py new file mode 100644 index 0000000000..857183fcd2 --- /dev/null +++ b/tests/graph/test_plot_maxcappi.py @@ -0,0 +1,106 @@ +import os + +import matplotlib.pyplot as plt +import pytest + +import pyart + + +@pytest.mark.skipif( + not pyart.graph.gridmapdisplay._CARTOPY_AVAILABLE, reason="Cartopy is not installed" +) +def test_plot_maxcappi_simple(outfile=None): + """ + Test the basic functionality of plot_maxcappi. + """ + # Create a test grid using Py-ART's testing utility + grid = pyart.testing.make_target_grid() + grid.z["data"] = grid.z["data"] * 10 + 100 + grid.metadata["instrument_name"] = "GRID" + + # Use plot_maxcappi with the generated grid + pyart.graph.max_cappi.plot_maxcappi( + grid=grid, + field="reflectivity", + savedir=None, # Do not save the plot + show_figure=False, # Do not show the plot + ) + + if outfile: + plt.savefig(outfile) + plt.close() + + +@pytest.mark.skipif( + not pyart.graph.gridmapdisplay._CARTOPY_AVAILABLE, reason="Cartopy is not installed" +) +def test_plot_maxcappi_with_save(outfile=None): + """ + Test plot_maxcappi and save the output to a file. + """ + # Create a test grid using Py-ART's testing utility + grid = pyart.testing.make_target_grid() + grid.z["data"] = grid.z["data"] * 10 + 100 + grid.metadata["instrument_name"] = "GRID" + + # Define the output file path + outfile = outfile or "test_plot_maxcappi_output.png" + + # Use plot_maxcappi with the generated grid + pyart.graph.max_cappi.plot_maxcappi( + grid=grid, + field="reflectivity", + savedir=None, # Handle saving manually below + show_figure=False, # Do not show the plot + ) + + # Save the figure to a file + plt.savefig(outfile) + plt.close() + + # Check if the file was created + assert os.path.exists(outfile), "The plot was not saved as expected." + + +@pytest.mark.skipif( + not pyart.graph.gridmapdisplay._CARTOPY_AVAILABLE, reason="Cartopy is not installed" +) +def test_plot_maxcappi_with_all_options(outfile=None): + """ + Test plot_maxcappi with all options enabled. + """ + import cartopy.crs as ccrs + + # Create a test grid using Py-ART's testing utility + grid = pyart.testing.make_target_grid() + grid.z["data"] = grid.z["data"] * 10 + 100 + grid.metadata["instrument_name"] = "GRID" + + # Use a custom projection for testing + projection = ccrs.Mercator() + + # Use plot_maxcappi with additional options + pyart.graph.max_cappi.plot_maxcappi( + grid=grid, + field="reflectivity", + title="Test Max-CAPPI", + lat_lines=None, + lon_lines=None, + add_map=True, + projection=projection, + colorbar=True, + range_rings=True, + dpi=150, + savedir=None, + show_figure=False, + ) + + if outfile: + plt.savefig(outfile) + plt.close() + + +if __name__ == "__main__": + test_plot_maxcappi_simple("figure_plot_maxcappi_simple.png") + test_plot_maxcappi_with_save("figure_plot_maxcappi_output.png") + test_plot_maxcappi_with_all_options("figure_plot_maxcappi_all_options.png") diff --git a/tests/retrieve/test_qvp.py b/tests/retrieve/test_qvp.py index 2293ae56f7..7991fe8b45 100644 --- a/tests/retrieve/test_qvp.py +++ b/tests/retrieve/test_qvp.py @@ -498,7 +498,9 @@ def test_find_nearest_gate(test_radar): assert ind_ray == 141.0 assert ind_rng == 145.0 assert azi == 141.0 - assert rng == 14514.514 + assert ( + abs(rng - 14514.514) < 1e-3 + ) # Allow for a small tolerance in floating-point comparison def test_find_neighbour_gates(test_radar):