Skip to content

Commit

Permalink
vectorize 1d interpolators
Browse files Browse the repository at this point in the history
review changes
  • Loading branch information
hollymandel committed Sep 24, 2024
1 parent 074a480 commit 4106fde
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
12 changes: 10 additions & 2 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
copy=False,
bounds_error=False,
order=None,
axis=-1,
**kwargs,
):
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__(
bounds_error=bounds_error,
assume_sorted=assume_sorted,
copy=copy,
axis=axis,
**self.cons_kwargs,
)

Expand Down Expand Up @@ -496,10 +498,13 @@ def _get_interpolator(
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif method == "barycentric":
kwargs.update(axis=-1)
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method in ["krogh", "krog"]:
kwargs.update(axis=-1)
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
kwargs.update(axis=-1)
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
# utils.emit_user_level_warning(
Expand All @@ -512,6 +517,10 @@ def _get_interpolator(
kwargs.update(method=method)
interp_class = SplineInterpolator
elif method == "akima":
kwargs.update(axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima", axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima")
Expand Down Expand Up @@ -761,8 +770,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):

def _interp1d(var, x, new_x, func, kwargs):
# x, new_x are tuples of size 1.
x, new_x = x[0].data, new_x[0].data

x, new_x = x[0], new_x[0]
rslt = func(x, var, **kwargs)(np.ravel(new_x))
if new_x.ndim > 1:
return reshape(rslt, (var.shape[:-1] + new_x.shape))
Expand Down
28 changes: 15 additions & 13 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def func(obj, new_x):
assert_allclose(actual, expected)


@requires_scipy
@pytest.mark.parametrize(
"use_dask, method",
(
Expand All @@ -152,37 +153,38 @@ def func(obj, new_x):
),
)
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

# scipy interpolation for the reference
def func(obj, dim, new_x, method):
scipy_kwargs = {}
interpolant_options = {
"barycentric": "BarycentricInterpolator",
"krogh": "KroghInterpolator",
"pchip": "PchipInterpolator",
"akima": "Akima1DInterpolator",
"makima": "Akima1DInterpolator",
"barycentric": scipy.interpolate.BarycentricInterpolator,
"krogh": scipy.interpolate.KroghInterpolator,
"pchip": scipy.interpolate.PchipInterpolator,
"akima": scipy.interpolate.Akima1DInterpolator,
"makima": scipy.interpolate.Akima1DInterpolator,
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
for s in new_x.shape[::-1]:
shape.insert(obj.get_axis_num(dim), s)

if method in interpolant_options:
interpolant = getattr(scipy.interpolate, interpolant_options[method])
interpolant = interpolant_options[method]
if method == "makima":
scipy_kwargs["method"] = method
return interpolant(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)
else:
scipy_kwargs["kind"] = method
scipy_kwargs["bounds_error"] = False
scipy_kwargs["fill_value"] = np.nan

return scipy.interpolate.interp1d(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
kind=method,
bounds_error=False,
fill_value=np.nan,
**scipy_kwargs,
)(new_x).reshape(shape)

da = get_example_data(0)
Expand Down

0 comments on commit 4106fde

Please sign in to comment.