Skip to content

Commit

Permalink
vectorize 1d interpolators
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Sep 23, 2024
1 parent d26144d commit 9c33edf
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 42 deletions.
10 changes: 5 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,12 +2224,12 @@ def interp(
Performs univariate or multivariate interpolation of a DataArray onto
new coordinates using scipy's interpolation routines. If interpolating
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
called. When interpolating along multiple existing dimensions, an
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
is called. When interpolating along multiple existing dimensions, an
attempt is made to decompose the interpolation into multiple
1-dimensional interpolations. If this is possible,
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
:py:func:`scipy.interpolate.interpn` is called.
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator is called.
Otherwise, :py:func:`scipy.interpolate.interpn` is called.
Parameters
----------
Expand Down
10 changes: 5 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3885,12 +3885,12 @@ def interp(
Performs univariate or multivariate interpolation of a Dataset onto
new coordinates using scipy's interpolation routines. If interpolating
along an existing dimension, :py:class:`scipy.interpolate.interp1d` is
called. When interpolating along multiple existing dimensions, an
along an existing dimension, either :py:class:`scipy.interpolate.interp1d`
or a 1-dimensional scipy interpolator (e.g. :py:class:`scipy.interpolate.KroghInterpolator`)
is called. When interpolating along multiple existing dimensions, an
attempt is made to decompose the interpolation into multiple
1-dimensional interpolations. If this is possible,
:py:class:`scipy.interpolate.interp1d` is called. Otherwise,
:py:func:`scipy.interpolate.interpn` is called.
1-dimensional interpolations. If this is possible, the 1-dimensional interpolator
is called. Otherwise, :py:func:`scipy.interpolate.interpn` is called.
Parameters
----------
Expand Down
33 changes: 20 additions & 13 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def __init__(

self.method = method
self.cons_kwargs = kwargs
del self.cons_kwargs["axis"]
self.call_kwargs = {"nu": nu, "ext": ext}

if fill_value is not None:
Expand Down Expand Up @@ -479,7 +480,8 @@ def _get_interpolator(
interp1d_methods = get_args(Interp1dOptions)
valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))

# prioritize scipy.interpolate
# prefer numpy.interp for 1d linear interpolation. This function cannot
# take higher dimensional data but scipy.interp1d can.
if (
method == "linear"
and not kwargs.get("fill_value", None) == "extrapolate"
Expand All @@ -489,25 +491,31 @@ def _get_interpolator(
interp_class = NumpyInterpolator

elif method in valid_methods:
kwargs.update(axis=-1)
if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
elif vectorizeable_only:
raise ValueError(
f"{method} is not a vectorizeable interpolator. "
f"Available methods are {interp1d_methods}"
)
elif method == "barycentric":
interp_class = _import_interpolant("BarycentricInterpolator", method)
elif method in ["krogh", "krog"]:
interp_class = _import_interpolant("KroghInterpolator", method)
elif method == "pchip":
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
# utils.emit_user_level_warning(
# "The 1d SplineInterpolator class is performing an incorrect calculation and "
# "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
# PendingDeprecationWarning,
# )
if vectorizeable_only:
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
kwargs.update(method=method)
interp_class = SplineInterpolator
elif method == "akima":
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima")
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError(f"{method} is not a valid scipy interpolator")
else:
Expand All @@ -525,6 +533,7 @@ def _get_interpolator_nd(method, **kwargs):

if method in valid_methods:
kwargs.update(method=method)
kwargs.setdefault("bounds_error", False)
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
Expand Down Expand Up @@ -614,9 +623,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
if not indexes_coords:
return var.copy()

# default behavior
kwargs["bounds_error"] = kwargs.get("bounds_error", False)

result = var
# decompose the interpolation into a succession of independent interpolation
for indep_indexes_coords in decompose_interp(indexes_coords):
Expand Down Expand Up @@ -663,8 +669,8 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
new_x : a list of 1d array
New coordinates. Should not contain NaN.
method : string
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
1-dimensional interpolation.
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima',
'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation.
{'linear', 'nearest'} for multidimensional interpolation
**kwargs
Optional keyword arguments to be passed to scipy.interpolator
Expand Down Expand Up @@ -755,8 +761,9 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):

def _interp1d(var, x, new_x, func, kwargs):
# x, new_x are tuples of size 1.
x, new_x = x[0], new_x[0]
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
x, new_x = x[0].data, new_x[0].data

rslt = func(x, var, **kwargs)(np.ravel(new_x))
if new_x.ndim > 1:
return reshape(rslt, (var.shape[:-1] + new_x.shape))
if new_x.ndim == 0:
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def copy(
Interp1dOptions = Literal[
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
]
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
InterpolantOptions = Literal[
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
]
InterpOptions = Union[Interp1dOptions, InterpolantOptions]

DatetimeUnitOptions = Literal[
Expand Down
65 changes: 47 additions & 18 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,55 @@ def func(obj, new_x):
assert_allclose(actual, expected)


@pytest.mark.parametrize("use_dask", [False, True])
def test_interpolate_vectorize(use_dask: bool) -> None:
@pytest.mark.parametrize(
"use_dask, method",
(
(False, "linear"),
(False, "akima"),
(False, "makima"),
(True, "linear"),
(True, "akima"),
),
)
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")

# scipy interpolation for the reference
def func(obj, dim, new_x):
def func(obj, dim, new_x, method):
scipy_kwargs = {}
interpolant_options = {
"barycentric": "BarycentricInterpolator",
"krogh": "KroghInterpolator",
"pchip": "PchipInterpolator",
"akima": "Akima1DInterpolator",
"makima": "Akima1DInterpolator",
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
for s in new_x.shape[::-1]:
shape.insert(obj.get_axis_num(dim), s)

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
)(new_x).reshape(shape)
if method in interpolant_options:
interpolant = getattr(scipy.interpolate, interpolant_options[method])
if method == "makima":
scipy_kwargs["method"] = method
return interpolant(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)
else:
scipy_kwargs["kind"] = method
scipy_kwargs["bounds_error"] = False
scipy_kwargs["fill_value"] = np.nan
return scipy.interpolate.interp1d(
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)

da = get_example_data(0)

if use_dask:
da = da.chunk({"y": 5})

Expand All @@ -165,17 +191,17 @@ def func(obj, dim, new_x):
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

Check failure on line 194 in xarray/tests/test_interp.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 min-all-deps

test_interpolate_vectorize[False-makima] TypeError: Akima1DInterpolator.__init__() got an unexpected keyword argument 'method'

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "y"],
coords={
"z": xdest["z"],
"z2": xdest["z2"],
"y": da["y"],
"x": ("z", xdest.values),
"x2": ("z", func(da["x2"], "x", xdest)),
"x2": ("z", func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
Expand All @@ -191,18 +217,18 @@ def func(obj, dim, new_x):
},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "w", "y"],
coords={
"z": xdest["z"],
"w": xdest["w"],
"z2": xdest["z2"],
"y": da["y"],
"x": (("z", "w"), xdest.data),
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
Expand Down Expand Up @@ -404,7 +430,7 @@ def test_errors(use_dask: bool) -> None:
pytest.skip("dask is not installed in the environment.")
da = da.chunk()

for method in ["akima", "spline"]:
for method in ["spline"]:
with pytest.raises(ValueError):
da.interp(x=[0.5, 1.5], method=method) # type: ignore[arg-type]

Expand Down Expand Up @@ -922,7 +948,10 @@ def test_interp1d_bounds_error() -> None:
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
],
)
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
def test_coord_attrs(
x,
expect_same_attrs: bool,
) -> None:
base_attrs = dict(foo="bar")
ds = xr.Dataset(
data_vars=dict(a=2 * np.arange(5)),
Expand Down

0 comments on commit 9c33edf

Please sign in to comment.