Skip to content

Commit

Permalink
vectorize 1d interpolators
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Sep 23, 2024
1 parent d26144d commit 09a9e73
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 40 deletions.
29 changes: 18 additions & 11 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __init__(

self.method = method
self.cons_kwargs = kwargs
del self.cons_kwargs["axis"]
self.call_kwargs = {"nu": nu, "ext": ext}

if fill_value is not None:
Expand Down Expand Up @@ -479,7 +480,8 @@ def _get_interpolator(
interp1d_methods = get_args(Interp1dOptions)
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))

# prioritize scipy.interpolate
# prefer numpy.interp for 1d linear interpolation. This function cannot
# take higher dimensional data but scipy.interp1d can.
if (
method == "linear"
and not kwargs.get("fill_value", None) == "extrapolate"
Expand All @@ -489,25 +491,31 @@ def _get_interpolator(
interp_class = NumpyInterpolator

elif method in valid_methods:
kwargs.update(axis=-1)
if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif vectorizeable_only:
raise ValueError(
f"{method} is not a vectorizeable interpolator. "
f"Available methods are {interp1d_methods}"
)
elif method == "barycentric":
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method in ["krogh", "krog"]:
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
# utils.emit_user_level_warning(
# "The 1d SplineInterpolator class is performing an incorrect calculation and "
# "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
# PendingDeprecationWarning,
# )
if vectorizeable_only:
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
kwargs.update(method=method)
interp_class = SplineInterpolator
elif method == "akima":
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima")
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError(f"{method} is not a valid scipy interpolator")
else:
Expand All @@ -525,6 +533,7 @@ def _get_interpolator_nd(method, **kwargs):

if method in valid_methods:
kwargs.update(method=method)
kwargs.update(bounds_error=False)
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
Expand Down Expand Up @@ -614,9 +623,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
if not indexes_coords:
return var.copy()

# default behavior
kwargs["bounds_error"] = kwargs.get("bounds_error", False)

result = var
# decompose the interpolation into a succession of independent interpolation
for indep_indexes_coords in decompose_interp(indexes_coords):
Expand Down Expand Up @@ -755,8 +761,9 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):

def _interp1d(var, x, new_x, func, kwargs):
# x, new_x are tuples of size 1.
x, new_x = x[0], new_x[0]
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
x, new_x = x[0].data, new_x[0].data

rslt = func(x, var, **kwargs)(np.ravel(new_x))
if new_x.ndim > 1:
return reshape(rslt, (var.shape[:-1] + new_x.shape))
if new_x.ndim == 0:
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def copy(
Interp1dOptions = Literal[
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
]
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
InterpolantOptions = Literal[
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
]
InterpOptions = Union[Interp1dOptions, InterpolantOptions]

DatetimeUnitOptions = Literal[
Expand Down
80 changes: 52 additions & 28 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,9 @@
import xarray as xr
from xarray.coding.cftimeindex import _parse_array_of_cftime_strings
from xarray.core.types import InterpOptions
from xarray.tests import (
assert_allclose,
assert_equal,
assert_identical,
has_dask,
has_scipy,
requires_cftime,
requires_dask,
requires_scipy,
)
from xarray.tests import (assert_allclose, assert_equal, assert_identical,
has_dask, has_scipy, requires_cftime, requires_dask,
requires_scipy)
from xarray.tests.test_dataset import create_test_data

try:
Expand Down Expand Up @@ -132,29 +125,57 @@ def func(obj, new_x):
assert_allclose(actual, expected)


@pytest.mark.parametrize("use_dask", [False, True])
def test_interpolate_vectorize(use_dask: bool) -> None:
@pytest.mark.parametrize(
"use_dask, method",
(
(False, "linear"),
(False, "akima"),
(False, "makima"),
(True, "linear"),
(True, "akima"),
),
)
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")

# scipy interpolation for the reference
def func(obj, dim, new_x):
def func(obj, dim, new_x, method):
scipy_kwargs = {}
interpolant_options = {
"barycentric": "BarycentricInterpolator",
"krogh": "KroghInterpolator",
"pchip": "PchipInterpolator",
"akima": "Akima1DInterpolator",
"makima": "Akima1DInterpolator",
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
for s in new_x.shape[::-1]:
shape.insert(obj.get_axis_num(dim), s)

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
)(new_x).reshape(shape)
if method in interpolant_options:
from scipy import interpolate

interpolant = getattr(interpolate, interpolant_options[method])
if method == "makima":
scipy_kwargs["method"] = method
return interpolant(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)
else:
scipy_kwargs["kind"] = method
scipy_kwargs["bounds_error"] = False
scipy_kwargs["fill_value"] = np.nan
return scipy.interpolate.interp1d(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)

da = get_example_data(0)

if use_dask:
da = da.chunk({"y": 5})

Expand All @@ -165,17 +186,17 @@ def func(obj, dim, new_x):
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "y"],
coords={
"z": xdest["z"],
"z2": xdest["z2"],
"y": da["y"],
"x": ("z", xdest.values),
"x2": ("z", func(da["x2"], "x", xdest)),
"x2": ("z", func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
Expand All @@ -191,18 +212,18 @@ def func(obj, dim, new_x):
},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "w", "y"],
coords={
"z": xdest["z"],
"w": xdest["w"],
"z2": xdest["z2"],
"y": da["y"],
"x": (("z", "w"), xdest.data),
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
Expand Down Expand Up @@ -404,7 +425,7 @@ def test_errors(use_dask: bool) -> None:
pytest.skip("dask is not installed in the environment.")
da = da.chunk()

for method in ["akima", "spline"]:
for method in ["spline"]:
with pytest.raises(ValueError):
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]

Expand Down Expand Up @@ -922,7 +943,10 @@ def test_interp1d_bounds_error() -> None:
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
],
)
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
def test_coord_attrs(
x,
expect_same_attrs: bool,
) -> None:
base_attrs = dict(foo="bar")
ds = xr.Dataset(
data_vars=dict(a=2 * np.arange(5)),
Expand Down

0 comments on commit 09a9e73

Please sign in to comment.