Skip to content

Commit

Permalink
handle np versions, separate unary/binary path
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Nov 17, 2024
1 parent bba3183 commit 6a55ea6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 14 deletions.
5 changes: 4 additions & 1 deletion xarray/tests/test_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,11 @@ def setUp(self):
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
@pytest.mark.parametrize("name", xu.__all__)
def test_ufuncs(self, name, request):
np_func = getattr(np, name)
xu_func = getattr(xu, name)
if isinstance(xu_func, xu._UnavailableUfunc):
pytest.xfail(f"Ufunc {name} is not available in numpy {np.__version__}.")

np_func = getattr(np, name)

if name == "isnat":
args = (self.xt,)
Expand Down
61 changes: 48 additions & 13 deletions xarray/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,40 @@ def get_array_namespace(*args):
return next(iter(xps)) if len(xps) else np


class _UFuncDispatcher:
"""Wrapper for dispatching ufuncs."""
class _UnaryUfunc:
"""Wrapper for dispatching unary ufuncs."""

def __init__(self, name):
self._name = name

def __call__(self, *args, **kwargs):
xp = get_array_namespace(*args)
def __call__(self, x, **kwargs):
xp = get_array_namespace(x)
func = getattr(xp, self._name)
return xr.apply_ufunc(func, *args, dask="allowed", **kwargs)
return xr.apply_ufunc(func, x, dask="allowed", **kwargs)


class _BinaryUfunc:
"""Wrapper for dispatching binary ufuncs."""

def __init__(self, name):
self._name = name

def __call__(self, x, y, **kwargs):
xp = get_array_namespace(x, y)
func = getattr(xp, self._name)
return xr.apply_ufunc(func, x, y, dask="allowed", **kwargs)


class _UnavailableUfunc:
"""Wrapper for unimplemented ufuncs in older numpy versions."""

def __init__(self, name):
self._name = name

def __call__(self, *args, **kwargs):
raise NotImplementedError(
f"Ufunc {self._name} is not available in numpy {np.__version__}."
)


def _skip_signature(doc, name):
Expand Down Expand Up @@ -90,7 +114,18 @@ def _dedent(doc):


def _create_op(name):
func = _UFuncDispatcher(name)
if not hasattr(np, name):
# handle older numpy versions with missing array api standard aliases
if np.lib.NumpyVersion(np.__version__) < "2.0.0":
return _UnavailableUfunc(name)
raise ValueError(f"'{name}' is not a valid numpy function")

np_func = getattr(np, name)
if hasattr(np_func, "nin") and np_func.nin == 2:
func = _BinaryUfunc(name)
else:
func = _UnaryUfunc(name)

func.__name__ = name
doc = getattr(np, name).__doc__

Expand All @@ -111,7 +146,7 @@ def _create_op(name):
# Ufuncs that use core dimensions or product multiple output arrays are
# not currently supported, and left commented below.

# unary
# UNARY
abs = _create_op("abs")
absolute = _create_op("absolute")
acos = _create_op("acos")
Expand Down Expand Up @@ -142,7 +177,7 @@ def _create_op(name):
expm1 = _create_op("expm1")
fabs = _create_op("fabs")
floor = _create_op("floor")
# frexp
# frexp = _create_op("frexp")
invert = _create_op("invert")
isfinite = _create_op("isfinite")
isinf = _create_op("isinf")
Expand All @@ -153,7 +188,7 @@ def _create_op(name):
log1p = _create_op("log1p")
log2 = _create_op("log2")
logical_not = _create_op("logical_not")
# modf
# modf = _create_op("modf")
negative = _create_op("negative")
positive = _create_op("positive")
rad2deg = _create_op("rad2deg")
Expand All @@ -171,7 +206,7 @@ def _create_op(name):
tanh = _create_op("tanh")
trunc = _create_op("trunc")

# binary
# BINARY
add = _create_op("add")
arctan2 = _create_op("arctan2")
atan2 = _create_op("atan2")
Expand All @@ -182,7 +217,7 @@ def _create_op(name):
bitwise_xor = _create_op("bitwise_xor")
copysign = _create_op("copysign")
divide = _create_op("divide")
# divmod
# divmod = _create_op("divmod")
equal = _create_op("equal")
float_power = _create_op("float_power")
floor_divide = _create_op("floor_divide")
Expand All @@ -204,7 +239,7 @@ def _create_op(name):
logical_and = _create_op("logical_and")
logical_or = _create_op("logical_or")
logical_xor = _create_op("logical_xor")
# matmul
# matmul = _create_op("matmul")
maximum = _create_op("maximum")
minimum = _create_op("minimum")
mod = _create_op("mod")
Expand All @@ -217,7 +252,7 @@ def _create_op(name):
right_shift = _create_op("right_shift")
subtract = _create_op("subtract")
true_divide = _create_op("true_divide")
# vecdot
# vecdot = _create_op("vecdot")

# elementwise non-ufunc
angle = _create_op("angle")
Expand Down

0 comments on commit 6a55ea6

Please sign in to comment.