diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 9ac24921955..03ae2a7f848 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -513,6 +513,9 @@ def _get_interpolator( interp_class = SplineInterpolator elif method == "akima": interp_class = _import_interpolant("Akima1DInterpolator", method) + elif method == "makima": + kwargs.update(method="makima") + interp_class = _import_interpolant("Akima1DInterpolator", method) else: raise ValueError(f"{method} is not a valid scipy interpolator") else: diff --git a/xarray/core/types.py b/xarray/core/types.py index 3eb97f86c4a..bf310b2fa87 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -218,7 +218,9 @@ def copy( Interp1dOptions = Literal[ "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial" ] -InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"] +InterpolantOptions = Literal[ + "barycentric", "krogh", "pchip", "spline", "akima", "makima" +] InterpOptions = Union[Interp1dOptions, InterpolantOptions] DatetimeUnitOptions = Literal[ diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index ae9dbe5fde4..2ec222fb8cb 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -134,7 +134,13 @@ def func(obj, new_x): @pytest.mark.parametrize( "use_dask, method", - ((False, "linear"), (False, "akima"), (True, "linear"), (True, "akima")), + ( + (False, "linear"), + (False, "akima"), + (False, "makima"), + (True, "linear"), + (True, "akima"), + ), ) def test_interpolate_vectorize(use_dask: bool, method: str) -> None: if not has_scipy: @@ -145,11 +151,13 @@ def test_interpolate_vectorize(use_dask: bool, method: str) -> None: # scipy interpolation for the reference def func(obj, dim, new_x, method): + scipy_kwargs = {} interpolant_options = { "barycentric": "BarycentricInterpolator", "krogh": "KroghInterpolator", "pchip": "PchipInterpolator", "akima": "Akima1DInterpolator", + "makima": "Akima1DInterpolator", } shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)] @@ -160,21 +168,17 @@ def func(obj, dim, new_x, method): from scipy import interpolate interpolant = getattr(interpolate, interpolant_options[method]) + if method == "makima": + scipy_kwargs["method"] = method return interpolant( - da[dim], - obj.data, - axis=obj.get_axis_num(dim), - )( - new_x - ).reshape(shape) + da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs + )(new_x).reshape(shape) else: + scipy_kwargs["kind"] = method + scipy_kwargs["bounds_error"] = False + scipy_kwargs["fill_value"] = np.nan return scipy.interpolate.interp1d( - da[dim], - obj.data, - axis=obj.get_axis_num(dim), - bounds_error=False, - fill_value=np.nan, - kind=method, + da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs )(new_x).reshape(shape) da = get_example_data(0)