From f64f96782df17fe90cdbb501838ca0d6db94908b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 00:11:20 +0200 Subject: [PATCH] reflexive should broadcast as well --- xarray/namedarray/_array_api/_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 777d7f84d39..49dc8a24f88 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -410,6 +410,8 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) >>> _get_broadcasted_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) + >>> _get_broadcasted_dims(b, a) + (('x', 'y', 'z'), (5, 3, 4)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((1, 3, 4))) @@ -445,13 +447,17 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] ... ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4)) """ + DEFAULT_SIZE = -1 + BROADCASTABLE_SIZES = (0, 1) + BROADCASTABLE_SIZES_OR_DEFAULT = BROADCASTABLE_SIZES + (DEFAULT_SIZE,) + arrays_dims = tuple(a.dims for a in arrays) arrays_shapes = tuple(a.shape for a in arrays) sizes: dict[Any, Any] = {} for dims, shape in zip( zip_longest(*map(reversed, arrays_dims), fillvalue=_default), - zip_longest(*map(reversed, arrays_shapes), fillvalue=-1), + zip_longest(*map(reversed, arrays_shapes), fillvalue=DEFAULT_SIZE), strict=False, ): for d, s in zip(reversed(dims), reversed(shape), strict=False): @@ -461,8 +467,11 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] if s is None: raise NotImplementedError("TODO: Handle None in shape, {shapes = }") - s_prev = sizes.get(d, -1) - if s_prev not in (-1, 0, 1, s): + s_prev = sizes.get(d, DEFAULT_SIZE) + if not ( + s == s_prev + or any(v in BROADCASTABLE_SIZES_OR_DEFAULT for v in (s, s_prev)) + ): raise ValueError( "operands could not be broadcast together with " f"dims = {arrays_dims} and shapes = {arrays_shapes}"