From b2696509d4d60b7997b8c440e9a664b0fcb10122 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 05:57:26 +0200 Subject: [PATCH] Handle unordered dims --- .../_array_api/_manipulation_functions.py | 22 ++++++++++++++----- xarray/namedarray/core.py | 1 + 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index a54066cac1b..114a3480efe 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -323,8 +323,17 @@ def _set_dims( >>> x_new.dims, x_new.shape (('x',), (3,)) + + Unordered dims + + >>> x = NamedArray(("y", "x"), np.zeros((2, 3))) + >>> x_new = _set_dims(x, ("x", "y"), None) + >>> x_new.dims, x_new.shape + (('x', 'y'), (3, 2)) + Error + >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) >>> x_new = _set_dims(x, (), None) Traceback (most recent call last): ... @@ -335,13 +344,15 @@ def _set_dims( # remains writeable as long as possible: return x - extra_dims = tuple(d for d in dim if d not in x.dims) - if not extra_dims: + missing_dims = set(x.dims) - set(dim) + if missing_dims: raise ValueError( f"new dimensions {dim!r} must be a superset of " f"existing dimensions {x.dims!r}" ) + extra_dims = tuple(d for d in dim if d not in x.dims) + if shape is not None: # Add dimensions, with same size as shape: dims_map = dict(zip(dim, shape, strict=False)) @@ -404,9 +415,10 @@ def _arithmetic_broadcast( if not _OPTIONS["arithmetic_broadcast"]: if x1.dims != x2.dims: raise ValueError( - "Broadcasting is necessary but automatic broadcasting is disabled via " - "global option `'arithmetic_broadcast'`. " - "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + "Broadcasting is necessary but automatic broadcasting is disabled " + "via global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable " + "automatic broadcasting." ) return _broadcast_arrays_with_minimal_size(x1, x2) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index e3f98a7f205..42983104319 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -432,6 +432,7 @@ def _maybe_asarray( # x is proper array. Respect the chosen dtype. return x # x is a scalar. Use the same dtype as self. + # TODO: Is this a good idea? x[Any, int] + 1.4 => int result then. return asarray(x, dtype=self.dtype) # Required methods below: