From 4106fde7b0c49b2c0c572dfbb2d530bb9bbbc4a8 Mon Sep 17 00:00:00 2001 From: hollymandel Date: Thu, 19 Sep 2024 14:49:55 -0700 Subject: [PATCH] vectorize 1d interpolators review changes --- xarray/core/missing.py | 12 ++++++++++-- xarray/tests/test_interp.py | 28 +++++++++++++++------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 35dd42b24a6..204380363db 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -138,6 +138,7 @@ def __init__( copy=False, bounds_error=False, order=None, + axis=-1, **kwargs, ): from scipy.interpolate import interp1d @@ -173,6 +174,7 @@ def __init__( bounds_error=bounds_error, assume_sorted=assume_sorted, copy=copy, + axis=axis, **self.cons_kwargs, ) @@ -496,10 +498,13 @@ def _get_interpolator( kwargs.update(method=method) interp_class = ScipyInterpolator elif method == "barycentric": + kwargs.update(axis=-1) interp_class = _import_interpolant("BarycentricInterpolator", method) elif method in ["krogh", "krog"]: + kwargs.update(axis=-1) interp_class = _import_interpolant("KroghInterpolator", method) elif method == "pchip": + kwargs.update(axis=-1) interp_class = _import_interpolant("PchipInterpolator", method) elif method == "spline": # utils.emit_user_level_warning( @@ -512,6 +517,10 @@ def _get_interpolator( kwargs.update(method=method) interp_class = SplineInterpolator elif method == "akima": + kwargs.update(axis=-1) + interp_class = _import_interpolant("Akima1DInterpolator", method) + elif method == "makima": + kwargs.update(method="makima", axis=-1) interp_class = _import_interpolant("Akima1DInterpolator", method) elif method == "makima": kwargs.update(method="makima") @@ -761,8 +770,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): def _interp1d(var, x, new_x, func, kwargs): # x, new_x are tuples of size 1. - x, new_x = x[0].data, new_x[0].data - + x, new_x = x[0], new_x[0] rslt = func(x, var, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: return reshape(rslt, (var.shape[:-1] + new_x.shape)) diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index f0c7f959687..7885ec5119f 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -133,6 +133,7 @@ def func(obj, new_x): assert_allclose(actual, expected) +@requires_scipy @pytest.mark.parametrize( "use_dask, method", ( @@ -152,18 +153,15 @@ def func(obj, new_x): ), ) def test_interpolate_vectorize(use_dask: bool, method: str) -> None: - if not has_scipy: - pytest.skip("scipy is not installed.") - # scipy interpolation for the reference def func(obj, dim, new_x, method): scipy_kwargs = {} interpolant_options = { - "barycentric": "BarycentricInterpolator", - "krogh": "KroghInterpolator", - "pchip": "PchipInterpolator", - "akima": "Akima1DInterpolator", - "makima": "Akima1DInterpolator", + "barycentric": scipy.interpolate.BarycentricInterpolator, + "krogh": scipy.interpolate.KroghInterpolator, + "pchip": scipy.interpolate.PchipInterpolator, + "akima": scipy.interpolate.Akima1DInterpolator, + "makima": scipy.interpolate.Akima1DInterpolator, } shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)] @@ -171,18 +169,22 @@ def func(obj, dim, new_x, method): shape.insert(obj.get_axis_num(dim), s) if method in interpolant_options: - interpolant = getattr(scipy.interpolate, interpolant_options[method]) + interpolant = interpolant_options[method] if method == "makima": scipy_kwargs["method"] = method return interpolant( da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs )(new_x).reshape(shape) else: - scipy_kwargs["kind"] = method - scipy_kwargs["bounds_error"] = False - scipy_kwargs["fill_value"] = np.nan + return scipy.interpolate.interp1d( - da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs + da[dim], + obj.data, + axis=obj.get_axis_num(dim), + kind=method, + bounds_error=False, + fill_value=np.nan, + **scipy_kwargs, )(new_x).reshape(shape) da = get_example_data(0)