Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vectorized interpolation with more scipy interpolators #9526

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ New Features
`Tom Nicholas <https://github.com/TomNicholas>`_.
- Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`).
By `Eni Awowale <https://github.com/eni-awowale>`_.
- Added support for vectorized interpolation using additional interpolators
from the ``scipy.interpolate`` module (:issue:`9049`, :pull:`9526`).
By `Holly Mandel <https://github.com/hollymandel>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,12 +2224,12 @@ def interp(

Performs univariate or multivariate interpolation of a DataArray onto
new coordinates using scipy's interpolation routines. If interpolating
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
called. When interpolating along multiple existing dimensions, an
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
is called. When interpolating along multiple existing dimensions, an
attempt is made to decompose the interpolation into multiple
1-dimensional interpolations. If this is possible,
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
:py:func:`scipy.interpolate.interpn` is called.
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called.
Otherwise, :py:func:`scipy.interpolate.interpn` is called.

Parameters
----------
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3885,12 +3885,12 @@ def interp(

Performs univariate or multivariate interpolation of a Dataset onto
new coordinates using scipy's interpolation routines. If interpolating
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
called. When interpolating along multiple existing dimensions, an
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
is called. When interpolating along multiple existing dimensions, an
attempt is made to decompose the interpolation into multiple
1-dimensional interpolations. If this is possible,
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
:py:func:`scipy.interpolate.interpn` is called.
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator
is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called.

Parameters
----------
Expand Down
37 changes: 25 additions & 12 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
copy=False,
bounds_error=False,
order=None,
axis=-1,
**kwargs,
):
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__(
bounds_error=bounds_error,
assume_sorted=assume_sorted,
copy=copy,
axis=axis,
**self.cons_kwargs,
)

Expand Down Expand Up @@ -479,7 +481,8 @@ def _get_interpolator(
interp1d_methods = get_args(Interp1dOptions)
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))

# prioritize scipy.interpolate
# prefer numpy.interp for 1d linear interpolation. This function cannot
# take higher dimensional data but scipy.interp1d can.
if (
method == "linear"
and not kwargs.get("fill_value", None) == "extrapolate"
Expand All @@ -492,21 +495,33 @@ def _get_interpolator(
if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif vectorizeable_only:
raise ValueError(
f"{method} is not a vectorizeable interpolator. "
f"Available methods are {interp1d_methods}"
)
elif method == "barycentric":
kwargs.update(axis=-1)
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method in ["krogh", "krog"]:
kwargs.update(axis=-1)
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
kwargs.update(axis=-1)
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
utils.emit_user_level_warning(
"The 1d SplineInterpolator class is performing an incorrect calculation and "
"is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
PendingDeprecationWarning,
)
if vectorizeable_only:
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
kwargs.update(method=method)
hollymandel marked this conversation as resolved.
Show resolved Hide resolved
interp_class = SplineInterpolator
elif method == "akima":
kwargs.update(axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima", axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima", axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError(f"{method} is not a valid scipy interpolator")
Expand All @@ -525,6 +540,7 @@ def _get_interpolator_nd(method, **kwargs):

if method in valid_methods:
kwargs.update(method=method)
kwargs.setdefault("bounds_error", False)
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
Expand Down Expand Up @@ -614,9 +630,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
if not indexes_coords:
return var.copy()

# default behavior
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice cleanup!

kwargs["bounds_error"] = kwargs.get("bounds_error", False)

result = var
# decompose the interpolation into a succession of independent interpolation
for indep_indexes_coords in decompose_interp(indexes_coords):
Expand Down Expand Up @@ -663,8 +676,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
new_x : a list of 1d array
New coordinates. Should not contain NaN.
method : string
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
1-dimensional interpolation.
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima',
'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation.
{'linear', 'nearest'} for multidimensional interpolation
**kwargs
Optional keyword arguments to be passed to scipy.interpolator
Expand Down Expand Up @@ -756,7 +769,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
def _interp1d(var, x, new_x, func, kwargs):
# x, new_x are tuples of size 1.
x, new_x = x[0], new_x[0]
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
rslt = func(x, var, **kwargs)(np.ravel(new_x))
if new_x.ndim > 1:
return reshape(rslt, (var.shape[:-1] + new_x.shape))
if new_x.ndim == 0:
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def copy(
Interp1dOptions = Literal[
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
]
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
InterpolantOptions = Literal[
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
]
InterpOptions = Union[Interp1dOptions, InterpolantOptions]

DatetimeUnitOptions = Literal[
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _importorskip(

has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
has_scipy, requires_scipy = _importorskip("scipy")
has_scipy_ge_1_13, requires_scipy_ge_1_13 = _importorskip("scipy", "1.13")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
97 changes: 68 additions & 29 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
assert_identical,
has_dask,
has_scipy,
has_scipy_ge_1_13,
requires_cftime,
requires_dask,
requires_scipy,
Expand Down Expand Up @@ -132,29 +133,66 @@ def func(obj, new_x):
assert_allclose(actual, expected)


@pytest.mark.parametrize("use_dask", [False, True])
def test_interpolate_vectorize(use_dask: bool) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")

@requires_scipy
@pytest.mark.parametrize(
hollymandel marked this conversation as resolved.
Show resolved Hide resolved
"use_dask, method",
(
(False, "linear"),
(False, "akima"),
pytest.param(
False,
"makima",
marks=pytest.mark.skipif(not has_scipy_ge_1_13, reason="scipy too old"),
),
pytest.param(
True,
"linear",
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
),
pytest.param(
True,
"akima",
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
),
),
)
def test_interpolate_vectorize(use_dask: bool, method: InterpOptions) -> None:
# scipy interpolation for the reference
def func(obj, dim, new_x):
def func(obj, dim, new_x, method):
scipy_kwargs = {}
interpolant_options = {
"barycentric": scipy.interpolate.BarycentricInterpolator,
"krogh": scipy.interpolate.KroghInterpolator,
"pchip": scipy.interpolate.PchipInterpolator,
"akima": scipy.interpolate.Akima1DInterpolator,
"makima": scipy.interpolate.Akima1DInterpolator,
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
for s in new_x.shape[::-1]:
shape.insert(obj.get_axis_num(dim), s)

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
)(new_x).reshape(shape)
if method in interpolant_options:
interpolant = interpolant_options[method]
if method == "makima":
scipy_kwargs["method"] = method
return interpolant(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)
else:

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
kind=method,
bounds_error=False,
fill_value=np.nan,
**scipy_kwargs,
)(new_x).reshape(shape)

da = get_example_data(0)

if use_dask:
da = da.chunk({"y": 5})

Expand All @@ -165,17 +203,17 @@ def func(obj, dim, new_x):
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "y"],
coords={
"z": xdest["z"],
"z2": xdest["z2"],
"y": da["y"],
"x": ("z", xdest.values),
"x2": ("z", func(da["x2"], "x", xdest)),
"x2": ("z", func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
Expand All @@ -191,18 +229,18 @@ def func(obj, dim, new_x):
},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "w", "y"],
coords={
"z": xdest["z"],
"w": xdest["w"],
"z2": xdest["z2"],
"y": da["y"],
"x": (("z", "w"), xdest.data),
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
Expand Down Expand Up @@ -393,19 +431,17 @@ def test_nans(use_dask: bool) -> None:
assert actual.count() > 0


@requires_scipy
@pytest.mark.parametrize("use_dask", [True, False])
def test_errors(use_dask: bool) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

# akima and spline are unavailable
# spline is unavailable
da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)})
if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")
da = da.chunk()

for method in ["akima", "spline"]:
with pytest.raises(ValueError):
for method in ["spline"]:
with pytest.raises(ValueError), pytest.warns(PendingDeprecationWarning):
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]

# not sorted
Expand Down Expand Up @@ -922,7 +958,10 @@ def test_interp1d_bounds_error() -> None:
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
],
)
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
def test_coord_attrs(
x,
expect_same_attrs: bool,
) -> None:
base_attrs = dict(foo="bar")
ds = xr.Dataset(
data_vars=dict(a=2 * np.arange(5)),
Expand Down
6 changes: 5 additions & 1 deletion xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def test_scipy_methods_function(method) -> None:
# Note: Pandas does some wacky things with these methods and the full
# integration tests won't work.
da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True)
actual = da.interpolate_na(method=method, dim="time")
if method == "spline":
with pytest.warns(PendingDeprecationWarning):
actual = da.interpolate_na(method=method, dim="time")
else:
actual = da.interpolate_na(method=method, dim="time")
assert (da.count("time") <= actual.count("time")).all()


Expand Down
Loading