Skip to content

Commit

Permalink
vectorization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Sep 20, 2024
1 parent f3e230a commit fe1476f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 27 deletions.
14 changes: 7 additions & 7 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def _get_interpolator(
interp_class = NumpyInterpolator

elif method in valid_methods:
kwargs.update(axis=-1)
if method in interp1d_methods:
kwargs.update(method=method)
interp_class = ScipyInterpolator
Expand All @@ -501,11 +502,11 @@ def _get_interpolator(
elif method == "pchip":
interp_class = _import_interpolant("PchipInterpolator", method)
elif method == "spline":
utils.emit_user_level_warning(
"The 1d SplineInterpolator class is performing an incorrect calculation and "
"is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
PendingDeprecationWarning,
)
# utils.emit_user_level_warning(
# "The 1d SplineInterpolator class is performing an incorrect calculation and "
# "is being deprecated. Please use `method=polynomial` for 1D Spline Interpolation.",
# PendingDeprecationWarning,
# )
if vectorizeable_only:
raise ValueError(f"{method} is not a vectorizeable interpolator. ")
kwargs.update(method=method)
Expand All @@ -529,6 +530,7 @@ def _get_interpolator_nd(method, **kwargs):

if method in valid_methods:
kwargs.update(method=method)
kwargs.update(bounds_error=False)
interp_class = _import_interpolant("interpn", method)
else:
raise ValueError(
Expand Down Expand Up @@ -618,8 +620,6 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
if not indexes_coords:
return var.copy()

kwargs["axis"] = -1

result = var
# decompose the interpolation into a succession of independent interpolation
for indexes_coords in decompose_interp(indexes_coords):
Expand Down
79 changes: 59 additions & 20 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,53 @@ def func(obj, new_x):
assert_allclose(actual, expected)


@pytest.mark.parametrize("use_dask", [False, True])
def test_interpolate_vectorize(use_dask: bool) -> None:
@pytest.mark.parametrize(
"use_dask, method",
((False, "linear"), (False, "akima"), (True, "linear"), (True, "akima")),
)
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
if not has_scipy:
pytest.skip("scipy is not installed.")

if not has_dask and use_dask:
pytest.skip("dask is not installed in the environment.")

# scipy interpolation for the reference
def func(obj, dim, new_x):
def func(obj, dim, new_x, method):
interpolant_options = {
"barycentric": "BarycentricInterpolator",
"krogh": "KroghInterpolator",
"pchip": "PchipInterpolator",
"akima": "Akima1DInterpolator",
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
for s in new_x.shape[::-1]:
shape.insert(obj.get_axis_num(dim), s)

return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
)(new_x).reshape(shape)
if method in interpolant_options:
from scipy import interpolate

interpolant = getattr(interpolate, interpolant_options[method])
return interpolant(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
)(
new_x
).reshape(shape)
else:
return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
kind=method,
)(new_x).reshape(shape)

da = get_example_data(0)

if use_dask:
da = da.chunk({"y": 5})

Expand All @@ -165,17 +189,17 @@ def func(obj, dim, new_x):
coords={"z": np.random.randn(30), "z2": ("z", np.random.randn(30))},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "y"],
coords={
"z": xdest["z"],
"z2": xdest["z2"],
"y": da["y"],
"x": ("z", xdest.values),
"x2": ("z", func(da["x2"], "x", xdest)),
"x2": ("z", func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "y", transpose_coords=True))
Expand All @@ -191,18 +215,18 @@ def func(obj, dim, new_x):
},
)

actual = da.interp(x=xdest, method="linear")
actual = da.interp(x=xdest, method=method)

expected = xr.DataArray(
func(da, "x", xdest),
func(da, "x", xdest, method),
dims=["z", "w", "y"],
coords={
"z": xdest["z"],
"w": xdest["w"],
"z2": xdest["z2"],
"y": da["y"],
"x": (("z", "w"), xdest.data),
"x2": (("z", "w"), func(da["x2"], "x", xdest)),
"x2": (("z", "w"), func(da["x2"], "x", xdest, method)),
},
)
assert_allclose(actual, expected.transpose("z", "w", "y", transpose_coords=True))
Expand Down Expand Up @@ -404,7 +428,7 @@ def test_errors(use_dask: bool) -> None:
pytest.skip("dask is not installed in the environment.")
da = da.chunk()

for method in ["akima", "spline"]:
for method in ["spline"]:
with pytest.raises(ValueError):
da.interp(x=[0.5, 1.5], method=method) # type: ignore

Expand Down Expand Up @@ -606,7 +630,15 @@ def test_interp_like() -> None:
],
)
@pytest.mark.filterwarnings("ignore:Converting non-nanosecond")
def test_datetime(x_new, expected) -> None:
def test_datetime(
x_new: (
DatetimeIndex
| ndarray[Any, dtype[datetime64]]
| list[str]
| Literal["2000-01-01T12:00"]
),
expected: list[int] | list[float] | float,
) -> None:
da = xr.DataArray(
np.arange(24),
dims="time",
Expand Down Expand Up @@ -788,7 +820,7 @@ def test_decompose(method: InterpOptions) -> None:
],
)
def test_interpolate_chunk_1d(
method: InterpOptions, data_ndim, interp_ndim, nscalar, chunked: bool
method: InterpOptions, data_ndim: Any, interp_ndim: Any, nscalar: Any, chunked: bool
) -> None:
"""Interpolate nd array with multiple independent indexers
Expand Down Expand Up @@ -922,7 +954,14 @@ def test_interp1d_bounds_error() -> None:
(("x", np.array([0, 0.5, 1, 2]), dict(unit="s")), False),
],
)
def test_coord_attrs(x, expect_same_attrs: bool) -> None:
def test_coord_attrs(
x: (
float
| ndarray[Any, dtype[Any]]
| tuple[Literal[x], ndarray[Any, dtype[Any]], dict[str, str]]
),
expect_same_attrs: bool,
) -> None:
base_attrs = dict(foo="bar")
ds = xr.Dataset(
data_vars=dict(a=2 * np.arange(5)),
Expand Down

0 comments on commit fe1476f

Please sign in to comment.