Skip to content

Commit

Permalink
Improve nearest neighbour interpolation for gridded data
Browse files Browse the repository at this point in the history
  • Loading branch information
mkstratos committed Feb 24, 2020
1 parent 72cafb9 commit 279b9d7
Showing 1 changed file with 109 additions and 21 deletions.
130 changes: 109 additions & 21 deletions build_antarctica.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import xarray as xr
from nco import Nco
import pyproj
from pykdtree.kdtree import KDTree
from util import projections


Expand Down Expand Up @@ -64,7 +65,7 @@ def load_cryosat(in_file, **kwargs):
# Some datasets have a time axis, crop that because it will be added later
# to the interpolated dataset
if data.ndim == 3:
_slice = slice(-1, None, None)
_slice = (-1, slice(None), slice(None))
else:
_slice = slice(None)

Expand All @@ -81,7 +82,6 @@ def load_cryosat(in_file, **kwargs):
dims[0]: (dim_orig, y_transform),
dims[1]: (dim_orig, x_transform),
}

_out = xr.DataArray(
np.ma.masked_values(data[_slice], -9999), dims=dim_orig, coords=coords
)
Expand Down Expand Up @@ -208,6 +208,66 @@ def rect_bivariate_interp(xin, yin, data, xout, yout):
return new_data


def nn_interp(x_in, y_in, x_out, y_out, z_in, nbrs=1):
"""Perform nearest neighbour interpolation on Cartesian coordinate data.
Parameters
----------
x_in, y_in : array_like
1- or 2-D arrays of input coordinate data
x_out, y_out : array_like
1- or 2-D arrays on which to interpolate (must be same coordinate
reference system as x_in y_in)
z_in : array_like
2D (ny, nx) array of input data
nbrs : int
Number of neighbours to use in interpolation
Returns
-------
z_interp : array_like
2D Array (shape of y_out, x_out) of interpolated data
"""
# X,Y must be the same length. If input data is a grid then use meshgrid
# otherwise, skip this step, because each index of x, y, z already describe
# the same location
if x_in.ndim == 1 and z_in.ndim >= 2:
x_in, y_in = np.meshgrid(x_in, y_in)

if x_out.ndim == 1:
x_out, y_out = np.meshgrid(x_out, y_out)

z_in = np.ma.masked_invalid(z_in)

# Creates a K-Dimensional tree for input grid, allows
# "lookup" for nearest neighbour
tree = KDTree(np.vstack([x_in.flatten(), y_in.flatten()]).T)

# Perform "lookup" for nearest neighbour of new grid,
# use the nearest nbrs neighbours
dist, inds = tree.query(
np.vstack([x_out.flatten(), y_out.flatten()]).T, k=nbrs
)

if nbrs > 1:
# Weight by distance of the neighbours
wgts = 1.0 / dist ** 2

# Select indicies on flattened (1D) input data, weight by
# sum of distance (axis=1 is "neighbour" axis)
z_out = np.sum(wgts * z_in.flatten()[inds], axis=1) / np.sum(
wgts, axis=1
)
else:
# If using only one neighbour, just use the single
# nearest neighbour as the interpolated value
z_out = z_in.flatten()[inds]

# Return re-shaped data
return z_out.reshape(x_out.shape)


def interp(in_cfg, in_var, out_cfg, output):
"""Take input data, interpolate, assgin metadata, return DataArray."""
print(f"Interpolating {in_var} from {in_cfg['file']}")
Expand All @@ -220,21 +280,37 @@ def interp(in_cfg, in_var, out_cfg, output):
ny_in = in_data[in_cfg["coords"]["y"]].shape[0]
if nx_in >= xout.shape[0] and ny_in >= yout.shape[0]:
# Use nearest neighbour when input data is higher res than output
method = "nearest"
nbrs = 1
else:
# Otherwise use linear
method = "linear"

print(f" using {method}")
breakpoint()
intp_data = in_data[in_var].interp(
**{in_cfg["coords"]["x"]: xout, in_cfg["coords"]["y"]: yout},
method=method,
nbrs = 1

print(f" using {nbrs} neighbour(s)")
# intp_data = in_data[in_var].interp(
# **{in_cfg["coords"]["x"]: xout, in_cfg["coords"]["y"]: yout},
# method=method,
# )

intp_data = nn_interp(
in_data[in_var][in_cfg["coords"]["x"]].values,
in_data[in_var][in_cfg["coords"]["y"]].values,
xout,
yout,
in_data[in_var].values,
nbrs,
)
for coord in ["y", "x"]:
if in_cfg["coords"][coord] != out_cfg["coords"][coord]:
# Drop original coordinate copy from the output DataArray
intp_data = intp_data.drop(in_cfg["coords"][coord])
intp_data = xr.DataArray(
intp_data,
dims=[out_cfg["coords"]["y"], out_cfg["coords"]["x"]],
coords={
out_cfg["coords"]["y"]: yout,
out_cfg["coords"]["x"]: xout,
},
)
# for coord in ["y", "x"]:
# if in_cfg["coords"][coord] != out_cfg["coords"][coord]:
# # Drop original coordinate copy from the output DataArray
# intp_data = intp_data.drop(in_cfg["coords"][coord])
# else:
# # Otherwise use Bivariate Spline on rectangular grid
# print(" using bivariate spline")
Expand All @@ -248,12 +324,24 @@ def interp(in_cfg, in_var, out_cfg, output):
# yout.values,
# )
else:
out_grid = np.meshgrid(xout, yout, indexing="ij")
intp_data = scipy.interpolate.griddata(
in_data[:, :2][:, ::-1],
in_data[:, 2],
tuple(out_grid),
method="nearest",
# out_grid = np.meshgrid(xout, yout, indexing="ij")
# intp_data = scipy.interpolate.griddata(
# in_data[:, :2][:, ::-1],
# in_data[:, 2],
# tuple(out_grid),
# method="nearest",
# )
intp_data = nn_interp(
in_data[:, 0], in_data[:, 1], xout, yout, in_data[:, 2], 1,
)

intp_data = xr.DataArray(
intp_data,
dims=[out_cfg["coords"]["y"], out_cfg["coords"]["x"]],
coords={
out_cfg["coords"]["y"]: yout,
out_cfg["coords"]["x"]: xout,
},
)

if not isinstance(intp_data, xr.DataArray):
Expand Down Expand Up @@ -453,4 +541,4 @@ def main(island="antarctica", resolution=1, proj_opt=None):


if __name__ == "__main__":
main("greenland", proj_opt="mcb")
main("antarctica", proj_opt=None)

0 comments on commit 279b9d7

Please sign in to comment.