Skip to content

Commit

Permalink
vectorize 1d interpolators
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Sep 24, 2024
1 parent 52f13d4 commit 623be9f
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 53 deletions.
10 changes: 5 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,12 +2224,12 @@ def interp(
Performs univariate or multivariate interpolation of a DataArray onto
new coordinates using scipy's interpolation routines. If interpolating
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
called. When interpolating along multiple existing dimensions, an
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
is called. When interpolating along multiple existing dimensions, an
attempt is made to decompose the interpolation into multiple
1-dimensional interpolations. If this is possible,
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
:py:func:`scipy.interpolate.interpn` is called.
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called.
Otherwise, :py:func:`scipy.interpolate.interpn` is called.
Parameters
----------
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3885,12 +3885,12 @@ def interp(
Performs univariate or multivariate interpolation of a Dataset onto
new coordinates using scipy's interpolation routines. If interpolating
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
called. When interpolating along multiple existing dimensions, an
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
is called. When interpolating along multiple existing dimensions, an
attempt is made to decompose the interpolation into multiple
1-dimensional interpolations. If this is possible,
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
:py:func:`scipy.interpolate.interpn` is called.
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator
is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called.
Parameters
----------
Expand Down
37 changes: 25 additions & 12 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
copy=False,
bounds_error=False,
order=None,
axis=-1,
**kwargs,
):
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__(
bounds_error=bounds_error,
assume_sorted=assume_sorted,
copy=copy,
axis=axis,
**self.cons_kwargs,
)

Expand Down Expand Up @@ -479,7 +481,8 @@ def _get_interpolator(
interp1d_methods = get_args(Interp1dOptions)
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))

# prioritize scipy.interpolate
# prefer numpy.interp for 1d linear interpolation. This function cannot
# take higher dimensional data but scipy.interp1d can.
if (
method == "linear"
and not kwargs.get("fill_value", None) == "extrapolate"
Expand All @@ -492,21 +495,33 @@ def _get_interpolator(
if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif vectorizeable_only:
raise ValueError(
f"{method} is not a vectorizeable interpolator. "
f"Available methods are {interp1d_methods}"
)
elif method == "barycentric":
kwargs.update(axis=-1)
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method in ["krogh", "krog"]:
kwargs.update(axis=-1)
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
kwargs.update(axis=-1)
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
utils.emit_user_level_warning(
"The 1d SplineInterpolator class is performing an incorrect calculation and "
"is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
PendingDeprecationWarning,
)
if vectorizeable_only:
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
kwargs.update(method=method)
interp_class = SplineInterpolator
elif method == "akima":
kwargs.update(axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima", axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima", axis=-1)
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError(f"{method} is not a valid scipy interpolator")
Expand All @@ -525,6 +540,7 @@ def _get_interpolator_nd(method, **kwargs):

if method in valid_methods:
kwargs.update(method=method)
kwargs.setdefault("bounds_error", False)
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
Expand Down Expand Up @@ -614,9 +630,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
if not indexes_coords:
return var.copy()

# default behavior
kwargs["bounds_error"] = kwargs.get("bounds_error", False)

result = var
# decompose the interpolation into a succession of independent interpolation
for indep_indexes_coords in decompose_interp(indexes_coords):
Expand Down Expand Up @@ -663,8 +676,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
new_x : a list of 1d array
New coordinates. Should not contain NaN.
method : string
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
1-dimensional interpolation.
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima',
'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation.
{'linear', 'nearest'} for multidimensional interpolation
**kwargs
Optional keyword arguments to be passed to scipy.interpolator
Expand Down Expand Up @@ -756,7 +769,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
def _interp1d(var, x, new_x, func, kwargs):
# x, new_x are tuples of size 1.
x, new_x = x[0], new_x[0]
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
rslt = func(x, var, **kwargs)(np.ravel(new_x))
if new_x.ndim > 1:
return reshape(rslt, (var.shape[:-1] + new_x.shape))
if new_x.ndim == 0:
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def copy(
Interp1dOptions = Literal[
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
]
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
InterpolantOptions = Literal[
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
]
InterpOptions = Union[Interp1dOptions, InterpolantOptions]

DatetimeUnitOptions = Literal[
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _importorskip(

has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
has_scipy, requires_scipy = _importorskip("scipy")
has_scipy_ge_1_13, requires_scipy_ge_1_13 = _importorskip("scipy", "1.13")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
97 changes: 68 additions & 29 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
assert_identical,
has_dask,
has_scipy,
has_scipy_ge_1_13,
requires_cftime,
requires_dask,
requires_scipy,
Expand Down Expand Up @@ -132,29 +133,66 @@ def func(obj, new_x):
assert_allclose(actual, expected)


@pytest.mark.parametrize("use_dask", [False, True])
def test_interpolate_vectorize(use_dask: bool) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")

@requires_scipy
@pytest.mark.parametrize(
"use_dask, method",
(
(False, "linear"),
(False, "akima"),
pytest.param(
False,
"makima",
marks=pytest.mark.skipif(not has_scipy_ge_1_13, reason="scipy too old"),
),
pytest.param(
True,
"linear",
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
),
pytest.param(
True,
"akima",
marks=pytest.mark.skipif(not has_dask, reason="dask not available"),
),
),
)
def test_interpolate_vectorize(use_dask: bool, method: InterpOptions) -> None:
# scipy interpolation for the reference
def func(obj, dim, new_x):
def func(obj, dim, new_x, method):
scipy_kwargs = {}
interpolant_options = {
"barycentric": scipy.interpolate.BarycentricInterpolator,
"krogh": scipy.interpolate.KroghInterpolator,
"pchip": scipy.interpolate.PchipInterpolator,
"akima": scipy.interpolate.Akima1DInterpolator,
"makima": scipy.interpolate.Akima1DInterpolator,
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
for s in new_x.shape[::-1]:
shape.insert(obj.get_axis_num(dim), s)

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
)(new_x).reshape(shape)
if method in interpolant_options:
interpolant = interpolant_options[method]
if method == "makima":
scipy_kwargs["method"] = method
return interpolant(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)
else:

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
kind=method,
bounds_error=False,
fill_value=np.nan,
**scipy_kwargs,
)(new_x).reshape(shape)

da = get_example_data(0)

if use_dask:
da = da.chunk({"y": 5})

Expand All @@ -165,17 +203,17 @@ def func(obj, dim, new_x):
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "y"],
coords={
"z": xdest["z"],
"z2": xdest["z2"],
"y": da["y"],
"x": ("z", xdest.values),
"x2": ("z", func(da["x2"], "x", xdest)),
"x2": ("z", func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
Expand All @@ -191,18 +229,18 @@ def func(obj, dim, new_x):
},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "w", "y"],
coords={
"z": xdest["z"],
"w": xdest["w"],
"z2": xdest["z2"],
"y": da["y"],
"x": (("z", "w"), xdest.data),
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
Expand Down Expand Up @@ -393,19 +431,17 @@ def test_nans(use_dask: bool) -> None:
assert actual.count() > 0


@requires_scipy
@pytest.mark.parametrize("use_dask", [True, False])
def test_errors(use_dask: bool) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

# akima and spline are unavailable
# spline is unavailable
da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)})
if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")
da = da.chunk()

for method in ["akima", "spline"]:
with pytest.raises(ValueError):
for method in ["spline"]:
with pytest.raises(ValueError), pytest.warns(PendingDeprecationWarning):
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]

# not sorted
Expand Down Expand Up @@ -922,7 +958,10 @@ def test_interp1d_bounds_error() -> None:
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
],
)
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
def test_coord_attrs(
x,
expect_same_attrs: bool,
) -> None:
base_attrs = dict(foo="bar")
ds = xr.Dataset(
data_vars=dict(a=2 * np.arange(5)),
Expand Down
6 changes: 5 additions & 1 deletion xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def test_scipy_methods_function(method) -> None:
# Note: Pandas does some wacky things with these methods and the full
# integration tests won't work.
da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True)
actual = da.interpolate_na(method=method, dim="time")
if method == "spline":
with pytest.warns(PendingDeprecationWarning):
actual = da.interpolate_na(method=method, dim="time")
else:
actual = da.interpolate_na(method=method, dim="time")
assert (da.count("time") <= actual.count("time")).all()


Expand Down

0 comments on commit 623be9f

Please sign in to comment.