diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 89c8d3b4599..9cf630acea1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -32,6 +32,9 @@ New Features `Tom Nicholas `_. - Added zarr backends for :py:func:`open_groups` (:issue:`9430`, :pull:`9469`). By `Eni Awowale `_. +- Added support for vectorized interpolation using additional interpolators + from the ``scipy.interpolate`` module (:issue:`9049`, :pull:`9526`). + By `Holly Mandel `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index bcc57acd316..2adf862f1fd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2224,12 +2224,12 @@ def interp( Performs univariate or multivariate interpolation of a DataArray onto new coordinates using scipy's interpolation routines. If interpolating - along an existing dimension, :py:class:`scipy.interpolate.interp1d` is - called. When interpolating along multiple existing dimensions, an + along an existing dimension, either :py:class:`scipy.interpolate.interp1d` + or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`) + is called. When interpolating along multiple existing dimensions, an attempt is made to decompose the interpolation into multiple - 1-dimensional interpolations. If this is possible, - :py:class:`scipy.interpolate.interp1d` is called. Otherwise, - :py:func:`scipy.interpolate.interpn` is called. + 1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called. + Otherwise, :py:func:`scipy.interpolate.interpn` is called. Parameters ---------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dd860b54933..82b60d7abc8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3885,12 +3885,12 @@ def interp( Performs univariate or multivariate interpolation of a Dataset onto new coordinates using scipy's interpolation routines. If interpolating - along an existing dimension, :py:class:`scipy.interpolate.interp1d` is - called. When interpolating along multiple existing dimensions, an + along an existing dimension, either :py:class:`scipy.interpolate.interp1d` + or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`) + is called. When interpolating along multiple existing dimensions, an attempt is made to decompose the interpolation into multiple - 1-dimensional interpolations. If this is possible, - :py:class:`scipy.interpolate.interp1d` is called. Otherwise, - :py:func:`scipy.interpolate.interpn` is called. + 1-dimensional interpolations. If this is possible, the 1-dimensional interpolator + is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called. Parameters ---------- diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 2df53b172f0..6a380f53f0d 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -138,6 +138,7 @@ def __init__( copy=False, bounds_error=False, order=None, + axis=-1, **kwargs, ): from scipy.interpolate import interp1d @@ -173,6 +174,7 @@ def __init__( bounds_error=bounds_error, assume_sorted=assume_sorted, copy=copy, + axis=axis, **self.cons_kwargs, ) @@ -479,7 +481,8 @@ def _get_interpolator( interp1d_methods = get_args(Interp1dOptions) valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v)) - # prioritize scipy.interpolate + # prefer numpy.interp for 1d linear interpolation. This function cannot + # take higher dimensional data but scipy.interp1d can. if ( method == "linear" and not kwargs.get("fill_value", None) == "extrapolate" @@ -492,21 +495,33 @@ def _get_interpolator( if method in interp1d_methods: kwargs.update(method=method) interp_class = ScipyInterpolator - elif vectorizeable_only: - raise ValueError( - f"{method} is not a vectorizeable interpolator. " - f"Available methods are {interp1d_methods}" - ) elif method == "barycentric": + kwargs.update(axis=-1) interp_class = _import_interpolant("BarycentricInterpolator", method) elif method in ["krogh", "krog"]: + kwargs.update(axis=-1) interp_class = _import_interpolant("KroghInterpolator", method) elif method == "pchip": + kwargs.update(axis=-1) interp_class = _import_interpolant("PchipInterpolator", method) elif method == "spline": + utils.emit_user_level_warning( + "The 1d SplineInterpolator class is performing an incorrect calculation and " + "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.", + PendingDeprecationWarning, + ) + if vectorizeable_only: + raise ValueError(f"{method} is not a vectorizeable interpolator. ") kwargs.update(method=method) interp_class = SplineInterpolator elif method == "akima": + kwargs.update(axis=-1) + interp_class = _import_interpolant("Akima1DInterpolator", method) + elif method == "makima": + kwargs.update(method="makima", axis=-1) + interp_class = _import_interpolant("Akima1DInterpolator", method) + elif method == "makima": + kwargs.update(method="makima", axis=-1) interp_class = _import_interpolant("Akima1DInterpolator", method) else: raise ValueError(f"{method} is not a valid scipy interpolator") @@ -525,6 +540,7 @@ def _get_interpolator_nd(method, **kwargs): if method in valid_methods: kwargs.update(method=method) + kwargs.setdefault("bounds_error", False) interp_class = _import_interpolant("interpn", method) else: raise ValueError( @@ -614,9 +630,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): if not indexes_coords: return var.copy() - # default behavior - kwargs["bounds_error"] = kwargs.get("bounds_error", False) - result = var # decompose the interpolation into a succession of independent interpolation for indep_indexes_coords in decompose_interp(indexes_coords): @@ -663,8 +676,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): new_x : a list of 1d array New coordinates. Should not contain NaN. method : string - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for - 1-dimensional interpolation. + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima', + 'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation **kwargs Optional keyword arguments to be passed to scipy.interpolator @@ -756,7 +769,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): def _interp1d(var, x, new_x, func, kwargs): # x, new_x are tuples of size 1. x, new_x = x[0], new_x[0] - rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x)) + rslt = func(x, var, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: return reshape(rslt, (var.shape[:-1] + new_x.shape)) if new_x.ndim == 0: diff --git a/xarray/core/types.py b/xarray/core/types.py index 34b6029ee15..1d383d550ec 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -228,7 +228,9 @@ def copy( Interp1dOptions = Literal[ "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial" ] -InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"] +InterpolantOptions = Literal[ + "barycentric", "krogh", "pchip", "spline", "akima", "makima" +] InterpOptions = Union[Interp1dOptions, InterpolantOptions] DatetimeUnitOptions = Literal[ diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 71ae1a7075f..6b54994f311 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -87,6 +87,7 @@ def _importorskip( has_matplotlib, requires_matplotlib = _importorskip("matplotlib") has_scipy, requires_scipy = _importorskip("scipy") +has_scipy_ge_1_13, requires_scipy_ge_1_13 = _importorskip("scipy", "1.13") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 5c03881242b..d02e12dd695 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -16,6 +16,7 @@ assert_identical, has_dask, has_scipy, + has_scipy_ge_1_13, requires_cftime, requires_dask, requires_scipy, @@ -132,29 +133,66 @@ def func(obj, new_x): assert_allclose(actual, expected) -@pytest.mark.parametrize("use_dask", [False, True]) -def test_interpolate_vectorize(use_dask: bool) -> None: - if not has_scipy: - pytest.skip("scipy is not installed.") - - if not has_dask and use_dask: - pytest.skip("dask is not installed in the environment.") - +@requires_scipy +@pytest.mark.parametrize( + "use_dask, method", + ( + (False, "linear"), + (False, "akima"), + pytest.param( + False, + "makima", + marks=pytest.mark.skipif(not has_scipy_ge_1_13, reason="scipy too old"), + ), + pytest.param( + True, + "linear", + marks=pytest.mark.skipif(not has_dask, reason="dask not available"), + ), + pytest.param( + True, + "akima", + marks=pytest.mark.skipif(not has_dask, reason="dask not available"), + ), + ), +) +def test_interpolate_vectorize(use_dask: bool, method: InterpOptions) -> None: # scipy interpolation for the reference - def func(obj, dim, new_x): + def func(obj, dim, new_x, method): + scipy_kwargs = {} + interpolant_options = { + "barycentric": scipy.interpolate.BarycentricInterpolator, + "krogh": scipy.interpolate.KroghInterpolator, + "pchip": scipy.interpolate.PchipInterpolator, + "akima": scipy.interpolate.Akima1DInterpolator, + "makima": scipy.interpolate.Akima1DInterpolator, + } + shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)] for s in new_x.shape[::-1]: shape.insert(obj.get_axis_num(dim), s) - return scipy.interpolate.interp1d( - da[dim], - obj.data, - axis=obj.get_axis_num(dim), - bounds_error=False, - fill_value=np.nan, - )(new_x).reshape(shape) + if method in interpolant_options: + interpolant = interpolant_options[method] + if method == "makima": + scipy_kwargs["method"] = method + return interpolant( + da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs + )(new_x).reshape(shape) + else: + + return scipy.interpolate.interp1d( + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + kind=method, + bounds_error=False, + fill_value=np.nan, + **scipy_kwargs, + )(new_x).reshape(shape) da = get_example_data(0) + if use_dask: da = da.chunk({"y": 5}) @@ -165,17 +203,17 @@ def func(obj, dim, new_x): coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))}, ) - actual = da.interp(x=xdest, method="linear") + actual = da.interp(x=xdest, method=method) expected = xr.DataArray( - func(da, "x", xdest), + func(da, "x", xdest, method), dims=["z", "y"], coords={ "z": xdest["z"], "z2": xdest["z2"], "y": da["y"], "x": ("z", xdest.values), - "x2": ("z", func(da["x2"], "x", xdest)), + "x2": ("z", func(da["x2"], "x", xdest, method)), }, ) assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True)) @@ -191,10 +229,10 @@ def func(obj, dim, new_x): }, ) - actual = da.interp(x=xdest, method="linear") + actual = da.interp(x=xdest, method=method) expected = xr.DataArray( - func(da, "x", xdest), + func(da, "x", xdest, method), dims=["z", "w", "y"], coords={ "z": xdest["z"], @@ -202,7 +240,7 @@ def func(obj, dim, new_x): "z2": xdest["z2"], "y": da["y"], "x": (("z", "w"), xdest.data), - "x2": (("z", "w"), func(da["x2"], "x", xdest)), + "x2": (("z", "w"), func(da["x2"], "x", xdest, method)), }, ) assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True)) @@ -393,19 +431,17 @@ def test_nans(use_dask: bool) -> None: assert actual.count() > 0 +@requires_scipy @pytest.mark.parametrize("use_dask", [True, False]) def test_errors(use_dask: bool) -> None: - if not has_scipy: - pytest.skip("scipy is not installed.") - - # akima and spline are unavailable + # spline is unavailable da = xr.DataArray([0, 1, np.nan, 2], dims="x", coords={"x": range(4)}) if not has_dask and use_dask: pytest.skip("dask is not installed in the environment.") da = da.chunk() - for method in ["akima", "spline"]: - with pytest.raises(ValueError): + for method in ["spline"]: + with pytest.raises(ValueError), pytest.warns(PendingDeprecationWarning): da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type] # not sorted @@ -922,7 +958,10 @@ def test_interp1d_bounds_error() -> None: (("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False), ], ) -def test_coord_attrs(x, expect_same_attrs: bool) -> None: +def test_coord_attrs( + x, + expect_same_attrs: bool, +) -> None: base_attrs = dict(foo="bar") ds = xr.Dataset( data_vars=dict(a=2 * np.arange(5)), diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index bf90074a7cc..58d8a9dcf5d 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -137,7 +137,11 @@ def test_scipy_methods_function(method) -> None: # Note: Pandas does some wacky things with these methods and the full # integration tests won't work. da, _ = make_interpolate_example_data((25, 25), 0.4, non_uniform=True) - actual = da.interpolate_na(method=method, dim="time") + if method == "spline": + with pytest.warns(PendingDeprecationWarning): + actual = da.interpolate_na(method=method, dim="time") + else: + actual = da.interpolate_na(method=method, dim="time") assert (da.count("time") <= actual.count("time")).all()