From 9c33edfb7783b4e5f357250741092c1908cfd71b Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 19 Sep 2024 14:49:55 -0700 Subject: [PATCH] vectorize 1d interpolators --- xarray/core/dataarray.py | 10 +++--- xarray/core/dataset.py | 10 +++--- xarray/core/missing.py | 33 +++++++++++-------- xarray/core/types.py | 4 ++- xarray/tests/test_interp.py | 65 +++++++++++++++++++++++++++---------- 5 files changed, 80 insertions(+), 42 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index bcc57acd316..2adf862f1fd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2224,12 +2224,12 @@ def interp( Performs univariate or multivariate interpolation of a DataArray onto new coordinates using scipy's interpolation routines. If interpolating - along an existing dimension, :py:class:`scipy.interpolate.interp1d` is - called. When interpolating along multiple existing dimensions, an + along an existing dimension, either :py:class:`scipy.interpolate.interp1d` + or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`) + is called. When interpolating along multiple existing dimensions, an attempt is made to decompose the interpolation into multiple - 1-dimensional interpolations. If this is possible, - :py:class:`scipy.interpolate.interp1d` is called. Otherwise, - :py:func:`scipy.interpolate.interpn` is called. + 1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called. + Otherwise, :py:func:`scipy.interpolate.interpn` is called. Parameters ---------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b1ce264cbc8..e0d9316f939 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3885,12 +3885,12 @@ def interp( Performs univariate or multivariate interpolation of a Dataset onto new coordinates using scipy's interpolation routines. If interpolating - along an existing dimension, :py:class:`scipy.interpolate.interp1d` is - called. When interpolating along multiple existing dimensions, an + along an existing dimension, either :py:class:`scipy.interpolate.interp1d` + or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`) + is called. When interpolating along multiple existing dimensions, an attempt is made to decompose the interpolation into multiple - 1-dimensional interpolations. If this is possible, - :py:class:`scipy.interpolate.interp1d` is called. Otherwise, - :py:func:`scipy.interpolate.interpn` is called. + 1-dimensional interpolations. If this is possible, the 1-dimensional interpolator + is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called. Parameters ---------- diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 2df53b172f0..35dd42b24a6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -203,6 +203,7 @@ def __init__( self.method = method self.cons_kwargs = kwargs + del self.cons_kwargs["axis"] self.call_kwargs = {"nu": nu, "ext": ext} if fill_value is not None: @@ -479,7 +480,8 @@ def _get_interpolator( interp1d_methods = get_args(Interp1dOptions) valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v)) - # prioritize scipy.interpolate + # prefer numpy.interp for 1d linear interpolation. This function cannot + # take higher dimensional data but scipy.interp1d can. if ( method == "linear" and not kwargs.get("fill_value", None) == "extrapolate" @@ -489,14 +491,10 @@ def _get_interpolator( interp_class = NumpyInterpolator elif method in valid_methods: + kwargs.update(axis=-1) if method in interp1d_methods: kwargs.update(method=method) interp_class = ScipyInterpolator - elif vectorizeable_only: - raise ValueError( - f"{method} is not a vectorizeable interpolator. " - f"Available methods are {interp1d_methods}" - ) elif method == "barycentric": interp_class = _import_interpolant("BarycentricInterpolator", method) elif method in ["krogh", "krog"]: @@ -504,10 +502,20 @@ def _get_interpolator( elif method == "pchip": interp_class = _import_interpolant("PchipInterpolator", method) elif method == "spline": + # utils.emit_user_level_warning( + # "The 1d SplineInterpolator class is performing an incorrect calculation and " + # "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.", + # PendingDeprecationWarning, + # ) + if vectorizeable_only: + raise ValueError(f"{method} is not a vectorizeable interpolator. ") kwargs.update(method=method) interp_class = SplineInterpolator elif method == "akima": interp_class = _import_interpolant("Akima1DInterpolator", method) + elif method == "makima": + kwargs.update(method="makima") + interp_class = _import_interpolant("Akima1DInterpolator", method) else: raise ValueError(f"{method} is not a valid scipy interpolator") else: @@ -525,6 +533,7 @@ def _get_interpolator_nd(method, **kwargs): if method in valid_methods: kwargs.update(method=method) + kwargs.setdefault("bounds_error", False) interp_class = _import_interpolant("interpn", method) else: raise ValueError( @@ -614,9 +623,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): if not indexes_coords: return var.copy() - # default behavior - kwargs["bounds_error"] = kwargs.get("bounds_error", False) - result = var # decompose the interpolation into a succession of independent interpolation for indep_indexes_coords in decompose_interp(indexes_coords): @@ -663,8 +669,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): new_x : a list of 1d array New coordinates. Should not contain NaN. method : string - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for - 1-dimensional interpolation. + {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima', + 'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation **kwargs Optional keyword arguments to be passed to scipy.interpolator @@ -755,8 +761,9 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): def _interp1d(var, x, new_x, func, kwargs): # x, new_x are tuples of size 1. - x, new_x = x[0], new_x[0] - rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x)) + x, new_x = x[0].data, new_x[0].data + + rslt = func(x, var, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: return reshape(rslt, (var.shape[:-1] + new_x.shape)) if new_x.ndim == 0: diff --git a/xarray/core/types.py b/xarray/core/types.py index 34b6029ee15..1d383d550ec 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -228,7 +228,9 @@ def copy( Interp1dOptions = Literal[ "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial" ] -InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"] +InterpolantOptions = Literal[ + "barycentric", "krogh", "pchip", "spline", "akima", "makima" +] InterpOptions = Union[Interp1dOptions, InterpolantOptions] DatetimeUnitOptions = Literal[ diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 5c03881242b..a13f9918b80 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -132,8 +132,17 @@ def func(obj, new_x): assert_allclose(actual, expected) -@pytest.mark.parametrize("use_dask", [False, True]) -def test_interpolate_vectorize(use_dask: bool) -> None: +@pytest.mark.parametrize( + "use_dask, method", + ( + (False, "linear"), + (False, "akima"), + (False, "makima"), + (True, "linear"), + (True, "akima"), + ), +) +def test_interpolate_vectorize(use_dask: bool, method: str) -> None: if not has_scipy: pytest.skip("scipy is not installed.") @@ -141,20 +150,37 @@ def test_interpolate_vectorize(use_dask: bool) -> None: pytest.skip("dask is not installed in the environment.") # scipy interpolation for the reference - def func(obj, dim, new_x): + def func(obj, dim, new_x, method): + scipy_kwargs = {} + interpolant_options = { + "barycentric": "BarycentricInterpolator", + "krogh": "KroghInterpolator", + "pchip": "PchipInterpolator", + "akima": "Akima1DInterpolator", + "makima": "Akima1DInterpolator", + } + shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)] for s in new_x.shape[::-1]: shape.insert(obj.get_axis_num(dim), s) - return scipy.interpolate.interp1d( - da[dim], - obj.data, - axis=obj.get_axis_num(dim), - bounds_error=False, - fill_value=np.nan, - )(new_x).reshape(shape) + if method in interpolant_options: + interpolant = getattr(scipy.interpolate, interpolant_options[method]) + if method == "makima": + scipy_kwargs["method"] = method + return interpolant( + da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs + )(new_x).reshape(shape) + else: + scipy_kwargs["kind"] = method + scipy_kwargs["bounds_error"] = False + scipy_kwargs["fill_value"] = np.nan + return scipy.interpolate.interp1d( + da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs + )(new_x).reshape(shape) da = get_example_data(0) + if use_dask: da = da.chunk({"y": 5}) @@ -165,17 +191,17 @@ def func(obj, dim, new_x): coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))}, ) - actual = da.interp(x=xdest, method="linear") + actual = da.interp(x=xdest, method=method) expected = xr.DataArray( - func(da, "x", xdest), + func(da, "x", xdest, method), dims=["z", "y"], coords={ "z": xdest["z"], "z2": xdest["z2"], "y": da["y"], "x": ("z", xdest.values), - "x2": ("z", func(da["x2"], "x", xdest)), + "x2": ("z", func(da["x2"], "x", xdest, method)), }, ) assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True)) @@ -191,10 +217,10 @@ def func(obj, dim, new_x): }, ) - actual = da.interp(x=xdest, method="linear") + actual = da.interp(x=xdest, method=method) expected = xr.DataArray( - func(da, "x", xdest), + func(da, "x", xdest, method), dims=["z", "w", "y"], coords={ "z": xdest["z"], @@ -202,7 +228,7 @@ def func(obj, dim, new_x): "z2": xdest["z2"], "y": da["y"], "x": (("z", "w"), xdest.data), - "x2": (("z", "w"), func(da["x2"], "x", xdest)), + "x2": (("z", "w"), func(da["x2"], "x", xdest, method)), }, ) assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True)) @@ -404,7 +430,7 @@ def test_errors(use_dask: bool) -> None: pytest.skip("dask is not installed in the environment.") da = da.chunk() - for method in ["akima", "spline"]: + for method in ["spline"]: with pytest.raises(ValueError): da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type] @@ -922,7 +948,10 @@ def test_interp1d_bounds_error() -> None: (("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False), ], ) -def test_coord_attrs(x, expect_same_attrs: bool) -> None: +def test_coord_attrs( + x, + expect_same_attrs: bool, +) -> None: base_attrs = dict(foo="bar") ds = xr.Dataset( data_vars=dict(a=2 * np.arange(5)),