diff --git a/xarray/tests/test_ufuncs.py b/xarray/tests/test_ufuncs.py index 27e8b46c5fa..580d3a8b94f 100644 --- a/xarray/tests/test_ufuncs.py +++ b/xarray/tests/test_ufuncs.py @@ -196,8 +196,11 @@ def setUp(self): @pytest.mark.filterwarnings("ignore::RuntimeWarning") @pytest.mark.parametrize("name", xu.__all__) def test_ufuncs(self, name, request): - np_func = getattr(np, name) xu_func = getattr(xu, name) + if isinstance(xu_func, xu._UnavailableUfunc): + pytest.xfail(f"Ufunc {name} is not available in numpy {np.__version__}.") + + np_func = getattr(np, name) if name == "isnat": args = (self.xt,) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index d5eb2060a78..65240e042a7 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -41,16 +41,40 @@ def get_array_namespace(*args): return next(iter(xps)) if len(xps) else np -class _UFuncDispatcher: - """Wrapper for dispatching ufuncs.""" +class _UnaryUfunc: + """Wrapper for dispatching unary ufuncs.""" def __init__(self, name): self._name = name - def __call__(self, *args, **kwargs): - xp = get_array_namespace(*args) + def __call__(self, x, **kwargs): + xp = get_array_namespace(x) func = getattr(xp, self._name) - return xr.apply_ufunc(func, *args, dask="allowed", **kwargs) + return xr.apply_ufunc(func, x, dask="allowed", **kwargs) + + +class _BinaryUfunc: + """Wrapper for dispatching binary ufuncs.""" + + def __init__(self, name): + self._name = name + + def __call__(self, x, y, **kwargs): + xp = get_array_namespace(x, y) + func = getattr(xp, self._name) + return xr.apply_ufunc(func, x, y, dask="allowed", **kwargs) + + +class _UnavailableUfunc: + """Wrapper for unimplemented ufuncs in older numpy versions.""" + + def __init__(self, name): + self._name = name + + def __call__(self, *args, **kwargs): + raise NotImplementedError( + f"Ufunc {self._name} is not available in numpy {np.__version__}." + ) def _skip_signature(doc, name): @@ -90,7 +114,18 @@ def _dedent(doc): def _create_op(name): - func = _UFuncDispatcher(name) + if not hasattr(np, name): + # handle older numpy versions with missing array api standard aliases + if np.lib.NumpyVersion(np.__version__) < "2.0.0": + return _UnavailableUfunc(name) + raise ValueError(f"'{name}' is not a valid numpy function") + + np_func = getattr(np, name) + if hasattr(np_func, "nin") and np_func.nin == 2: + func = _BinaryUfunc(name) + else: + func = _UnaryUfunc(name) + func.__name__ = name doc = getattr(np, name).__doc__ @@ -111,7 +146,7 @@ def _create_op(name): # Ufuncs that use core dimensions or product multiple output arrays are # not currently supported, and left commented below. -# unary +# UNARY abs = _create_op("abs") absolute = _create_op("absolute") acos = _create_op("acos") @@ -142,7 +177,7 @@ def _create_op(name): expm1 = _create_op("expm1") fabs = _create_op("fabs") floor = _create_op("floor") -# frexp +# frexp = _create_op("frexp") invert = _create_op("invert") isfinite = _create_op("isfinite") isinf = _create_op("isinf") @@ -153,7 +188,7 @@ def _create_op(name): log1p = _create_op("log1p") log2 = _create_op("log2") logical_not = _create_op("logical_not") -# modf +# modf = _create_op("modf") negative = _create_op("negative") positive = _create_op("positive") rad2deg = _create_op("rad2deg") @@ -171,7 +206,7 @@ def _create_op(name): tanh = _create_op("tanh") trunc = _create_op("trunc") -# binary +# BINARY add = _create_op("add") arctan2 = _create_op("arctan2") atan2 = _create_op("atan2") @@ -182,7 +217,7 @@ def _create_op(name): bitwise_xor = _create_op("bitwise_xor") copysign = _create_op("copysign") divide = _create_op("divide") -# divmod +# divmod = _create_op("divmod") equal = _create_op("equal") float_power = _create_op("float_power") floor_divide = _create_op("floor_divide") @@ -204,7 +239,7 @@ def _create_op(name): logical_and = _create_op("logical_and") logical_or = _create_op("logical_or") logical_xor = _create_op("logical_xor") -# matmul +# matmul = _create_op("matmul") maximum = _create_op("maximum") minimum = _create_op("minimum") mod = _create_op("mod") @@ -217,7 +252,7 @@ def _create_op(name): right_shift = _create_op("right_shift") subtract = _create_op("subtract") true_divide = _create_op("true_divide") -# vecdot +# vecdot = _create_op("vecdot") # elementwise non-ufunc angle = _create_op("angle")