From 18400d53fc6333b8180987b42be100a86b0f6027 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:29:44 +0200 Subject: [PATCH] indexing are always 1d? --- xarray/namedarray/core.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 72c22c838b9..5d2c40bdb45 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -505,9 +505,9 @@ def __ge__(self, other: int | float | NamedArray, /): def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: """ - sdfd + Returns self[key]. - Some rules + Some rules: * Integers removes the dim. * Slices and ellipsis maintains same dim. * None adds a dim. @@ -519,14 +519,14 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: 1D >>> x = NamedArray(("x",), np.array([0, 1, 2])) - >>> mask = NamedArray(("x",), np.array([1, 0, 0], dtype=bool)) - >>> xm = x[mask] + >>> key = NamedArray(("x",), np.array([1, 0, 0], dtype=bool)) + >>> xm = x[key] >>> xm.dims, xm.shape (('x',), (1,)) >>> x = NamedArray(("x",), np.array([0, 1, 2])) - >>> mask = NamedArray(("x",), np.array([0, 0, 0], dtype=bool)) - >>> xm = x[mask] + >>> key = NamedArray(("x",), np.array([0, 0, 0], dtype=bool)) + >>> xm = x[key] >>> xm.dims, xm.shape (('x',), (0,)) @@ -543,15 +543,24 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: >>> xm.dims, xm.shape (('dim_2', 'x', 'y'), (1, 3, 4)) - >>> mask = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool)) - >>> xm = x[mask] + >>> key = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool)) + >>> xm = x[key] >>> xm.dims, xm.shape ((('x', 'y'),), (12,)) + + 0D + + >>> x = NamedArray((), np.array(False, dtype=np.bool)) + >>> key = NamedArray((), np.array(False, dtype=np.bool)) + >>> xm = x[key] + >>> xm.dims, xm.shape + (('dim_0',), (0,)) """ from xarray.namedarray._array_api._manipulation_functions import ( _broadcast_arrays, ) from xarray.namedarray._array_api._utils import ( + _atleast1d_dims, _dims_from_tuple_indexing, _flatten_dims, ) @@ -559,7 +568,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: if isinstance(key, NamedArray): self_new, key_new = _broadcast_arrays(self, key) _data = self_new._data[key_new._data] - _dims = _flatten_dims(self_new.dims) + _dims = _flatten_dims(_atleast1d_dims(self_new.dims)) return self._new(_dims, _data) # elif isinstance(key, int): # return self._new(self.dims[1:], self._data[key])