Skip to content

Commit

Permalink
Handle unordered dims
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 20, 2024
1 parent 5540702 commit b269650
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
22 changes: 17 additions & 5 deletions xarray/namedarray/_array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,17 @@ def _set_dims(
>>> x_new.dims, x_new.shape
(('x',), (3,))
Unordered dims
>>> x = NamedArray(("y", "x"), np.zeros((2, 3)))
>>> x_new = _set_dims(x, ("x", "y"), None)
>>> x_new.dims, x_new.shape
(('x', 'y'), (3, 2))
Error
>>> x = NamedArray(("x",), np.asarray([1, 2, 3]))
>>> x_new = _set_dims(x, (), None)
Traceback (most recent call last):
...
Expand All @@ -335,13 +344,15 @@ def _set_dims(
# remains writeable as long as possible:
return x

extra_dims = tuple(d for d in dim if d not in x.dims)
if not extra_dims:
missing_dims = set(x.dims) - set(dim)
if missing_dims:
raise ValueError(
f"new dimensions {dim!r} must be a superset of "
f"existing dimensions {x.dims!r}"
)

extra_dims = tuple(d for d in dim if d not in x.dims)

if shape is not None:
# Add dimensions, with same size as shape:
dims_map = dict(zip(dim, shape, strict=False))
Expand Down Expand Up @@ -404,9 +415,10 @@ def _arithmetic_broadcast(
if not _OPTIONS["arithmetic_broadcast"]:
if x1.dims != x2.dims:
raise ValueError(
"Broadcasting is necessary but automatic broadcasting is disabled via "
"global option `'arithmetic_broadcast'`. "
"Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
"Broadcasting is necessary but automatic broadcasting is disabled "
"via global option `'arithmetic_broadcast'`. "
"Use `xr.set_options(arithmetic_broadcast=True)` to enable "
"automatic broadcasting."
)

return _broadcast_arrays_with_minimal_size(x1, x2)
Expand Down
1 change: 1 addition & 0 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def _maybe_asarray(
# x is proper array. Respect the chosen dtype.
return x
# x is a scalar. Use the same dtype as self.
# TODO: Is this a good idea? x[Any, int] + 1.4 => int result then.
return asarray(x, dtype=self.dtype)

# Required methods below:
Expand Down

0 comments on commit b269650

Please sign in to comment.