Skip to content

Commit

Permalink
reflexive should broadcast as well
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 19, 2024
1 parent 9d4b9d2 commit f64f967
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]
>>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4)))
>>> _get_broadcasted_dims(a, b)
(('x', 'y', 'z'), (5, 3, 4))
>>> _get_broadcasted_dims(b, a)
(('x', 'y', 'z'), (5, 3, 4))
>>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4)))
>>> b = NamedArray(("x", "y", "z"), np.zeros((1, 3, 4)))
Expand Down Expand Up @@ -445,13 +447,17 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]
...
ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4))
"""
DEFAULT_SIZE = -1
BROADCASTABLE_SIZES = (0, 1)
BROADCASTABLE_SIZES_OR_DEFAULT = BROADCASTABLE_SIZES + (DEFAULT_SIZE,)

arrays_dims = tuple(a.dims for a in arrays)
arrays_shapes = tuple(a.shape for a in arrays)

sizes: dict[Any, Any] = {}
for dims, shape in zip(
zip_longest(*map(reversed, arrays_dims), fillvalue=_default),
zip_longest(*map(reversed, arrays_shapes), fillvalue=-1),
zip_longest(*map(reversed, arrays_shapes), fillvalue=DEFAULT_SIZE),
strict=False,
):
for d, s in zip(reversed(dims), reversed(shape), strict=False):
Expand All @@ -461,8 +467,11 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]
if s is None:
raise NotImplementedError("TODO: Handle None in shape, {shapes = }")

s_prev = sizes.get(d, -1)
if s_prev not in (-1, 0, 1, s):
s_prev = sizes.get(d, DEFAULT_SIZE)
if not (
s == s_prev
or any(v in BROADCASTABLE_SIZES_OR_DEFAULT for v in (s, s_prev))
):
raise ValueError(
"operands could not be broadcast together with "
f"dims = {arrays_dims} and shapes = {arrays_shapes}"
Expand Down

0 comments on commit f64f967

Please sign in to comment.