Skip to content

Commit

Permalink
makima
Browse files Browse the repository at this point in the history
  • Loading branch information
hollymandel committed Sep 20, 2024
1 parent fe1476f commit 3d824d0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
3 changes: 3 additions & 0 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ def _get_interpolator(
interp_class = SplineInterpolator
elif method == "akima":
interp_class = _import_interpolant("Akima1DInterpolator", method)
elif method == "makima":
kwargs.update(method="makima")
interp_class = _import_interpolant("Akima1DInterpolator", method)
else:
raise ValueError(f"{method} is not a valid scipy interpolator")
else:
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def copy(
Interp1dOptions = Literal[
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
]
InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
InterpolantOptions = Literal[
"barycentric", "krogh", "pchip", "spline", "akima", "makima"
]
InterpOptions = Union[Interp1dOptions, InterpolantOptions]

DatetimeUnitOptions = Literal[
Expand Down
30 changes: 17 additions & 13 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,13 @@ def func(obj, new_x):

@pytest.mark.parametrize(
"use_dask, method",
((False, "linear"), (False, "akima"), (True, "linear"), (True, "akima")),
(
(False, "linear"),
(False, "akima"),
(False, "makima"),
(True, "linear"),
(True, "akima"),
),
)
def test_interpolate_vectorize(use_dask: bool, method: str) -> None:
if not has_scipy:
Expand All @@ -145,11 +151,13 @@ def test_interpolate_vectorize(use_dask: bool, method: str) -> None:

# scipy interpolation for the reference
def func(obj, dim, new_x, method):
scipy_kwargs = {}
interpolant_options = {
"barycentric": "BarycentricInterpolator",
"krogh": "KroghInterpolator",
"pchip": "PchipInterpolator",
"akima": "Akima1DInterpolator",
"makima": "Akima1DInterpolator",
}

shape = [s for i, s in enumerate(obj.shape) if i != obj.get_axis_num(dim)]
Expand All @@ -160,21 +168,17 @@ def func(obj, dim, new_x, method):
from scipy import interpolate

interpolant = getattr(interpolate, interpolant_options[method])
if method == "makima":
scipy_kwargs["method"] = method
return interpolant(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
)(
new_x
).reshape(shape)
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)
else:
scipy_kwargs["kind"] = method
scipy_kwargs["bounds_error"] = False
scipy_kwargs["fill_value"] = np.nan
return scipy.interpolate.interp1d(
da[dim],
obj.data,
axis=obj.get_axis_num(dim),
bounds_error=False,
fill_value=np.nan,
kind=method,
da[dim], obj.data, axis=obj.get_axis_num(dim), **scipy_kwargs
)(new_x).reshape(shape)

da = get_example_data(0)
Expand Down

0 comments on commit 3d824d0

Please sign in to comment.