From 520c1699f0a75aae6347690d71dbd5b903470826 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:04:48 +0200 Subject: [PATCH 001/367] Add mean to array_api --- xarray/core/formatting.py | 2 +- xarray/namedarray/_array_api.py | 114 +++++++++++++++++++++++--- xarray/namedarray/_typing.py | 18 ++++- xarray/namedarray/core.py | 136 ++++++++++++++++++++++---------- 4 files changed, 216 insertions(+), 54 deletions(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 96a767f95ac..561b8d3cc0d 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -632,7 +632,7 @@ def short_data_repr(array): return short_array_repr(array) elif is_duck_array(internal_data): return limit_lines(repr(array.data), limit=40) - elif array._in_memory: + elif getattr(array, "_in_memory", None): return short_array_repr(array) else: # internal xarray array type diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 69f97305686..03b366b2df6 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -1,24 +1,49 @@ from types import ModuleType from typing import Any +import warnings import numpy as np from xarray.namedarray._typing import ( + duckarray, _arrayapi, + _Axes, + _AxisLike, + _Dims, + _DimsLike, _DType, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, ) -from xarray.namedarray.core import NamedArray +from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: if isinstance(x._data, _arrayapi): return x._data.__array_namespace__() - else: - return np + + return np + + +def _to_nxp( + x: duckarray[_ShapeType, _DType] +) -> tuple[ModuleType, _arrayapi[_ShapeType, _DType]]: + + return nxp, nxp.asarray(x) + + +# %% Creation Functions def astype( @@ -49,18 +74,21 @@ def astype( Examples -------- - >>> narr = NamedArray(("x",), np.array([1.5, 2.5])) + >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) >>> astype(narr, np.dtype(int)).data - array([1, 2]) + Array([1, 2], dtype=int32) """ if isinstance(x._data, _arrayapi): xp = x._data.__array_namespace__() - return x._new(data=xp.astype(x, dtype, copy=copy)) + return x._new(data=xp.astype(x._data, dtype, copy=copy)) # np.astype doesn't exist yet: return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] +# %% Elementwise Functions + + def imag( x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] ) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: @@ -83,7 +111,7 @@ def imag( Examples -------- - >>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j])) + >>> narr = NamedArray(("x",), np.asarray([1. + 2j, 2 + 4j])) # TODO: Use nxp >>> imag(narr).data array([2., 4.]) """ @@ -114,9 +142,77 @@ def real( Examples -------- - >>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j])) + >>> narr = NamedArray(("x",), np.asarray([1. + 2j, 2 + 4j])) # TODO: Use nxp >>> real(narr).data array([1., 2.]) """ xp = _get_data_namespace(x) - return x._new(data=xp.real(x._data)) + out = x._new(data=xp.real(x._data)) + return out + + +# %% Statistical Functions + + +def mean( + x: NamedArray[Any, _DType], + /, + *, + axis: _AxisLike | None = None, + keepdims: bool = False, + dims: _Dims | None = None, +) -> NamedArray[Any, _DType]: + """ + Calculates the arithmetic mean of the input array x. + + Parameters + ---------- + x : + Should have a real-valued floating-point data type. + dims : + Dim or dims along which arithmetic means must be computed. By default, + the mean must be computed over the entire array. If a tuple of hashables, + arithmetic means must be computed over multiple axes. + Default: None. + keepdims : + if True, the reduced axes (dimensions) must be included in the result + as singleton dimensions, and, accordingly, the result must be compatible + with the input array (see Broadcasting). Otherwise, if False, the + reduced axes (dimensions) must not be included in the result. + Default: False. + axis : + Axis or axes along which arithmetic means must be computed. By default, + the mean must be computed over the entire array. If a tuple of integers, + arithmetic means must be computed over multiple axes. + Default: None. + + Returns + ------- + out : + If the arithmetic mean was computed over the entire array, + a zero-dimensional array containing the arithmetic mean; otherwise, + a non-zero-dimensional array containing the arithmetic means. + The returned array must have the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> mean(x).data + Array(2.5, dtype=float64) + """ + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.mean(x._data, axis=axis, keepdims=keepdims) + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +if __name__ == "__main__": + import doctest + + doctest.testmod() + + x = NamedArray(("x", "y"), _to_nxp(np.array([[1.0, 2], [3, 4]]))[-1]) + x_mean = mean(x) + print(x_mean) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 820371a7463..452f283cf0d 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -19,6 +19,12 @@ if TYPE_CHECKING: from numpy.typing import NDArray + # https://data-apis.org/array-api/latest/API_specification/indexing.html + # TODO: np.array_api doesn't allow None for some reason, maybe they're + # recommending to use expand_dims? + _IndexKey = Union[int, slice, ellipsis, None] + _IndexKeys = tuple[_IndexKey, ...] + _IndexKeyLike = Union[_IndexKey, _IndexKeys] # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array _T = TypeVar("_T") @@ -48,12 +54,16 @@ def dtype(self) -> _DType_co: ] # For unknown shapes Dask uses np.nan, array_api uses None: -_IntOrUnknown = int +_IntOrUnknown = int # Union[int, _Unknown] _Shape = tuple[_IntOrUnknown, ...] _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] _ShapeType = TypeVar("_ShapeType", bound=Any) _ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) +_Axis = int +_Axes = tuple[_Axis, ...] +_AxisLike = Union[_Axis, _Axes] + _Chunks = tuple[_Shape, ...] _Dim = Hashable @@ -116,6 +126,9 @@ class _arrayfunction( Corresponds to np.ndarray. """ + def __getitem__(self, key: _IndexKeyLike, /) -> _arrayfunction[Any, _DType_co]: + ... + # TODO: Should return the same subclass but with a new dtype generic. # https://github.com/python/typing/issues/548 def __array_ufunc__( @@ -151,6 +164,9 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType Corresponds to np.ndarray. """ + def __getitem__(self, key: _IndexKeyLike, /) -> _arrayapi[Any, _DType_co]: + ... + def __array_namespace__(self) -> ModuleType: ... diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index f13c7e8d2c6..f98774f9af8 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -42,6 +42,8 @@ from xarray.namedarray._typing import ( DuckArray, _AttrsLike, + _Axes, + _AxisLike, _Chunks, _Dim, _Dims, @@ -80,6 +82,81 @@ ) +def _dims_to_axis( + x: NamedArray[Any, Any], dims: _DimsLike | None, axis: _AxisLike | None +) -> _Axes | None: + """ + Convert dims to axis indices. + + Examples + -------- + >>> narr = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) + >>> _dims_to_axis(narr, ("y",), None) + (1,) + >>> _dims_to_axis(narr, None, 0) + (0,) + >>> _dims_to_axis(narr, None, None) + """ + # if dims == ...: + # dims = None + if dims is not None and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim' arguments") + + if dims is None: + if axis is None or isinstance(axis, tuple): + return axis + return (axis,) + else: + return x.get_axis_num(dims) + + +def _get_remaining_dims( + x: NamedArray[Any, _DType], + data: duckarray[Any, _DType], + axis: _Axes, + *, + keepdims: bool, +) -> tuple[_Dims, duckarray[Any, _DType]]: + """ + Get the reamining dims after a reduce operation. + + Parameters + ---------- + x : + DESCRIPTION. + data : + DESCRIPTION. + axis : + DESCRIPTION. + keepdims : + DESCRIPTION. + + Returns + ------- + tuple[_Dims, duckarray[Any, _DType]] + DESCRIPTION. + + """ + removed_axes: np.ndarray[Any, np.dtype[int]] + if axis is None: + removed_axes = np.arange(x.ndim, dtype=int) + else: + removed_axes = np.atleast_1d(axis) % x.ndim + + if keepdims: + # Insert np.newaxis for removed dims + slices = tuple( + np.newaxis if i in removed_axes else slice(None, None) + for i in range(x.ndim) + ) + data = data[slices] + dims = x.dims + else: + dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) + + return dims, data + + @overload def _new( x: NamedArray[Any, _DType_co], @@ -597,7 +674,14 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) - def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + def _get_axis_num(self, dim: _Dim) -> int: + try: + out = self.dims.index(dim) + return out + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + + def get_axis_num(self, dims: _Dims) -> _Axes: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -610,16 +694,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, . int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if not isinstance(dim, str) and isinstance(dim, Iterable): - return tuple(self._get_axis_num(d) for d in dim) - else: - return self._get_axis_num(dim) - - def _get_axis_num(self: Any, dim: Hashable) -> int: - try: - return self.dims.index(dim) # type: ignore[no-any-return] - except ValueError: - raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + return tuple(self._get_axis_num(d) for d in dims) @property def chunks(self) -> _Chunks | None: @@ -704,51 +779,26 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - if dim == ...: - dim = None - if dim is not None and axis is not None: - raise ValueError("cannot supply both 'axis' and 'dim' arguments") - - if dim is not None: - axis = self.get_axis_num(dim) + axis_ = _dims_to_axis(self, dim, axis) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning ) - if axis is not None: - if isinstance(axis, tuple) and len(axis) == 1: + if axis_ is not None: + if len(axis_) == 1: # unpack axis for the benefit of functions # like np.argmin which can't handle tuple arguments - axis = axis[0] - data = func(self.data, axis=axis, **kwargs) + data = func(self.data, axis=axis_[0], **kwargs) + else: + data = func(self.data, axis=axis_, **kwargs) else: data = func(self.data, **kwargs) if getattr(data, "shape", ()) == self.shape: dims = self.dims else: - removed_axes: Iterable[int] - if axis is None: - removed_axes = range(self.ndim) - else: - removed_axes = np.atleast_1d(axis) % self.ndim - if keepdims: - # Insert np.newaxis for removed dims - slices = tuple( - np.newaxis if i in removed_axes else slice(None, None) - for i in range(self.ndim) - ) - if getattr(data, "shape", None) is None: - # Reduce has produced a scalar value, not an array-like - data = np.asanyarray(data)[slices] - else: - data = data[slices] - dims = self.dims - else: - dims = tuple( - adim for n, adim in enumerate(self.dims) if n not in removed_axes - ) + dims, data = _get_remaining_dims(data, data, axis, keepdims=keepdims) # Return NamedArray to handle IndexVariable when data is nD return from_array(dims, data, attrs=self._attrs) From c0e30e792c39770ce604c37a8ae6471e12b526d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Oct 2023 21:05:50 +0000 Subject: [PATCH 002/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 12 ++++-------- xarray/namedarray/core.py | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 03b366b2df6..37f7ad55f05 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -1,25 +1,22 @@ +import warnings from types import ModuleType from typing import Any -import warnings import numpy as np from xarray.namedarray._typing import ( - duckarray, _arrayapi, - _Axes, _AxisLike, _Dims, - _DimsLike, _DType, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, + duckarray, ) from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims - with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -39,7 +36,6 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: def _to_nxp( x: duckarray[_ShapeType, _DType] ) -> tuple[ModuleType, _arrayapi[_ShapeType, _DType]]: - return nxp, nxp.asarray(x) @@ -111,7 +107,7 @@ def imag( Examples -------- - >>> narr = NamedArray(("x",), np.asarray([1. + 2j, 2 + 4j])) # TODO: Use nxp + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp >>> imag(narr).data array([2., 4.]) """ @@ -142,7 +138,7 @@ def real( Examples -------- - >>> narr = NamedArray(("x",), np.asarray([1. + 2j, 2 + 4j])) # TODO: Use nxp + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp >>> real(narr).data array([1., 2.]) """ diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index f98774f9af8..177bdae4aba 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -4,7 +4,7 @@ import math import sys import warnings -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, From 39f66a18fa3830ee12357ec573f7e8e5ff1a1c17 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 00:02:17 +0200 Subject: [PATCH 003/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 36 ++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 03b366b2df6..7de8a09eac3 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -199,20 +199,42 @@ def mean( >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) >>> mean(x).data Array(2.5, dtype=float64) + >>> mean(x, dims=("x",)).data + Array([2., 3.], dtype=float64) + + Using keepdims: + + >>> mean(x, dims=("x",), keepdims=True).data + Array([[2., 3.]], dtype=float64) + >>> mean(x, dims=("y",), keepdims=True).data + Array([[1.5], + [3.5]], dtype=float64) """ xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) - d = xp.mean(x._data, axis=axis, keepdims=keepdims) + d = xp.mean(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out -if __name__ == "__main__": - import doctest +# if __name__ == "__main__": +# import doctest + +# doctest.testmod() +# err +# arr = nxp.asarray([[1.0, 2], [3, 4]]) +# arr_m = nxp.mean(arr, axis=(1,), keepdims=True) - doctest.testmod() +# removed_axes = np.atleast_1d((0,)) % arr.ndim +# slices = tuple( +# np.newaxis if i in removed_axes else slice(None, None) for i in range(arr.ndim) +# ) +# arr_m[slices] +# err +# x = NamedArray(("x", "y"), arr) +# x_mean = mean(x, axis=(0,), keepdims=True) +# print(x_mean) - x = NamedArray(("x", "y"), _to_nxp(np.array([[1.0, 2], [3, 4]]))[-1]) - x_mean = mean(x) - print(x_mean) +# x_mean = mean(x, dims=("x",), keepdims=True) From 985965ff1bde0b3569c247a87e31c232da3814dc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 00:29:03 +0200 Subject: [PATCH 004/367] Update core.py --- xarray/namedarray/core.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 177bdae4aba..173720deced 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -97,8 +97,6 @@ def _dims_to_axis( (0,) >>> _dims_to_axis(narr, None, None) """ - # if dims == ...: - # dims = None if dims is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") @@ -137,9 +135,9 @@ def _get_remaining_dims( DESCRIPTION. """ - removed_axes: np.ndarray[Any, np.dtype[int]] + removed_axes: np.ndarray[Any, np.dtype[np.intp]] if axis is None: - removed_axes = np.arange(x.ndim, dtype=int) + removed_axes = np.arange(x.ndim, dtype=np.intp) else: removed_axes = np.atleast_1d(axis) % x.ndim @@ -779,7 +777,12 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - axis_ = _dims_to_axis(self, dim, axis) + d: _Dims + if dim == ...: + d = None + else: + d = dim + axis_ = _dims_to_axis(self, d, axis) with warnings.catch_warnings(): warnings.filterwarnings( @@ -796,12 +799,12 @@ def reduce( data = func(self.data, **kwargs) if getattr(data, "shape", ()) == self.shape: - dims = self.dims + dims_ = self.dims else: - dims, data = _get_remaining_dims(data, data, axis, keepdims=keepdims) + dims_, data = _get_remaining_dims(data, data, axis, keepdims=keepdims) # Return NamedArray to handle IndexVariable when data is nD - return from_array(dims, data, attrs=self._attrs) + return from_array(dims_, data, attrs=self._attrs) def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" From b2e5a0ca25319af48fb0e4e93582f3ba75c3476d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 00:39:12 +0200 Subject: [PATCH 005/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 9595784045d..ed619291566 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -71,7 +71,7 @@ def astype( Examples -------- >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) - >>> astype(narr, np.dtype(int)).data + >>> astype(narr, np.dtype(np.int32)).data Array([1, 2], dtype=int32) """ if isinstance(x._data, _arrayapi): @@ -154,9 +154,9 @@ def mean( x: NamedArray[Any, _DType], /, *, - axis: _AxisLike | None = None, - keepdims: bool = False, dims: _Dims | None = None, + keepdims: bool = False, + axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: """ Calculates the arithmetic mean of the input array x. From c4e72c0db2e33889eea0abd84dd803f5b9605c0c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:47:05 +0200 Subject: [PATCH 006/367] Update variable.py --- xarray/core/variable.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 576535eea2b..d8b220758aa 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,7 +8,7 @@ from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, Iterable import numpy as np import pandas as pd @@ -2600,6 +2600,24 @@ def argmax( """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + """Return axis number(s) corresponding to dimension(s) in this array. + + Parameters + ---------- + dim : str or iterable of str + Dimension name(s) for which to lookup axes. + + Returns + ------- + int or tuple of int + Axis number or numbers corresponding to the given dimensions. + """ + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) + class IndexVariable(Variable): """Wrapper for accommodating a pandas.Index in an xarray.Variable. From 15cf3eb26f394275075f462be6488c9e8384598a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:48:42 +0000 Subject: [PATCH 007/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/variable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d8b220758aa..21025a870d8 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -5,10 +5,10 @@ import math import numbers import warnings -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, Iterable +from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast import numpy as np import pandas as pd From 4b90908bcb55d5455847b980ef81c9b5ba051024 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:58:26 +0200 Subject: [PATCH 008/367] undo --- xarray/core/variable.py | 18 ------------------ xarray/namedarray/core.py | 8 ++++++-- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d8b220758aa..4f741c07126 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2600,24 +2600,6 @@ def argmax( """ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) - def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: - """Return axis number(s) corresponding to dimension(s) in this array. - - Parameters - ---------- - dim : str or iterable of str - Dimension name(s) for which to lookup axes. - - Returns - ------- - int or tuple of int - Axis number or numbers corresponding to the given dimensions. - """ - if not isinstance(dim, str) and isinstance(dim, Iterable): - return tuple(self._get_axis_num(d) for d in dim) - else: - return self._get_axis_num(dim) - class IndexVariable(Variable): """Wrapper for accommodating a pandas.Index in an xarray.Variable. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 173720deced..6a661eb85e8 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -14,6 +14,7 @@ TypeVar, cast, overload, + Iterable, ) import numpy as np @@ -679,7 +680,7 @@ def _get_axis_num(self, dim: _Dim) -> int: except ValueError: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") - def get_axis_num(self, dims: _Dims) -> _Axes: + def get_axis_num(self, dim: _Dims) -> _Axes: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -692,7 +693,10 @@ def get_axis_num(self, dims: _Dims) -> _Axes: int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - return tuple(self._get_axis_num(d) for d in dims) + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) @property def chunks(self) -> _Chunks | None: From d6b13b0f342ae3b7eab7ffc899ef43e97e515e0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:59:29 +0000 Subject: [PATCH 009/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/variable.py | 2 +- xarray/namedarray/core.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0b87ad792e9..576535eea2b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -5,7 +5,7 @@ import math import numbers import warnings -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta from functools import partial from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 6a661eb85e8..8bdce1b056e 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -4,7 +4,7 @@ import math import sys import warnings -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -14,7 +14,6 @@ TypeVar, cast, overload, - Iterable, ) import numpy as np From 76930adcb877cec01b4bd70a14b34cc87a70d350 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 19:05:49 +0200 Subject: [PATCH 010/367] Update core.py --- xarray/namedarray/core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 8bdce1b056e..11d925c0c06 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -672,13 +672,6 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) - def _get_axis_num(self, dim: _Dim) -> int: - try: - out = self.dims.index(dim) - return out - except ValueError: - raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") - def get_axis_num(self, dim: _Dims) -> _Axes: """Return axis number(s) corresponding to dimension(s) in this array. @@ -697,6 +690,13 @@ def get_axis_num(self, dim: _Dims) -> _Axes: else: return self._get_axis_num(dim) + def _get_axis_num(self, dim: _Dim) -> int: + try: + out = self.dims.index(dim) + return out + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") + @property def chunks(self) -> _Chunks | None: """ @@ -804,7 +804,7 @@ def reduce( if getattr(data, "shape", ()) == self.shape: dims_ = self.dims else: - dims_, data = _get_remaining_dims(data, data, axis, keepdims=keepdims) + dims_, data = _get_remaining_dims(self, data, axis, keepdims=keepdims) # Return NamedArray to handle IndexVariable when data is nD return from_array(dims_, data, attrs=self._attrs) From a01a50688e0e6b004946c9d1e3fae0481e5a8474 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 19:31:10 +0200 Subject: [PATCH 011/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 11d925c0c06..0458b80bbe0 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -792,7 +792,7 @@ def reduce( "ignore", r"Mean of empty slice", category=RuntimeWarning ) if axis_ is not None: - if len(axis_) == 1: + if isinstance(axis, tuple) and len(axis) == 1: # unpack axis for the benefit of functions # like np.argmin which can't handle tuple arguments data = func(self.data, axis=axis_[0], **kwargs) From 237f2d828b0061305613986795f670982da14882 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 20:07:31 +0200 Subject: [PATCH 012/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 0458b80bbe0..c04a6c173fd 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -792,7 +792,7 @@ def reduce( "ignore", r"Mean of empty slice", category=RuntimeWarning ) if axis_ is not None: - if isinstance(axis, tuple) and len(axis) == 1: + if isinstance(axis_, tuple) and len(axis_) == 1: # unpack axis for the benefit of functions # like np.argmin which can't handle tuple arguments data = func(self.data, axis=axis_[0], **kwargs) From cc39bf06076ebf96a7adf1b4c9e87b34dc2bd616 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 20:37:38 +0200 Subject: [PATCH 013/367] Update core.py --- xarray/namedarray/core.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c04a6c173fd..d27883eda7a 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -100,12 +100,14 @@ def _dims_to_axis( if dims is not None and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") - if dims is None: - if axis is None or isinstance(axis, tuple): - return axis + if dims is not None: + dims_: _Dims = (dims,) if isinstance(dims, str) else dims + return x.get_axis_num(dims_) + + if isinstance(axis, int): return (axis,) - else: - return x.get_axis_num(dims) + + return axis def _get_remaining_dims( From 5c8b1fa736c7c250433589dea7d85e5910ab45f4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 20:46:01 +0200 Subject: [PATCH 014/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d27883eda7a..a5161a2a5a4 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -806,7 +806,7 @@ def reduce( if getattr(data, "shape", ()) == self.shape: dims_ = self.dims else: - dims_, data = _get_remaining_dims(self, data, axis, keepdims=keepdims) + dims_, data = _get_remaining_dims(self, data, axis_, keepdims=keepdims) # Return NamedArray to handle IndexVariable when data is nD return from_array(dims_, data, attrs=self._attrs) From 60de162adf9c7643b2b0bf50e10eb59037d8c0b1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 20:56:23 +0200 Subject: [PATCH 015/367] Update test_namedarray.py --- xarray/tests/test_namedarray.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 93bff4d6a05..0153bc0ff7c 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -27,6 +27,7 @@ _DType, _Shape, duckarray, + _IndexKeyLike, ) from xarray.namedarray.utils import Default @@ -56,6 +57,11 @@ class CustomArrayIndexable( ExplicitlyIndexed, Generic[_ShapeType_co, _DType_co], ): + def __getitem__( + self, key: _IndexKeyLike, / + ) -> CustomArrayIndexable[Any, _DType_co]: + ... + def __array_namespace__(self) -> ModuleType: return np From 41c3b2534461f0d6239ce04f1e70fc17e57888dc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 18:57:41 +0000 Subject: [PATCH 016/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/test_namedarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 0153bc0ff7c..3eee13d76e3 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -25,9 +25,9 @@ _AttrsLike, _DimsLike, _DType, + _IndexKeyLike, _Shape, duckarray, - _IndexKeyLike, ) from xarray.namedarray.utils import Default From 0c87319d9c555d18820967867f23dd12be3ba72b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 21:17:44 +0200 Subject: [PATCH 017/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index ed619291566..b21b1ee5dbd 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -71,7 +71,11 @@ def astype( Examples -------- >>> narr = NamedArray(("x",), nxp.asarray([1.5, 2.5])) - >>> astype(narr, np.dtype(np.int32)).data + >>> narr + + Array([1.5, 2.5], dtype=float64) + >>> astype(narr, np.dtype(np.int32)) + Array([1, 2], dtype=int32) """ if isinstance(x._data, _arrayapi): @@ -108,7 +112,8 @@ def imag( Examples -------- >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp - >>> imag(narr).data + >>> imag(narr) + array([2., 4.]) """ xp = _get_data_namespace(x) @@ -139,7 +144,8 @@ def real( Examples -------- >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) # TODO: Use nxp - >>> real(narr).data + >>> real(narr) + array([1., 2.]) """ xp = _get_data_namespace(x) @@ -200,9 +206,11 @@ def mean( Using keepdims: - >>> mean(x, dims=("x",), keepdims=True).data + >>> mean(x, dims=("x",), keepdims=True) + Array([[2., 3.]], dtype=float64) - >>> mean(x, dims=("y",), keepdims=True).data + >>> mean(x, dims=("y",), keepdims=True) + Array([[1.5], [3.5]], dtype=float64) """ @@ -219,18 +227,3 @@ def mean( # import doctest # doctest.testmod() -# err -# arr = nxp.asarray([[1.0, 2], [3, 4]]) -# arr_m = nxp.mean(arr, axis=(1,), keepdims=True) - -# removed_axes = np.atleast_1d((0,)) % arr.ndim -# slices = tuple( -# np.newaxis if i in removed_axes else slice(None, None) for i in range(arr.ndim) -# ) -# arr_m[slices] -# err -# x = NamedArray(("x", "y"), arr) -# x_mean = mean(x, axis=(0,), keepdims=True) -# print(x_mean) - -# x_mean = mean(x, dims=("x",), keepdims=True) From a246e412ae19039db429fc88c340cffd493c8feb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 21:37:47 +0200 Subject: [PATCH 018/367] Update core.py --- xarray/namedarray/core.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index a5161a2a5a4..71804d58a69 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -137,6 +137,9 @@ def _get_remaining_dims( DESCRIPTION. """ + if data.shape == x.shape: + return x.dims, data + removed_axes: np.ndarray[Any, np.dtype[np.intp]] if axis is None: removed_axes = np.arange(x.ndim, dtype=np.intp) @@ -789,6 +792,7 @@ def reduce( d = dim axis_ = _dims_to_axis(self, d, axis) + data: duckarray[Any, Any] | ArrayLike with warnings.catch_warnings(): warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning @@ -803,10 +807,10 @@ def reduce( else: data = func(self.data, **kwargs) - if getattr(data, "shape", ()) == self.shape: - dims_ = self.dims - else: - dims_, data = _get_remaining_dims(self, data, axis_, keepdims=keepdims) + if not isinstance(data, duckarray): + data = np.asarray(data) + + dims_, data = _get_remaining_dims(self, data, axis_, keepdims=keepdims) # Return NamedArray to handle IndexVariable when data is nD return from_array(dims_, data, attrs=self._attrs) From f4936bb20c5e7ccc77242b4e124f14d1324da2aa Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Oct 2023 21:44:57 +0200 Subject: [PATCH 019/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 71804d58a69..64cf6f204ff 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -807,7 +807,7 @@ def reduce( else: data = func(self.data, **kwargs) - if not isinstance(data, duckarray): + if not isinstance(data, _arrayfunction_or_api): data = np.asarray(data) dims_, data = _get_remaining_dims(self, data, axis_, keepdims=keepdims) From 489b63a3d51a622f8cc35d7836e84e53b1da002a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 12:44:47 +0200 Subject: [PATCH 020/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index b21b1ee5dbd..335b702d77d 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + import warnings from types import ModuleType from typing import Any From a50496bfffe2d83b8958b64289b66394f9bd0ad6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Oct 2023 10:45:33 +0000 Subject: [PATCH 021/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 335b702d77d..f48dacbb045 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -1,6 +1,5 @@ from __future__ import annotations - import warnings from types import ModuleType from typing import Any From 814fda488840b86add63e811af6833c879ccf666 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:00:30 +0200 Subject: [PATCH 022/367] fixes --- xarray/namedarray/_typing.py | 48 +++++++++++++++++++++++---------- xarray/namedarray/core.py | 45 +++++++++++++++++++------------ xarray/tests/test_namedarray.py | 15 ++++++++++- 3 files changed, 76 insertions(+), 32 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 452f283cf0d..59b34c62d8d 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -19,18 +19,13 @@ if TYPE_CHECKING: from numpy.typing import NDArray - # https://data-apis.org/array-api/latest/API_specification/indexing.html - # TODO: np.array_api doesn't allow None for some reason, maybe they're - # recommending to use expand_dims? - _IndexKey = Union[int, slice, ellipsis, None] - _IndexKeys = tuple[_IndexKey, ...] - _IndexKeyLike = Union[_IndexKey, _IndexKeys] # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) +_dtype = np.dtype _DType = TypeVar("_DType", bound=np.dtype[Any]) _DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) # A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic` @@ -68,11 +63,18 @@ def dtype(self) -> _DType_co: _Dim = Hashable _Dims = tuple[_Dim, ...] - _DimsLike = Union[str, Iterable[_Dim]] -_AttrsLike = Union[Mapping[Any, Any], None] +_DimsLikeAgg = Union[_DimsLike, "ellipsis", None] -_dtype = np.dtype + +# https://data-apis.org/array-api/latest/API_specification/indexing.html +# TODO: np.array_api doesn't allow None for some reason, maybe they're +# recommending to use expand_dims? +_IndexKey = Union[int, slice, "ellipsis"] +_IndexKeys = tuple[_IndexKey | None, ...] +_IndexKeyLike = Union[_IndexKey, _IndexKeys] + +_AttrsLike = Union[Mapping[Any, Any], None] class _SupportsReal(Protocol[_T_co]): @@ -126,7 +128,22 @@ class _arrayfunction( Corresponds to np.ndarray. """ - def __getitem__(self, key: _IndexKeyLike, /) -> _arrayfunction[Any, _DType_co]: + @overload + def __getitem__(self, key: _IndexKeyLike) -> Any: + ... + + @overload + def __getitem__( + self, key: (_arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...]) + ) -> _arrayfunction[Any, _DType_co]: + ... + + @overload + def __getitem__( + self, + key: _IndexKeyLike | _arrayfunction[Any, Any], + /, + ) -> _arrayfunction[Any, _DType_co] | Any: ... # TODO: Should return the same subclass but with a new dtype generic. @@ -164,7 +181,10 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType Corresponds to np.ndarray. """ - def __getitem__(self, key: _IndexKeyLike, /) -> _arrayapi[Any, _DType_co]: + # TODO: Only integer _arrayapi: + def __getitem__( + self, key: _IndexKeyLike | _arrayapi[Any, Any], / + ) -> _arrayapi[Any, _DType_co]: ... def __array_namespace__(self) -> ModuleType: @@ -254,7 +274,7 @@ class _sparsearray( Corresponds to np.ndarray. """ - def todense(self) -> NDArray[_ScalarType_co]: + def todense(self) -> np.ndarray[Any, _DType_co]: ... @@ -272,7 +292,7 @@ class _sparsearrayfunction( Corresponds to np.ndarray. """ - def todense(self) -> NDArray[_ScalarType_co]: + def todense(self) -> np.ndarray[Any, _DType_co]: ... @@ -290,7 +310,7 @@ class _sparsearrayapi( Corresponds to np.ndarray. """ - def todense(self) -> NDArray[_ScalarType_co]: + def todense(self) -> np.ndarray[Any, _DType_co]: ... diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 64cf6f204ff..e82eecbf951 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -5,6 +5,7 @@ import sys import warnings from collections.abc import Hashable, Iterable, Mapping, Sequence +from types import EllipsisType from typing import ( TYPE_CHECKING, Any, @@ -38,7 +39,6 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray - from xarray.core.types import Dims from xarray.namedarray._typing import ( DuckArray, _AttrsLike, @@ -48,6 +48,7 @@ _Dim, _Dims, _DimsLike, + _DimsLikeAgg, _IntOrUnknown, _ScalarType, _Shape, @@ -83,7 +84,7 @@ def _dims_to_axis( - x: NamedArray[Any, Any], dims: _DimsLike | None, axis: _AxisLike | None + x: NamedArray[Any, Any], dims: _Dims | None, axis: _AxisLike | None ) -> _Axes | None: """ Convert dims to axis indices. @@ -113,7 +114,7 @@ def _dims_to_axis( def _get_remaining_dims( x: NamedArray[Any, _DType], data: duckarray[Any, _DType], - axis: _Axes, + axis: _Axes | None, *, keepdims: bool, ) -> tuple[_Dims, duckarray[Any, _DType]]: @@ -223,9 +224,9 @@ def _new( @overload def from_array( dims: _DimsLike, - data: DuckArray[_ScalarType], + data: duckarray[_ShapeType, _DType], attrs: _AttrsLike = ..., -) -> _NamedArray[_ScalarType]: +) -> NamedArray[_ShapeType, _DType]: ... @@ -234,15 +235,15 @@ def from_array( dims: _DimsLike, data: ArrayLike, attrs: _AttrsLike = ..., -) -> _NamedArray[Any]: +) -> NamedArray[Any, Any]: ... def from_array( dims: _DimsLike, - data: DuckArray[_ScalarType] | ArrayLike, + data: duckarray[_ShapeType, _DType] | ArrayLike, attrs: _AttrsLike = None, -) -> _NamedArray[_ScalarType] | _NamedArray[Any]: +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: """ Create a Named array from an array-like object. @@ -693,7 +694,7 @@ def get_axis_num(self, dim: _Dims) -> _Axes: if not isinstance(dim, str) and isinstance(dim, Iterable): return tuple(self._get_axis_num(d) for d in dim) else: - return self._get_axis_num(dim) + return (self._get_axis_num(dim),) def _get_axis_num(self, dim: _Dim) -> int: try: @@ -752,8 +753,8 @@ def sizes(self) -> dict[_Dim, _IntOrUnknown]: def reduce( self, func: Callable[..., Any], - dim: Dims = None, - axis: int | Sequence[int] | None = None, + dim: _DimsLikeAgg = None, + axis: int | Sequence[int] | None = None, # TODO: Use _AxisLike keepdims: bool = False, **kwargs: Any, ) -> NamedArray[Any, Any]: @@ -785,12 +786,22 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - d: _Dims - if dim == ...: - d = None - else: + + if isinstance(dim, EllipsisType): + # TODO: What's the point of ellipsis? Use either ... or None? + dim = None + d: _Dims | None + if dim is None: d = dim - axis_ = _dims_to_axis(self, d, axis) + else: + d = self._parse_dimensions(dim) + + axislike: _AxisLike | None + if axis is None or isinstance(axis, int): + axislike = axis + else: + axislike = tuple(axis) + axis_ = _dims_to_axis(self, d, axislike) data: duckarray[Any, Any] | ArrayLike with warnings.catch_warnings(): @@ -869,7 +880,7 @@ def _to_dense(self) -> Self: if isinstance(self._data, _sparsearrayfunction_or_api): # return self._replace(data=self._data.todense()) - data_: np.ndarray[Any, Any] = self._data.todense() + data_: np.ndarray[Any, _DType_co] = self._data.todense() return self._replace(data=data_) else: raise TypeError("self.data is not a sparse array") diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 3eee13d76e3..8f4261efe08 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import warnings from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Generic, cast, overload @@ -58,7 +59,7 @@ class CustomArrayIndexable( Generic[_ShapeType_co, _DType_co], ): def __getitem__( - self, key: _IndexKeyLike, / + self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], / ) -> CustomArrayIndexable[Any, _DType_co]: ... @@ -282,6 +283,18 @@ def test_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType] test_duck_array_typevar(numpy_a) test_duck_array_typevar(custom_a) + # Test numpy's array api: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp + + arraypi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64)) + test_duck_array_typevar(arraypi_a) + def test_new_namedarray() -> None: dtype_float = np.dtype(np.float32) From fc5590f2cda3ba6c60211abc6a47d197c7022ca2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Oct 2023 14:01:15 +0000 Subject: [PATCH 023/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_typing.py | 2 +- xarray/namedarray/core.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 59b34c62d8d..3d4154ac0d2 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -17,7 +17,7 @@ import numpy as np if TYPE_CHECKING: - from numpy.typing import NDArray + pass # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index e82eecbf951..6d31756dc60 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -40,7 +40,6 @@ from numpy.typing import ArrayLike, NDArray from xarray.namedarray._typing import ( - DuckArray, _AttrsLike, _Axes, _AxisLike, @@ -50,7 +49,6 @@ _DimsLike, _DimsLikeAgg, _IntOrUnknown, - _ScalarType, _Shape, _ShapeType, duckarray, From 80a08bf839239fac8544f740076ad5e122ea4ad2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:08:40 +0200 Subject: [PATCH 024/367] Update core.py --- xarray/namedarray/core.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index e82eecbf951..35ae1dddc75 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -83,6 +83,10 @@ ) +def _normalize_dimensions(dims: _DimsLike) -> _Dims: + return (dims,) if isinstance(dims, str) else tuple(dims) + + def _dims_to_axis( x: NamedArray[Any, Any], dims: _Dims | None, axis: _AxisLike | None ) -> _Axes | None: @@ -556,7 +560,7 @@ def dims(self, value: _DimsLike) -> None: self._dims = self._parse_dimensions(value) def _parse_dimensions(self, dims: _DimsLike) -> _Dims: - dims = (dims,) if isinstance(dims, str) else tuple(dims) + dims = _normalize_dimensions(dims) if len(dims) != self.ndim: raise ValueError( f"dimensions {dims} must have the same length as the " @@ -794,7 +798,7 @@ def reduce( if dim is None: d = dim else: - d = self._parse_dimensions(dim) + d = _normalize_dimensions(dim) axislike: _AxisLike | None if axis is None or isinstance(axis, int): From 508a9aeff918c7b984bac9c980fd2ce161544c43 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:46:13 +0200 Subject: [PATCH 025/367] more --- xarray/namedarray/_array_api.py | 4 +--- xarray/namedarray/core.py | 12 +++++------- xarray/tests/test_namedarray.py | 2 +- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index f48dacbb045..5348e24423b 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -35,9 +35,7 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return np -def _to_nxp( - x: duckarray[_ShapeType, _DType] -) -> tuple[ModuleType, _arrayapi[_ShapeType, _DType]]: +def _to_nxp(x: duckarray[_ShapeType, _DType]) -> tuple[ModuleType, nxp.Array]: return nxp, nxp.asarray(x) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index b201132b117..83351c6478f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -788,15 +788,13 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - - if isinstance(dim, EllipsisType): - # TODO: What's the point of ellipsis? Use either ... or None? - dim = None d: _Dims | None - if dim is None: - d = dim + if dim is None or dim is ...: + # TODO: What's the point of ellipsis? Use either ... or None? + d = None else: - d = _normalize_dimensions(dim) + dimslike: _DimsLike = dim # type: ignore[assignment] + d = _normalize_dimensions(dimslike) axislike: _AxisLike | None if axis is None or isinstance(axis, int): diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 8f4261efe08..775f8fc0fb9 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -61,7 +61,7 @@ class CustomArrayIndexable( def __getitem__( self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], / ) -> CustomArrayIndexable[Any, _DType_co]: - ... + return type(self)(array=self.array[key]) def __array_namespace__(self) -> ModuleType: return np From 296cc485af27574dba7eb38dfc9dff4c40e2750a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Oct 2023 14:47:32 +0000 Subject: [PATCH 026/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 83351c6478f..9f28de4a908 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -5,7 +5,6 @@ import sys import warnings from collections.abc import Hashable, Iterable, Mapping, Sequence -from types import EllipsisType from typing import ( TYPE_CHECKING, Any, From 96b413fbce4b551199c7474fafa730902dce61c0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:47:44 +0200 Subject: [PATCH 027/367] Update core.py --- xarray/namedarray/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 83351c6478f..a0ed595313e 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -5,7 +5,6 @@ import sys import warnings from collections.abc import Hashable, Iterable, Mapping, Sequence -from types import EllipsisType from typing import ( TYPE_CHECKING, Any, @@ -789,7 +788,7 @@ def reduce( removed. """ d: _Dims | None - if dim is None or dim is ...: + if dim is None or dim is ...: # isinstance(dim, types.EllipsisType) # TODO: What's the point of ellipsis? Use either ... or None? d = None else: From 8b13d606780763d6b95588d24ffd6e7ebb9b35af Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:48:19 +0200 Subject: [PATCH 028/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index a0ed595313e..a5f9150da25 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -788,7 +788,7 @@ def reduce( removed. """ d: _Dims | None - if dim is None or dim is ...: # isinstance(dim, types.EllipsisType) + if dim is None or dim is ...: # TODO: isinstance(dim, types.EllipsisType) # TODO: What's the point of ellipsis? Use either ... or None? d = None else: From 78bb3bc5d7988e17e9ef1f2f257db1ecd5c4a0a5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 16:51:23 +0200 Subject: [PATCH 029/367] Update _typing.py --- xarray/namedarray/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 3d4154ac0d2..3eddc49f3f3 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -71,7 +71,7 @@ def dtype(self) -> _DType_co: # TODO: np.array_api doesn't allow None for some reason, maybe they're # recommending to use expand_dims? _IndexKey = Union[int, slice, "ellipsis"] -_IndexKeys = tuple[_IndexKey | None, ...] +_IndexKeys = tuple[Union[_IndexKey, None], ...] _IndexKeyLike = Union[_IndexKey, _IndexKeys] _AttrsLike = Union[Mapping[Any, Any], None] From 2e6e09c5acbcc9b29714f703ef1218cdceacb651 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:02:45 +0200 Subject: [PATCH 030/367] Update variable.py --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4ad723325d3..26dcb5472d5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2193,7 +2193,7 @@ def rolling_window( pads[d] = (win - 1, 0) padded = var.pad(pads, mode="constant", constant_values=fill_value) - axis = tuple(self.get_axis_num(d) for d in dim) + axis = self.get_axis_num(dim) new_dims = self.dims + tuple(window_dim) return Variable( new_dims, From 8ad2264e4184537faa49faaa50c2ea1014d7184a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:31:14 +0200 Subject: [PATCH 031/367] Update core.py --- xarray/namedarray/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index a5f9150da25..ed3fbc57f04 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -41,6 +41,7 @@ from xarray.namedarray._typing import ( _AttrsLike, _Axes, + _Axis, _AxisLike, _Chunks, _Dim, @@ -679,7 +680,7 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) - def get_axis_num(self, dim: _Dims) -> _Axes: + def get_axis_num(self, dim: _Dims) -> _Axis | _Axes: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -695,7 +696,7 @@ def get_axis_num(self, dim: _Dims) -> _Axes: if not isinstance(dim, str) and isinstance(dim, Iterable): return tuple(self._get_axis_num(d) for d in dim) else: - return (self._get_axis_num(dim),) + return self._get_axis_num(dim) def _get_axis_num(self, dim: _Dim) -> int: try: From 6dd6ebb9f368a6100f1c16e46a65420c56befb84 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:32:17 +0200 Subject: [PATCH 032/367] Update core.py --- xarray/namedarray/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index ed3fbc57f04..6db6b494bbb 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -41,7 +41,6 @@ from xarray.namedarray._typing import ( _AttrsLike, _Axes, - _Axis, _AxisLike, _Chunks, _Dim, @@ -680,7 +679,7 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) - def get_axis_num(self, dim: _Dims) -> _Axis | _Axes: + def get_axis_num(self, dim: _Dims) -> _AxisLike: """Return axis number(s) corresponding to dimension(s) in this array. Parameters From 25ee56eca6bff9c0bcbaddfcb0ad9dca66436ea6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 17:49:10 +0200 Subject: [PATCH 033/367] Update core.py --- xarray/namedarray/core.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 6db6b494bbb..71668088cac 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -41,6 +41,7 @@ from xarray.namedarray._typing import ( _AttrsLike, _Axes, + _Axis, _AxisLike, _Chunks, _Dim, @@ -679,7 +680,15 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) - def get_axis_num(self, dim: _Dims) -> _AxisLike: + @overload + def get_axis_num(self, dim: _Dims) -> _Axes: + ... + + @overload + def get_axis_num(self, dim: _Dim) -> _Axis: + ... + + def get_axis_num(self, dim: _Dim | _Dims) -> _Axis | _Axes: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -792,7 +801,7 @@ def reduce( # TODO: What's the point of ellipsis? Use either ... or None? d = None else: - dimslike: _DimsLike = dim # type: ignore[assignment] + dimslike: _DimsLike = dim d = _normalize_dimensions(dimslike) axislike: _AxisLike | None From 1d554ad1eba1e84cf5b41e04aee5f44ec3f6047c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 19:50:27 +0200 Subject: [PATCH 034/367] Update common.py --- xarray/core/common.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index ab8a4d84261..3a9ec14ac10 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -198,7 +198,15 @@ def __iter__(self: Any) -> Iterator[Any]: raise TypeError("iteration over a 0-d array") return self._iter() - def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + @overload + def get_axis_num(self, dim: Hashable) -> int: + ... + + @overload + def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: + ... + + def get_axis_num(self, dim: str | Iterable[Hashable]) -> int | tuple[int, ...]: """Return axis number(s) corresponding to dimension(s) in this array. Parameters From b4b44a8f1d665fe45aef4c51d31a8ffb744dbc4f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Oct 2023 19:54:18 +0200 Subject: [PATCH 035/367] Update common.py --- xarray/core/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 3a9ec14ac10..a8f8b1d54c2 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -206,7 +206,7 @@ def get_axis_num(self, dim: Hashable) -> int: def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... - def get_axis_num(self, dim: str | Iterable[Hashable]) -> int | tuple[int, ...]: + def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: """Return axis number(s) corresponding to dimension(s) in this array. Parameters From 181ad35e4d2915854e5f111052cf82af79228dd1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Oct 2023 04:52:20 +0000 Subject: [PATCH 036/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index c2e53f94602..70a3896c24a 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -15,7 +15,6 @@ _ShapeType, _SupportsImag, _SupportsReal, - duckarray, ) from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims From 078cb7027cc9162d0889484d44da1df076f0ceda Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 22 Oct 2023 07:17:06 +0200 Subject: [PATCH 037/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index c2e53f94602..674c8516f1b 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -15,17 +15,9 @@ _ShapeType, _SupportsImag, _SupportsReal, - duckarray, ) from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims -with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - r"The numpy.array_api submodule is still experimental", - category=UserWarning, - ) - import numpy.array_api as nxp with warnings.catch_warnings(): warnings.filterwarnings( From 487823a958b585e23ec7aca63edea47f604b9fd4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Oct 2023 05:18:01 +0000 Subject: [PATCH 038/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 674c8516f1b..56f9ea7e68d 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -18,7 +18,6 @@ ) from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims - with warnings.catch_warnings(): warnings.filterwarnings( "ignore", From adabb3c76652a993617f35991e83aa57df02ba70 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 23 Oct 2023 20:19:41 +0200 Subject: [PATCH 039/367] Update _typing.py --- xarray/namedarray/_typing.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 3eddc49f3f3..917fdf4f002 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -113,6 +113,11 @@ def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: def __array__(self, dtype: _DType, /) -> np.ndarray[Any, _DType]: ... + def __array__( + self, dtype: _DType | None = ..., / + ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: + ... + # Corresponds to np.typing.NDArray: _Array = _array[Any, np.dtype[_ScalarType_co]] @@ -134,11 +139,10 @@ def __getitem__(self, key: _IndexKeyLike) -> Any: @overload def __getitem__( - self, key: (_arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...]) + self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...] ) -> _arrayfunction[Any, _DType_co]: ... - @overload def __getitem__( self, key: _IndexKeyLike | _arrayfunction[Any, Any], @@ -182,8 +186,23 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType """ # TODO: Only integer _arrayapi: + # def __getitem__( + # self, + # key: Union[ + # int, + # slice, + # "ellipsis", + # tuple[Union[int, slice, "ellipsis", None], ...], + # _arrayapi[Any, Any], + # ], + # /, + # ) -> _arrayapi[Any, _DType_co]: + # ... + def __getitem__( - self, key: _IndexKeyLike | _arrayapi[Any, Any], / + self, + key: Any, + /, ) -> _arrayapi[Any, _DType_co]: ... From 53bb1b5368c9930dc0eca7b82d6c6da1d18fa4ab Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 23 Oct 2023 21:59:12 +0200 Subject: [PATCH 040/367] Update common.py --- xarray/core/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index a8f8b1d54c2..9d55f64c473 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -199,11 +199,11 @@ def __iter__(self: Any) -> Iterator[Any]: return self._iter() @overload - def get_axis_num(self, dim: Hashable) -> int: + def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ... @overload - def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: + def get_axis_num(self, dim: Hashable) -> int: ... def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: From 976903eece0b6d67562f4b7b2c89384dbdbf4d93 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 29 Oct 2023 12:04:10 +0100 Subject: [PATCH 041/367] fixes --- xarray/namedarray/_typing.py | 4 +- xarray/namedarray/core.py | 94 ++++++++++++++++++++++++--------- xarray/namedarray/utils.py | 10 ---- xarray/tests/test_namedarray.py | 10 ++-- 4 files changed, 79 insertions(+), 39 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index ea45d72ab79..66c238e45f1 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,11 +1,13 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum from types import ModuleType from typing import ( TYPE_CHECKING, Any, Callable, + Final, Protocol, SupportsIndex, TypeVar, @@ -63,7 +65,7 @@ def dtype(self) -> _DType_co: _Dim = Hashable _Dims = tuple[_Dim, ...] -_DimsLike = Union[str, Iterable[_Dim]] +_DimsLike = Union[str, Iterable[_Dim], Default] _DimsLikeAgg = Union[_DimsLike, "ellipsis", None] diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 382961cfe4c..c7fa69692fb 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -25,6 +25,7 @@ _arrayapi, _arrayfunction_or_api, _chunkedarray, + _default, _dtype, _DType_co, _ScalarType_co, @@ -32,7 +33,7 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array +from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray @@ -49,11 +50,12 @@ _DimsLikeAgg, _DType, _IntOrUnknown, + _ScalarType, _Shape, _ShapeType, duckarray, + Default, ) - from xarray.namedarray.utils import Default try: from dask.typing import ( @@ -82,12 +84,33 @@ def _normalize_dimensions(dims: _DimsLike) -> _Dims: - return (dims,) if isinstance(dims, str) else tuple(dims) + """ + Normalize dimensions. + + Examples + -------- + >>> _normalize_dimensions(None) + (None,) + >>> _normalize_dimensions(1) + (1,) + >>> _normalize_dimensions("2") + ('2',) + >>> _normalize_dimensions(("time",)) + ('time', ) + >>> _normalize_dimensions(["time"]) + ('time',) + >>> _normalize_dimensions([("time", "x", "y")]) + (('time', 'x', 'y'),) + """ + if isinstance(dims, str) or not isinstance(dims, Iterable): + return (dims,) + + return tuple(dims) def _dims_to_axis( - x: NamedArray[Any, Any], dims: _Dims | None, axis: _AxisLike | None -) -> _Axes | None: + x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None +) -> _AxisLike | None: """ Convert dims to axis indices. @@ -100,12 +123,11 @@ def _dims_to_axis( (0,) >>> _dims_to_axis(narr, None, None) """ - if dims is not None and axis is not None: + if dims is not _default and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim' arguments") - if dims is not None: - dims_: _Dims = (dims,) if isinstance(dims, str) else dims - return x.get_axis_num(dims_) + if dims is not _default: + return x._dims_to_axes(dims) if isinstance(axis, int): return (axis,) @@ -116,7 +138,7 @@ def _dims_to_axis( def _get_remaining_dims( x: NamedArray[Any, _DType], data: duckarray[Any, _DType], - axis: _Axes | None, + axis: _AxisLike | None, *, keepdims: bool, ) -> tuple[_Dims, duckarray[Any, _DType]]: @@ -223,6 +245,15 @@ def _new( return cls_(dims_, data, attrs_) +@overload +def from_array( + dims: _DimsLike, + data: np.ma.masked_array[_ShapeType, _DType], + attrs: _AttrsLike = ..., +) -> NamedArray[_ShapeType, _DType]: + ... + + @overload def from_array( dims: _DimsLike, @@ -267,14 +298,18 @@ def from_array( ) # TODO: dask.array.ma.masked_array also exists, better way? - if isinstance(data, np.ma.MaskedArray): - mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + reveal_type(data) + if isinstance(data, np.ma.masked_array): + data_masked = cast("np.ma.masked_array[_ShapeType, _DType]", data) + reveal_type(data_masked) + + mask = np.ma.getmaskarray(data_masked) # type: ignore[no-untyped-call] if mask.any(): # TODO: requires refactoring/vendoring xarray.core.dtypes and # xarray.core.duck_array_ops raise NotImplementedError("MaskedArray is not supported yet") - return NamedArray(dims, data, attrs) + return NamedArray(dims, data_masked, attrs) if isinstance(data, _arrayfunction_or_api): return NamedArray(dims, data, attrs) @@ -716,14 +751,20 @@ def _dask_finalize( return type(self)(self._dims, data, attrs=self._attrs) @overload - def get_axis_num(self, dim: _Dims) -> _Axes: + def _dims_to_axes(self, dims: _Dims) -> _Axes: + ... + + @overload + def _dims_to_axes(self, dims: _Dim) -> _Axis: ... @overload - def get_axis_num(self, dim: _Dim) -> _Axis: + def _dims_to_axes(self, dims: Default = _default) -> None: ... - def get_axis_num(self, dim: _Dim | _Dims) -> _Axis | _Axes: + def _dims_to_axes( + self, dims: _Dims | _Dim | Default = _default + ) -> _Axes | _Axis | None: """Return axis number(s) corresponding to dimension(s) in this array. Parameters @@ -736,12 +777,15 @@ def get_axis_num(self, dim: _Dim | _Dims) -> _Axis | _Axes: int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if not isinstance(dim, str) and isinstance(dim, Iterable): - return tuple(self._get_axis_num(d) for d in dim) - else: - return self._get_axis_num(dim) + if dims is _default: + return None + + if isinstance(dims, tuple): + return tuple(self._dim_to_axis(d) for d in dims) + + return self._dim_to_axis(dims) - def _get_axis_num(self, dim: _Dim) -> int: + def _dim_to_axis(self, dim: _Dim) -> int: try: out = self.dims.index(dim) return out @@ -798,7 +842,7 @@ def sizes(self) -> dict[_Dim, _IntOrUnknown]: def reduce( self, func: Callable[..., Any], - dim: _DimsLikeAgg = None, + dim: _DimsLikeAgg | Default = _default, axis: int | Sequence[int] | None = None, # TODO: Use _AxisLike keepdims: bool = False, **kwargs: Any, @@ -915,7 +959,7 @@ def _as_sparse( data = as_sparse(astype(self, dtype).data, fill_value=fill_value) return self._replace(data=data) - def _to_dense(self) -> Self: + def _to_dense(self) -> NamedArray[Any, _DType_co]: """ Change backend from sparse to np.array """ @@ -923,8 +967,8 @@ def _to_dense(self) -> Self: if isinstance(self._data, _sparsearrayfunction_or_api): # return self._replace(data=self._data.todense()) - data_: np.ndarray[Any, _DType_co] = self._data.todense() - return self._replace(data=data_) + data_: np.ndarray[Any, Any] = self._data.todense() + return self._new(data=data_) else: raise TypeError("self.data is not a sparse array") diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 03eb0134231..21b55ddb249 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -2,11 +2,9 @@ import sys from collections.abc import Hashable -from enum import Enum from typing import ( TYPE_CHECKING, Any, - Final, ) import numpy as np @@ -31,14 +29,6 @@ DaskCollection: Any = NDArray # type: ignore -# Singleton type, as per https://github.com/python/typing/pull/240 -class Default(Enum): - token: Final = 0 - - -_default = Default.token - - def module_available(module: str) -> bool: """Checks whether a module is installed without importing it. diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index cc706b4ecbc..87400c03480 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -9,9 +9,13 @@ import pytest from xarray.core.indexing import ExplicitlyIndexed -from xarray.namedarray._typing import _arrayfunction_or_api, _DType_co, _ShapeType_co +from xarray.namedarray._typing import ( + _arrayfunction_or_api, + _default, + _DType_co, + _ShapeType_co, +) from xarray.namedarray.core import NamedArray, from_array -from xarray.namedarray.utils import _default if TYPE_CHECKING: from types import ModuleType @@ -25,8 +29,8 @@ _IndexKeyLike, _Shape, duckarray, + Default, ) - from xarray.namedarray.utils import Default class CustomArrayBase(Generic[_ShapeType_co, _DType_co]): From 552e025f866dbd6ced636a70cf9b2ffd2e83b64f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 11:04:50 +0000 Subject: [PATCH 042/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_typing.py | 2 -- xarray/namedarray/core.py | 2 +- xarray/tests/test_namedarray.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 66c238e45f1..37c5e4ca5de 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,13 +1,11 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence -from enum import Enum from types import ModuleType from typing import ( TYPE_CHECKING, Any, Callable, - Final, Protocol, SupportsIndex, TypeVar, diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c7fa69692fb..381da5f4313 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -39,6 +39,7 @@ from numpy.typing import ArrayLike, NDArray from xarray.namedarray._typing import ( + Default, _AttrsLike, _Axes, _Axis, @@ -54,7 +55,6 @@ _Shape, _ShapeType, duckarray, - Default, ) try: diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 87400c03480..3afbbce9274 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -23,13 +23,13 @@ from numpy.typing import ArrayLike, DTypeLike, NDArray from xarray.namedarray._typing import ( + Default, _AttrsLike, _DimsLike, _DType, _IndexKeyLike, _Shape, duckarray, - Default, ) From 4540906719f04d4a30078c902be3ed0cd29972e4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 29 Oct 2023 12:11:35 +0100 Subject: [PATCH 043/367] Update _typing.py --- xarray/namedarray/_typing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 37c5e4ca5de..55e8645b61c 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence +import Enum from types import ModuleType from typing import ( TYPE_CHECKING, @@ -12,6 +13,7 @@ Union, overload, runtime_checkable, + Final, ) import numpy as np @@ -20,6 +22,13 @@ pass +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token + # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) From f9634c478373d07a23d27b732221aa8e10ff9bcf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Oct 2023 11:12:15 +0000 Subject: [PATCH 044/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 55e8645b61c..7537aae6860 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,21 +1,21 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence -import Enum from types import ModuleType from typing import ( TYPE_CHECKING, Any, Callable, + Final, Protocol, SupportsIndex, TypeVar, Union, overload, runtime_checkable, - Final, ) +import Enum import numpy as np if TYPE_CHECKING: From 25acb13e81cdae0c9aa56c9229104bbe6206c2f2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 29 Oct 2023 12:16:34 +0100 Subject: [PATCH 045/367] Update _typing.py --- xarray/namedarray/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 7537aae6860..e114fe0a225 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable, Mapping, Sequence +from enum import Enum from types import ModuleType from typing import ( TYPE_CHECKING, @@ -15,7 +16,6 @@ runtime_checkable, ) -import Enum import numpy as np if TYPE_CHECKING: From 77a6e95659796592431598ed25b2690132efe4e4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 29 Oct 2023 12:28:07 +0100 Subject: [PATCH 046/367] more --- xarray/core/variable.py | 5 +++++ xarray/namedarray/core.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f18c4044f40..42b818f0d2c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1746,6 +1746,11 @@ def reduce( # type: ignore[override] Array with summarized data and the indicated dimension(s) removed. """ + if dim is None: + from xarray.namedarray._typing import _default + + dim = _default + keep_attrs_ = ( _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs ) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 381da5f4313..ed64e3ef95f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -96,7 +96,7 @@ def _normalize_dimensions(dims: _DimsLike) -> _Dims: >>> _normalize_dimensions("2") ('2',) >>> _normalize_dimensions(("time",)) - ('time', ) + ('time',) >>> _normalize_dimensions(["time"]) ('time',) >>> _normalize_dimensions([("time", "x", "y")]) From 84a3d15e73b05e967e5e9777accc4b839f0a6082 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 31 Oct 2023 20:34:46 +0100 Subject: [PATCH 047/367] Update _typing.py --- xarray/namedarray/_typing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index cdb658e0fc4..fee818465dc 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -136,6 +136,9 @@ def __getitem__( key: _IndexKeyLike | _arrayfunction[Any, Any], /, ) -> _arrayfunction[Any, _DType_co] | Any: + ... + + @overload def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]: ... From 6b3d0e0719a2ee03f19f1547ac037d35499b3b8f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 2 Nov 2023 23:46:37 +0100 Subject: [PATCH 048/367] add expand_dims --- xarray/namedarray/_array_api.py | 58 ++++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 56f9ea7e68d..330c34d8b4d 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -8,13 +8,17 @@ from xarray.namedarray._typing import ( _arrayapi, + _Axis, _AxisLike, + _Dim, _Dims, _DType, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, + Default, + _default, ) from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims @@ -148,6 +152,52 @@ def real( return out +# %% Manipulation functions +def expand_dims( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dim | Default = _default, + axis: _Axis | None = None, +) -> NamedArray[Any, _DType]: + """ + Expands the shape of an array by inserting a new dimension of size one at the + position specified by dims. + + Parameters + ---------- + x : + Array to expand. + dims : + Dimension name. New dimension will be stored in the 0 position. + axis : + Axis position (zero-based). If x has rank (i.e, number of dimensions) N, + a valid axis must reside on the closed-interval [-N-1, N]. If provided a + negative axis, the axis position at which to insert a singleton dimension + must be computed as N + axis + 1. Hence, if provided -1, the resolved axis + position must be N (i.e., a singleton dimension must be appended to the + input array x). If provided -N-1, the resolved axis position must be 0 + (i.e., a singleton dimension must be prepended to the input array x). + + Returns + ------- + out : + An expanded output array having the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> expand_dims(x, dims="new_dim") + + Array([[[1., 2.], + [3., 4.]]], dtype=float64) + """ + xp = _get_data_namespace(x) + dims_new = (dims,) + x.dims + out = x._new(dims=dims_new, data=xp.expand_dims(x._data, axis=0)) + return out + + # %% Statistical Functions @@ -155,7 +205,7 @@ def mean( x: NamedArray[Any, _DType], /, *, - dims: _Dims | None = None, + dims: _Dims | Default = _default, keepdims: bool = False, axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: @@ -218,7 +268,7 @@ def mean( return out -# if __name__ == "__main__": -# import doctest +if __name__ == "__main__": + import doctest -# doctest.testmod() + doctest.testmod() From 2e854444f6964487adc10a6c1abff3d488369485 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Nov 2023 22:47:25 +0000 Subject: [PATCH 049/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 330c34d8b4d..e51b2f76a97 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -7,9 +7,11 @@ import numpy as np from xarray.namedarray._typing import ( + Default, _arrayapi, _Axis, _AxisLike, + _default, _Dim, _Dims, _DType, @@ -17,8 +19,6 @@ _ShapeType, _SupportsImag, _SupportsReal, - Default, - _default, ) from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims From 0856521388ab877016168e012c53033b142faf8b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:50:35 +0100 Subject: [PATCH 050/367] more --- xarray/namedarray/_array_api.py | 18 ++++++++++++------ xarray/namedarray/core.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index e51b2f76a97..a5fff1b93c4 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -20,7 +20,12 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.core import NamedArray, _dims_to_axis, _get_remaining_dims +from xarray.namedarray.core import ( + NamedArray, + _assert_either_dim_or_axis, + _dims_to_axis, + _get_remaining_dims, +) with warnings.catch_warnings(): warnings.filterwarnings( @@ -157,8 +162,8 @@ def expand_dims( x: NamedArray[Any, _DType], /, *, - dims: _Dim | Default = _default, - axis: _Axis | None = None, + dim: _Dim | Default = _default, + axis: _Axis = 0, ) -> NamedArray[Any, _DType]: """ Expands the shape of an array by inserting a new dimension of size one at the @@ -168,7 +173,7 @@ def expand_dims( ---------- x : Array to expand. - dims : + dim : Dimension name. New dimension will be stored in the 0 position. axis : Axis position (zero-based). If x has rank (i.e, number of dimensions) N, @@ -193,8 +198,9 @@ def expand_dims( [3., 4.]]], dtype=float64) """ xp = _get_data_namespace(x) - dims_new = (dims,) + x.dims - out = x._new(dims=dims_new, data=xp.expand_dims(x._data, axis=0)) + d = list(x.dims) + d.insert(axis, dim) + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) return out diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 03e6547241b..f07c2023345 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -109,6 +109,13 @@ def _normalize_dimensions(dims: _DimsLike) -> _Dims: return tuple(dims) +def _assert_either_dim_or_axis( + dims: _Dim | _Dims | Default, axis: _AxisLike | None +) -> None: + if dims is not _default and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim(s)' arguments") + + def _dims_to_axis( x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None ) -> _AxisLike | None: @@ -124,8 +131,7 @@ def _dims_to_axis( (0,) >>> _dims_to_axis(narr, None, None) """ - if dims is not _default and axis is not None: - raise ValueError("cannot supply both 'axis' and 'dim' arguments") + _assert_either_dim_or_axis(dims, axis) if dims is not _default: return x._dims_to_axes(dims) From 3cdaffbc2821944435a9f56d00691ec31168c610 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Nov 2023 23:51:16 +0000 Subject: [PATCH 051/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index a5fff1b93c4..f9f29beefe8 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -22,7 +22,6 @@ ) from xarray.namedarray.core import ( NamedArray, - _assert_either_dim_or_axis, _dims_to_axis, _get_remaining_dims, ) From 460a1cdc8479999dba805ac66e814f42d1a0eccd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 3 Nov 2023 16:13:18 +0100 Subject: [PATCH 052/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 807a657e8b3..df5c6ff6528 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -303,7 +303,7 @@ def from_array( # xarray.core.duck_array_ops raise NotImplementedError("MaskedArray is not supported yet") - return NamedArray(dims, data_masked, attrs) + return NamedArray(dims, data, attrs) if isinstance(data, _arrayfunction_or_api): return NamedArray(dims, data, attrs) From 603ec86def3188c73b4324645c318b8c1b541f0f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Dec 2023 22:39:53 +0000 Subject: [PATCH 053/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 740b052b317..0aa5b65aa43 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -278,4 +278,3 @@ def mean( import doctest doctest.testmod() - From c71c9c41141a4212d807d9a7f261af7917f6892c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:10:22 +0200 Subject: [PATCH 054/367] Add ci for array-api-tests --- .github/workflows/array-api-tests.yml | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/workflows/array-api-tests.yml diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml new file mode 100644 index 00000000000..e0dc9168f96 --- /dev/null +++ b/.github/workflows/array-api-tests.yml @@ -0,0 +1,32 @@ +name: Test NamedArray._array_api + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + + steps: + - name: Checkout array-api-tests + uses: actions/checkout@v1 + with: + submodules: 'true' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install xarray + python -m pip install -r requirements.txt + - name: Run the test suite + env: + ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api + ARRAY_API_STRICT_API_VERSION: 2023.12 + run: | + pytest -v -rxXfE array_api_tests/ From 192b1b5b62e1e15e18a0266b297418e7430dd442 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 13:11:00 +0000 Subject: [PATCH 055/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 2 +- xarray/namedarray/core.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 9f2dc8ae966..63d57b2e85f 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -292,4 +292,4 @@ def mean( # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) - return out \ No newline at end of file + return out diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 5f32581c6e6..773a5ba0a30 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -52,12 +52,10 @@ if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray - from xarray.core.types import Dims, T_Chunks + from xarray.core.types import T_Chunks from xarray.namedarray._typing import ( Default, _AttrsLike, - _Axes, - _Axis, _AxisLike, _Chunks, _Dim, From 9407a2a3547d1f218b6888c245ed1597f1af05a1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:19:25 +0200 Subject: [PATCH 056/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 36 ++++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index e0dc9168f96..b0c42d20bfa 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -1,32 +1,44 @@ -name: Test NamedArray._array_api +name: Array API Tests on: [push, pull_request] -jobs: - build: +env: + PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" + API_VERSIONS: "2023.12" +jobs: + array-api-tests: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.12"] + python-version: ['3.12'] + exclude: + - python-version: '3.8' steps: + - name: Checkout xarray + uses: actions/checkout@v4 + with: + path: xarray - name: Checkout array-api-tests - uses: actions/checkout@v1 + uses: actions/checkout@v4 with: + repository: data-apis/array-api-tests submodules: 'true' + path: array-api-tests - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Install Dependencies run: | python -m pip install --upgrade pip - python -m pip install xarray - python -m pip install -r requirements.txt - - name: Run the test suite + python -m pip install ${GITHUB_WORKSPACE}/xarray + python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + python -m pip install hypothesis + - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api - ARRAY_API_STRICT_API_VERSION: 2023.12 run: | - pytest -v -rxXfE array_api_tests/ + cd ${GITHUB_WORKSPACE}/array-api-tests + pytest array_api_tests/ ${PYTEST_ARGS} From 2f721ae2eaabc4ce43d09d442af2f8f76c1d03e3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:26:37 +0200 Subject: [PATCH 057/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index b0c42d20bfa..ca6ded0d52d 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -1,4 +1,4 @@ -name: Array API Tests +name: Array API Tests - NamedArray._array_api on: [push, pull_request] From 51359b9e1f96277e8da1194cd5da4d9d860c9504 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:28:40 +0200 Subject: [PATCH 058/367] Add dtypes --- xarray/namedarray/_array_api.py | 93 +++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 63d57b2e85f..bb52fb15743 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -34,9 +34,98 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return np -# %% Creation Functions +# %% Dtypes +# TODO: should delegate to underlying array? Cubed doesn't at the moment. +int8 = np.int8 +int16 = np.int16 +int32 = np.int32 +int64 = np.int64 +uint8 = np.uint8 +uint16 = np.uint16 +uint32 = np.uint32 +uint64 = np.uint64 +float32 = np.float32 +float64 = np.float64 +complex64 = np.complex64 +complex128 = np.complex128 +bool = np.bool + +_all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, + bool, +) +_boolean_dtypes = (bool,) +_real_floating_dtypes = (float32, float64) +_floating_dtypes = (float32, float64, complex64, complex128) +_complex_floating_dtypes = (complex64, complex128) +_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_signed_integer_dtypes = (int8, int16, int32, int64) +_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_real_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_numeric_dtypes = ( + float32, + float64, + complex64, + complex128, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) + +_dtype_categories = { + "all": _all_dtypes, + "real numeric": _real_numeric_dtypes, + "numeric": _numeric_dtypes, + "integer": _integer_dtypes, + "integer or boolean": _integer_or_boolean_dtypes, + "boolean": _boolean_dtypes, + "real floating-point": _floating_dtypes, + "complex floating-point": _complex_floating_dtypes, + "floating-point": _floating_dtypes, +} + + +# %% Creation Functions def astype( x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True ) -> NamedArray[_ShapeType, _DType]: @@ -226,8 +315,6 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT # %% Statistical Functions - - def mean( x: NamedArray[Any, _DType], /, From 862f6e6adee5ce061059e166a5bcaf4e10179db4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:28:53 +0200 Subject: [PATCH 059/367] Add constants --- xarray/namedarray/_array_api.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index bb52fb15743..c9adb603eaa 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -123,6 +123,12 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: "floating-point": _floating_dtypes, } +# %% Constants +e = np.e +inf = np.inf +nan = np.nan +newaxis = np.newaxis +pi = np.pi # %% Creation Functions From ef3aadd82824c8506934165ce21c368cf172a012 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 13:29:41 +0000 Subject: [PATCH 060/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index c9adb603eaa..e485ffb877f 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -34,7 +34,6 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return np - # %% Dtypes # TODO: should delegate to underlying array? Cubed doesn't at the moment. int8 = np.int8 From fa49ec970aca84c713dc3f81485759335f92351c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:42:02 +0200 Subject: [PATCH 061/367] Create namedarray_array_api_skips.txt --- xarray/tests/namedarray_array_api_skips.txt | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 xarray/tests/namedarray_array_api_skips.txt diff --git a/xarray/tests/namedarray_array_api_skips.txt b/xarray/tests/namedarray_array_api_skips.txt new file mode 100644 index 00000000000..c6a15187421 --- /dev/null +++ b/xarray/tests/namedarray_array_api_skips.txt @@ -0,0 +1,5 @@ +# Known failures for the array api tests. + +# Test suite attempts in-place mutation: +array_api_tests/test_array_object.py::test_setitem +array_api_tests/test_array_object.py::test_setitem_masking From 694e2acd163227e82a6ba9cb8601d510f7757dd4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:42:05 +0200 Subject: [PATCH 062/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ca6ded0d52d..7af44c0683b 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -3,7 +3,7 @@ name: Array API Tests - NamedArray._array_api on: [push, pull_request] env: - PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 200" + PYTEST_ARGS: "-v -rxXfE --ci --hypothesis-disable-deadline --max-examples 5" API_VERSIONS: "2023.12" jobs: @@ -41,4 +41,4 @@ jobs: ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests/ ${PYTEST_ARGS} + pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/xarray/tests/namedarray_array_api_skips.txt ${PYTEST_ARGS} From 02a8c38efa6642c469589ff84ddaac89ac721423 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:43:39 +0200 Subject: [PATCH 063/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 7af44c0683b..3f915662a73 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -41,4 +41,4 @@ jobs: ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/xarray/tests/namedarray_array_api_skips.txt ${PYTEST_ARGS} + pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/tests/namedarray_array_api_skips.txt ${PYTEST_ARGS} From 66f94d237988c8f710e6b81e35622d4b974461e0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:54:46 +0200 Subject: [PATCH 064/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 3f915662a73..ac59b5749cd 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -19,7 +19,7 @@ jobs: - name: Checkout xarray uses: actions/checkout@v4 with: - path: xarray + path: xarray-tests - name: Checkout array-api-tests uses: actions/checkout@v4 with: From a25696230e9f487b0566fe7ade37713d7a01a148 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:56:48 +0200 Subject: [PATCH 065/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index ac59b5749cd..9ef65ddc34a 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -19,7 +19,7 @@ jobs: - name: Checkout xarray uses: actions/checkout@v4 with: - path: xarray-tests + path: xarray - name: Checkout array-api-tests uses: actions/checkout@v4 with: @@ -41,4 +41,4 @@ jobs: ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/tests/namedarray_array_api_skips.txt ${PYTEST_ARGS} + pytest array_api_tests/ --skips-file ${GITHUB_WORKSPACE}/xarray/xarray/tests/namedarray_array_api_skips.txt ${PYTEST_ARGS} From 2ebff05542ed6b1d48f4913da04e046f7dc01ac7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 16:55:24 +0200 Subject: [PATCH 066/367] Add asarray --- xarray/namedarray/_array_api.py | 125 +++++++++++++++++++++++++++++++- xarray/namedarray/_typing.py | 2 + xarray/namedarray/core.py | 13 ++-- 3 files changed, 133 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index e485ffb877f..27ebd80b657 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -1,21 +1,27 @@ from __future__ import annotations from types import ModuleType -from typing import Any +from typing import Any, overload import numpy as np from xarray.namedarray._typing import ( Default, _arrayapi, + duckarray, + _arrayfunction_or_api, + _ArrayLike, + _AttrsLike, _Axes, _Axis, _AxisLike, _default, _Dim, _Dims, + _DimsLike, _DType, _ScalarType, + _Shape, _ShapeType, _SupportsImag, _SupportsReal, @@ -25,6 +31,9 @@ _dims_to_axis, _get_remaining_dims, ) +from xarray.namedarray.utils import ( + to_0d_object_array, +) def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: @@ -34,6 +43,9 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return np +# %% array_api version +__array_api_version__ = "2023.12" + # %% Dtypes # TODO: should delegate to underlying array? Cubed doesn't at the moment. int8 = np.int8 @@ -131,6 +143,117 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: # %% Creation Functions +def _infer_dims( + shape: _Shape, + dims: _DimsLike | Default = _default, +) -> _DimsLike: + if dims is _default: + return tuple(f"dim_{n}" for n in range(len(shape))) + else: + return dims + + +@overload +def asarray( + obj: duckarray[_ShapeType, Any], + /, + *, + dtype: _DType, + device=..., + copy: bool | None = ..., + dims: _DimsLike = ..., + attrs: _AttrsLike = ..., +) -> NamedArray[_ShapeType, _DType]: ... +@overload +def asarray( + obj: _ArrayLike, + /, + *, + dtype: _DType, + device=..., + copy: bool | None = ..., + dims: _DimsLike = ..., + attrs: _AttrsLike = ..., +) -> NamedArray[Any, _DType]: ... +@overload +def asarray( + obj: duckarray[_ShapeType, _DType], + /, + *, + dtype: None, + device=None, + copy: bool | None = None, + dims: _DimsLike = ..., + attrs: _AttrsLike = ..., +) -> NamedArray[_ShapeType, _DType]: ... +@overload +def asarray( + obj: _ArrayLike, + /, + *, + dtype: None, + device=..., + copy: bool | None = ..., + dims: _DimsLike = ..., + attrs: _AttrsLike = ..., +) -> NamedArray[Any, _DType]: ... +def asarray( + obj: duckarray[_ShapeType, _DType] | _ArrayLike, + /, + *, + dtype: _DType | None = None, + device=None, + copy: bool | None = None, + dims: _DimsLike = _default, + attrs: _AttrsLike = None, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: + """ + Create a Named array from an array-like object. + + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or ArrayLike + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + """ + data = obj + if isinstance(data, NamedArray): + raise TypeError( + "Array is already a Named array. Use 'data.data' to retrieve the data array" + ) + + # TODO: dask.array.ma.MaskedArray also exists, better way? + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and + # xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + + _dims = _infer_dims(data.shape, dims) + return NamedArray(_dims, data, attrs) + + if isinstance(data, _arrayfunction_or_api): + _dims = _infer_dims(data.shape, dims) + return NamedArray(_dims, data, attrs) + + if isinstance(data, tuple): + _data = to_0d_object_array(data) + _dims = _infer_dims(_data.shape, dims) + return NamedArray(_dims, _data, attrs) + + # validate whether the data is valid data types. + _data = np.asarray(data) + _dims = _infer_dims(_data.shape, dims) + return NamedArray(_dims, _data, attrs) + + def astype( x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True ) -> NamedArray[_ShapeType, _DType]: diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index dc53ded367f..4b26d42eaf6 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -50,6 +50,8 @@ class Default(Enum): _ScalarType = TypeVar("_ScalarType", bound=np.generic) _ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) +_ArrayLike = np.typing.ArrayLike + # A protocol for anything with the dtype attribute @runtime_checkable diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 773a5ba0a30..4d05a79a45e 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -50,11 +50,12 @@ ) if TYPE_CHECKING: - from numpy.typing import ArrayLike, NDArray + from numpy.typing import NDArray from xarray.core.types import T_Chunks from xarray.namedarray._typing import ( Default, + _ArrayLike, _AttrsLike, _AxisLike, _Chunks, @@ -274,14 +275,14 @@ def from_array( @overload def from_array( dims: _DimsLike, - data: ArrayLike, + data: _ArrayLike, attrs: _AttrsLike = ..., ) -> NamedArray[Any, Any]: ... def from_array( dims: _DimsLike, - data: duckarray[_ShapeType, _DType] | ArrayLike, + data: duckarray[_ShapeType, _DType] | _ArrayLike, attrs: _AttrsLike = None, ) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: """ @@ -291,7 +292,7 @@ def from_array( ---------- dims : str or iterable of str Name(s) of the dimension(s). - data : T_DuckArray or ArrayLike + data : T_DuckArray or _ArrayLike The actual data that populates the array. Should match the shape specified by `dims`. attrs : dict, optional @@ -1017,7 +1018,7 @@ def reduce( axislike = tuple(axis) axis_ = _dims_to_axis(self, d, axislike) - data: duckarray[Any, Any] | ArrayLike + data: duckarray[Any, Any] | _ArrayLike with warnings.catch_warnings(): warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning @@ -1061,7 +1062,7 @@ def _repr_html_(self) -> str: def _as_sparse( self, sparse_format: Literal["coo"] | Default = _default, - fill_value: ArrayLike | Default = _default, + fill_value: _ArrayLike | Default = _default, ) -> NamedArray[Any, _DType_co]: """ Use sparse-array as backend. From 72b31453831fbbfe6993ff22e9805232c866d526 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 16:55:28 +0200 Subject: [PATCH 067/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 9ef65ddc34a..c7ccef7cde5 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -1,4 +1,4 @@ -name: Array API Tests - NamedArray._array_api +name: NamedArray._array_api on: [push, pull_request] From 6cee172ae743ee0af7a69600a2ff9a4280c7998e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 14:56:09 +0000 Subject: [PATCH 068/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 27ebd80b657..69d3a745161 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -8,7 +8,6 @@ from xarray.namedarray._typing import ( Default, _arrayapi, - duckarray, _arrayfunction_or_api, _ArrayLike, _AttrsLike, @@ -25,6 +24,7 @@ _ShapeType, _SupportsImag, _SupportsReal, + duckarray, ) from xarray.namedarray.core import ( NamedArray, From 90d9dabcae1d9e9afef90a57193dbfa3650d8a78 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 17:11:41 +0200 Subject: [PATCH 069/367] add data type functions --- xarray/namedarray/_array_api.py | 221 +++++++++++++++++++------------- 1 file changed, 134 insertions(+), 87 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 27ebd80b657..6cf6e09074d 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -8,6 +8,7 @@ from xarray.namedarray._typing import ( Default, _arrayapi, + _dtype, duckarray, _arrayfunction_or_api, _ArrayLike, @@ -46,93 +47,6 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: # %% array_api version __array_api_version__ = "2023.12" -# %% Dtypes -# TODO: should delegate to underlying array? Cubed doesn't at the moment. -int8 = np.int8 -int16 = np.int16 -int32 = np.int32 -int64 = np.int64 -uint8 = np.uint8 -uint16 = np.uint16 -uint32 = np.uint32 -uint64 = np.uint64 -float32 = np.float32 -float64 = np.float64 -complex64 = np.complex64 -complex128 = np.complex128 -bool = np.bool - -_all_dtypes = ( - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - complex64, - complex128, - bool, -) -_boolean_dtypes = (bool,) -_real_floating_dtypes = (float32, float64) -_floating_dtypes = (float32, float64, complex64, complex128) -_complex_floating_dtypes = (complex64, complex128) -_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) -_signed_integer_dtypes = (int8, int16, int32, int64) -_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) -_integer_or_boolean_dtypes = ( - bool, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) -_real_numeric_dtypes = ( - float32, - float64, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) -_numeric_dtypes = ( - float32, - float64, - complex64, - complex128, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) - -_dtype_categories = { - "all": _all_dtypes, - "real numeric": _real_numeric_dtypes, - "numeric": _numeric_dtypes, - "integer": _integer_dtypes, - "integer or boolean": _integer_or_boolean_dtypes, - "boolean": _boolean_dtypes, - "real floating-point": _floating_dtypes, - "complex floating-point": _complex_floating_dtypes, - "floating-point": _floating_dtypes, -} # %% Constants e = np.e @@ -254,6 +168,97 @@ def asarray( return NamedArray(_dims, _data, attrs) +# %% Data types +# TODO: should delegate to underlying array? Cubed doesn't at the moment. +int8 = np.int8 +int16 = np.int16 +int32 = np.int32 +int64 = np.int64 +uint8 = np.uint8 +uint16 = np.uint16 +uint32 = np.uint32 +uint64 = np.uint64 +float32 = np.float32 +float64 = np.float64 +complex64 = np.complex64 +complex128 = np.complex128 +bool = np.bool + +_all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, + bool, +) +_boolean_dtypes = (bool,) +_real_floating_dtypes = (float32, float64) +_floating_dtypes = (float32, float64, complex64, complex128) +_complex_floating_dtypes = (complex64, complex128) +_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_signed_integer_dtypes = (int8, int16, int32, int64) +_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_real_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_numeric_dtypes = ( + float32, + float64, + complex64, + complex128, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) + +_dtype_categories = { + "all": _all_dtypes, + "real numeric": _real_numeric_dtypes, + "numeric": _numeric_dtypes, + "integer": _integer_dtypes, + "integer or boolean": _integer_or_boolean_dtypes, + "boolean": _boolean_dtypes, + "real floating-point": _floating_dtypes, + "complex floating-point": _complex_floating_dtypes, + "floating-point": _floating_dtypes, +} + +# %% Data type functions + + def astype( x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True ) -> NamedArray[_ShapeType, _DType]: @@ -298,6 +303,48 @@ def astype( return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] +def can_cast(from_, to, /): + if isinstance(from_, NamedArray): + xp = _get_data_namespace(type) + from_ = from_.dtype + return xp.can_cast(from_, to) + else: + raise NotImplementedError("How to retrieve xp from dtype?") + + +def finfo(type: _dtype | NamedArray[Any, Any], /): + if isinstance(type, NamedArray): + xp = _get_data_namespace(type) + return xp.finfo(type._data) + else: + raise NotImplementedError("How to retrieve xp from dtype?") + + +def iinfo(type, /): + if isinstance(type, NamedArray): + xp = _get_data_namespace(type) + return xp.iinfo(type._data) + else: + raise NotImplementedError("How to retrieve xp from dtype?") + + +def isdtype(dtype, kind): + if isinstance(dtype, NamedArray): + xp = _get_data_namespace(dtype) + + return xp.isdtype(dtype.dtype, kind) + else: + raise NotImplementedError("How to retrieve xp from dtype?") + + +def result_type(*arrays_and_dtypes): + # TODO: Empty arg? + xp = _get_data_namespace(arrays_and_dtypes[0]) + return xp.result_type( + *(a.dtype if isinstance(a, NamedArray) else a for a in arrays_and_dtypes) + ) + + # %% Elementwise Functions From 221b33238a0303db770e4afe1db02e4eaecb1e77 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 17:12:58 +0200 Subject: [PATCH 070/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 951df190c5e..5c01e27f54c 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -9,7 +9,6 @@ Default, _arrayapi, _dtype, - duckarray, _arrayfunction_or_api, _ArrayLike, _AttrsLike, From f6443ed9dce5473a1a64896d61321f0986d54a21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 15:13:42 +0000 Subject: [PATCH 071/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 5c01e27f54c..537e165e1cd 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -8,7 +8,6 @@ from xarray.namedarray._typing import ( Default, _arrayapi, - _dtype, _arrayfunction_or_api, _ArrayLike, _AttrsLike, @@ -20,6 +19,7 @@ _Dims, _DimsLike, _DType, + _dtype, _ScalarType, _Shape, _ShapeType, From 9acc94b2d2c8e3633e9d19f0890529c3c7a5ecec Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 17:40:35 +0200 Subject: [PATCH 072/367] add types --- xarray/namedarray/_array_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 537e165e1cd..ae32b732427 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -303,7 +303,7 @@ def astype( return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] -def can_cast(from_, to, /): +def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: if isinstance(from_, NamedArray): xp = _get_data_namespace(type) from_ = from_.dtype @@ -320,7 +320,7 @@ def finfo(type: _dtype | NamedArray[Any, Any], /): raise NotImplementedError("How to retrieve xp from dtype?") -def iinfo(type, /): +def iinfo(type: _dtype | NamedArray[Any, Any], /): if isinstance(type, NamedArray): xp = _get_data_namespace(type) return xp.iinfo(type._data) @@ -328,7 +328,7 @@ def iinfo(type, /): raise NotImplementedError("How to retrieve xp from dtype?") -def isdtype(dtype, kind): +def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: if isinstance(dtype, NamedArray): xp = _get_data_namespace(dtype) @@ -337,7 +337,7 @@ def isdtype(dtype, kind): raise NotImplementedError("How to retrieve xp from dtype?") -def result_type(*arrays_and_dtypes): +def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: # TODO: Empty arg? xp = _get_data_namespace(arrays_and_dtypes[0]) return xp.result_type( From a3e22e8fbc0df401e11de7c9bfe9f5ba628d4fa2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 18:04:05 +0200 Subject: [PATCH 073/367] Add elementwise functions --- xarray/namedarray/_array_api.py | 363 ++++++++++++++++++++++++++++++++ 1 file changed, 363 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index ae32b732427..0b1f7a4b718 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -346,6 +346,178 @@ def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: # %% Elementwise Functions +def abs(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.abs(x._data)) + return out + + +def acos(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.acos(x._data)) + return out + + +def acosh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.acosh(x._data)) + return out + + +def add(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.add(x1._data, x2._data)) + return out + + +def asin(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.asin(x._data)) + return out + + +def asinh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.asinh(x._data)) + return out + + +def atan(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.atan(x._data)) + return out + + +def atan2(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.atan2(x1._data, x2._data)) + return out + + +def atanh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.atanh(x._data)) + return out + + +def bitwise_and(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_and(x1._data, x2._data)) + return out + + +def bitwise_invert(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.bitwise_invert(x._data)) + return out + + +def bitwise_left_shift(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_left_shift(x1._data, x2._data)) + return out + + +def bitwise_or(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_or(x1._data, x2._data)) + return out + + +def bitwise_right_shift(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_right_shift(x1._data, x2._data)) + return out + + +def bitwise_xor(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_xor(x1._data, x2._data)) + return out + + +def ceil(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.ceil(x._data)) + return out + + +def conj(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.conj(x._data)) + return out + + +def cos(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.cos(x._data)) + return out + + +def cosh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.cosh(x._data)) + return out + + +def divide(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.divide(x1._data, x2._data)) + return out + + +def exp(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.exp(x._data)) + return out + + +def expm1(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.expm1(x._data)) + return out + + +def equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.equal(x1._data, x2._data)) + return out + + +def floor(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.floor(x._data)) + return out + + +def floor_divide(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.floor_divide(x1._data, x2._data)) + return out + + +def greater(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.greater(x1._data, x2._data)) + return out + + +def greater_equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.greater_equal(x1._data, x2._data)) + return out def imag( @@ -380,6 +552,129 @@ def imag( return out +def isfinite(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.isfinite(x._data)) + return out + + +def isinf(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.isinf(x._data)) + return out + + +def isnan(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.isnan(x._data)) + return out + + +def less(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.less(x1._data, x2._data)) + return out + + +def less_equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.less_equal(x1._data, x2._data)) + return out + + +def log(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log(x._data)) + return out + + +def log1p(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log1p(x._data)) + return out + + +def log2(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log2(x._data)) + return out + + +def log10(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log10(x._data)) + return out + + +def logaddexp(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logaddexp(x1._data, x2._data)) + return out + + +def logical_and(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logical_and(x1._data, x2._data)) + return out + + +def logical_not(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.logical_not(x._data)) + return out + + +def logical_or(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logical_or(x1._data, x2._data)) + return out + + +def logical_xor(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logical_xor(x1._data, x2._data)) + return out + + +def multiply(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.multiply(x1._data, x2._data)) + return out + + +def negative(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.negative(x._data)) + return out + + +def not_equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.not_equal(x1._data, x2._data)) + return out + + +def positive(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.positive(x._data)) + return out + + +def pow(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.pow(x1._data, x2._data)) + return out + + def real( x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var] ) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: @@ -412,6 +707,74 @@ def real( return out +def remainder(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.remainder(x1._data, x2._data)) + return out + + +def round(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.round(x._data)) + return out + + +def sign(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sign(x._data)) + return out + + +def sin(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sin(x._data)) + return out + + +def sinh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sinh(x._data)) + return out + + +def sqrt(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sqrt(x._data)) + return out + + +def square(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.square(x._data)) + return out + + +def subtract(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.subtract(x1._data, x2._data)) + return out + + +def tan(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.tan(x._data)) + return out + + +def tanh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.tanh(x._data)) + return out + + +def trunc(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.trunc(x._data)) + return out + + # %% Manipulation functions def expand_dims( x: NamedArray[Any, _DType], From 35d10e5db0b04ad7a8a30a1e0c804ef135962d1c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 18:40:51 +0200 Subject: [PATCH 074/367] ignore getitem for now --- xarray/tests/namedarray_array_api_skips.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/tests/namedarray_array_api_skips.txt b/xarray/tests/namedarray_array_api_skips.txt index c6a15187421..926d644c9fd 100644 --- a/xarray/tests/namedarray_array_api_skips.txt +++ b/xarray/tests/namedarray_array_api_skips.txt @@ -1,5 +1,8 @@ # Known failures for the array api tests. +array_api_tests/test_array_object.py::test_getitem +array_api_tests/test_array_object.py::test_getitem_masking + # Test suite attempts in-place mutation: array_api_tests/test_array_object.py::test_setitem array_api_tests/test_array_object.py::test_setitem_masking From 7aab88179b7783f598b3a926bc78ebdf5f22a569 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 18:41:39 +0200 Subject: [PATCH 075/367] get xp from dtype.__module__ --- xarray/namedarray/_array_api.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 0b1f7a4b718..b8832afdd6a 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -259,6 +259,11 @@ def asarray( # %% Data type functions +def _get_namespace_dtype(dtype: _dtype) -> ModuleType: + xp = __import__(dtype.__module__) + return xp + + def astype( x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True ) -> NamedArray[_ShapeType, _DType]: @@ -305,11 +310,13 @@ def astype( def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: if isinstance(from_, NamedArray): - xp = _get_data_namespace(type) + xp = _get_data_namespace(from_) from_ = from_.dtype return xp.can_cast(from_, to) else: - raise NotImplementedError("How to retrieve xp from dtype?") + xp = _get_namespace_dtype(from_) + from_ = from_.dtype + return xp.can_cast(from_, to) def finfo(type: _dtype | NamedArray[Any, Any], /): @@ -317,7 +324,8 @@ def finfo(type: _dtype | NamedArray[Any, Any], /): xp = _get_data_namespace(type) return xp.finfo(type._data) else: - raise NotImplementedError("How to retrieve xp from dtype?") + xp = _get_namespace_dtype(type) + return xp.finfo(type._data) def iinfo(type: _dtype | NamedArray[Any, Any], /): @@ -325,21 +333,23 @@ def iinfo(type: _dtype | NamedArray[Any, Any], /): xp = _get_data_namespace(type) return xp.iinfo(type._data) else: - raise NotImplementedError("How to retrieve xp from dtype?") + xp = _get_namespace_dtype(type) + return xp.finfo(type._data) def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: - if isinstance(dtype, NamedArray): - xp = _get_data_namespace(dtype) - - return xp.isdtype(dtype.dtype, kind) - else: - raise NotImplementedError("How to retrieve xp from dtype?") + xp = _get_namespace_dtype(type) + return xp.isdtype(dtype, kind) def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: # TODO: Empty arg? - xp = _get_data_namespace(arrays_and_dtypes[0]) + arr_or_dtype = arrays_and_dtypes[0] + if isinstance(arr_or_dtype, NamedArray): + xp = _get_data_namespace(arr_or_dtype) + else: + xp = _get_namespace_dtype(arr_or_dtype) + return xp.result_type( *(a.dtype if isinstance(a, NamedArray) else a for a in arrays_and_dtypes) ) From ccbcf1c9d64defb1e2fae51afa35b9c118c9e4da Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 18:47:26 +0200 Subject: [PATCH 076/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index b8832afdd6a..aa898ee7b64 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -325,7 +325,7 @@ def finfo(type: _dtype | NamedArray[Any, Any], /): return xp.finfo(type._data) else: xp = _get_namespace_dtype(type) - return xp.finfo(type._data) + return xp.finfo(type) def iinfo(type: _dtype | NamedArray[Any, Any], /): @@ -334,7 +334,7 @@ def iinfo(type: _dtype | NamedArray[Any, Any], /): return xp.iinfo(type._data) else: xp = _get_namespace_dtype(type) - return xp.finfo(type._data) + return xp.finfo(type) def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: From 1a71ff7bf6cc7d4bf27afcfcfac625b4feefb175 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 20:07:34 +0200 Subject: [PATCH 077/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index aa898ee7b64..827c80d8160 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -37,6 +37,7 @@ ) +# %% Helper functions def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: if isinstance(x._data, _arrayapi): return x._data.__array_namespace__() @@ -44,6 +45,11 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return np +def _get_namespace_dtype(dtype: _dtype) -> ModuleType: + xp = __import__(dtype.__module__) + return xp + + # %% array_api version __array_api_version__ = "2023.12" @@ -259,11 +265,6 @@ def asarray( # %% Data type functions -def _get_namespace_dtype(dtype: _dtype) -> ModuleType: - xp = __import__(dtype.__module__) - return xp - - def astype( x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True ) -> NamedArray[_ShapeType, _DType]: From 2a8ac9ff0f0551f82f761759b61cff5f5cb8afb3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 20:07:47 +0200 Subject: [PATCH 078/367] Add __bool__ --- xarray/namedarray/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 4d05a79a45e..4ae58bd9667 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -544,6 +544,9 @@ def __len__(self) -> _IntOrUnknown: except Exception as exc: raise TypeError("len() of unsized object") from exc + def __bool__(self, /) -> bool: + return self._data.__bool__() + @property def dtype(self) -> _DType_co: """ From db49f2bdf372ff1824bfebffdf7bd8a99352dc93 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 20:39:17 +0200 Subject: [PATCH 079/367] Add arange --- xarray/namedarray/_array_api.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 827c80d8160..39e37164f1b 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -50,6 +50,10 @@ def _get_namespace_dtype(dtype: _dtype) -> ModuleType: return xp +def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: + return np if xp is None else xp + + # %% array_api version __array_api_version__ = "2023.12" @@ -73,6 +77,21 @@ def _infer_dims( return dims +def arange( + start, + /, + stop=None, + step=1, + *, + dtype: _DType | None = None, + device=None, +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + @overload def asarray( obj: duckarray[_ShapeType, Any], From 9a852ed2408753ca78a2e6eeb8bf7a697bceaa2f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 20:39:37 +0200 Subject: [PATCH 080/367] support copy in asarray --- xarray/namedarray/_array_api.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 39e37164f1b..dcf502f8f7d 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -101,7 +101,6 @@ def asarray( device=..., copy: bool | None = ..., dims: _DimsLike = ..., - attrs: _AttrsLike = ..., ) -> NamedArray[_ShapeType, _DType]: ... @overload def asarray( @@ -112,7 +111,6 @@ def asarray( device=..., copy: bool | None = ..., dims: _DimsLike = ..., - attrs: _AttrsLike = ..., ) -> NamedArray[Any, _DType]: ... @overload def asarray( @@ -123,7 +121,6 @@ def asarray( device=None, copy: bool | None = None, dims: _DimsLike = ..., - attrs: _AttrsLike = ..., ) -> NamedArray[_ShapeType, _DType]: ... @overload def asarray( @@ -134,7 +131,6 @@ def asarray( device=..., copy: bool | None = ..., dims: _DimsLike = ..., - attrs: _AttrsLike = ..., ) -> NamedArray[Any, _DType]: ... def asarray( obj: duckarray[_ShapeType, _DType] | _ArrayLike, @@ -144,7 +140,6 @@ def asarray( device=None, copy: bool | None = None, dims: _DimsLike = _default, - attrs: _AttrsLike = None, ) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: """ Create a Named array from an array-like object. @@ -163,9 +158,10 @@ def asarray( """ data = obj if isinstance(data, NamedArray): - raise TypeError( - "Array is already a Named array. Use 'data.data' to retrieve the data array" - ) + if copy: + return data.copy() + else: + return data # TODO: dask.array.ma.MaskedArray also exists, better way? if isinstance(data, np.ma.MaskedArray): @@ -176,21 +172,21 @@ def asarray( raise NotImplementedError("MaskedArray is not supported yet") _dims = _infer_dims(data.shape, dims) - return NamedArray(_dims, data, attrs) + return NamedArray(_dims, data) if isinstance(data, _arrayfunction_or_api): _dims = _infer_dims(data.shape, dims) - return NamedArray(_dims, data, attrs) + return NamedArray(_dims, data) if isinstance(data, tuple): _data = to_0d_object_array(data) _dims = _infer_dims(_data.shape, dims) - return NamedArray(_dims, _data, attrs) + return NamedArray(_dims, _data) # validate whether the data is valid data types. - _data = np.asarray(data) + _data = np.asarray(data, dtype=dtype, device=device, copy=copy) _dims = _infer_dims(_data.shape, dims) - return NamedArray(_dims, _data, attrs) + return NamedArray(_dims, _data) # %% Data types From 51181a8b0df973c03bc34603026c228f7d052474 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 18:40:17 +0000 Subject: [PATCH 081/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index dcf502f8f7d..66ad240bd29 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -10,7 +10,6 @@ _arrayapi, _arrayfunction_or_api, _ArrayLike, - _AttrsLike, _Axes, _Axis, _AxisLike, From 05b676871f071db943ef69cec3bce5e900e33248 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 21:01:51 +0200 Subject: [PATCH 082/367] add utility functions --- xarray/namedarray/_array_api.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index dcf502f8f7d..f18d52b289c 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -944,3 +944,36 @@ def mean( dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out + + +# %% Utility functions +def all( + x, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.all(x._data, axis=axis_, keepdims=False) # We fix keepdims later + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def any( + x, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.any(x._data, axis=axis_, keepdims=False) # We fix keepdims later + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out From 8a38bda27bf922f81dd4b91bfb05263ca85de653 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 21:02:08 +0200 Subject: [PATCH 083/367] Add full ones zeroes --- xarray/namedarray/_array_api.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index f18d52b289c..38918743ba2 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -189,6 +189,27 @@ def asarray( return NamedArray(_dims, _data) +def full( + shape, fill_value, *, dtype: _DType | None = None, device=None +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.full(shape, fill_value, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def ones( + shape, *, dtype: _DType | None = None, device=None +) -> NamedArray[_ShapeType, _DType]: + return full(shape, 1, dtype=dtype, device=device) + + +def zeros( + shape, *, dtype: _DType | None = None, device=None +) -> NamedArray[_ShapeType, _DType]: + return full(shape, 0, dtype=dtype, device=device) + + # %% Data types # TODO: should delegate to underlying array? Cubed doesn't at the moment. int8 = np.int8 From 05585e344dae079976082488bc8d26157f2a8bc9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 21:24:20 +0200 Subject: [PATCH 084/367] Add very basic reshape --- xarray/namedarray/_array_api.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 04280cf7960..621a0375f21 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -898,6 +898,16 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT return out +def reshape(x, /, shape: _Shape, *, copy: bool | None = None): + xp = _get_data_namespace(x) + _data = xp.reshape(x._data, shape) + out = asarray(_data, copy=copy) + # TODO: Have better control where the dims went. + # TODO: If reshaping should we save the dims? + # TODO: What's the xarray equivalent? + return out + + # %% Statistical Functions def mean( x: NamedArray[Any, _DType], From 3ea7332a99c1e438afe4efa3729c25b155574201 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 21:37:44 +0200 Subject: [PATCH 085/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 621a0375f21..6f9539ae753 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -370,7 +370,7 @@ def iinfo(type: _dtype | NamedArray[Any, Any], /): return xp.iinfo(type._data) else: xp = _get_namespace_dtype(type) - return xp.finfo(type) + return xp.iinfo(type) def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: From bd0d41adc106410ce0e6ff0d32f2e23d34347958 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 22:01:55 +0200 Subject: [PATCH 086/367] Add some typing --- xarray/namedarray/_array_api.py | 16 ++++++++++------ xarray/namedarray/core.py | 7 +++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 6f9539ae753..31b427d09d7 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -77,10 +77,10 @@ def _infer_dims( def arange( - start, + start: int | float, /, - stop=None, - step=1, + stop: int | float | None = None, + step: int | float = 1, *, dtype: _DType | None = None, device=None, @@ -189,7 +189,11 @@ def asarray( def full( - shape, fill_value, *, dtype: _DType | None = None, device=None + shape: _Shape, + fill_value: bool | int | float | complex, + *, + dtype: _DType | None = None, + device=None, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() _data = xp.full(shape, fill_value, dtype=dtype, device=device) @@ -198,13 +202,13 @@ def full( def ones( - shape, *, dtype: _DType | None = None, device=None + shape: _Shape, *, dtype: _DType | None = None, device=None ) -> NamedArray[_ShapeType, _DType]: return full(shape, 1, dtype=dtype, device=device) def zeros( - shape, *, dtype: _DType | None = None, device=None + shape: _Shape, *, dtype: _DType | None = None, device=None ) -> NamedArray[_ShapeType, _DType]: return full(shape, 0, dtype=dtype, device=device) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 4ae58bd9667..97a5ad97716 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -547,6 +547,13 @@ def __len__(self) -> _IntOrUnknown: def __bool__(self, /) -> bool: return self._data.__bool__() + def __getitem__(self, key): + if isinstance(key, int): + _data = self._data[key] + return self._new((), _data) + else: + raise NotImplementedError("only int supported") + @property def dtype(self) -> _DType_co: """ From 4192f8b78fdd6f5b913b381abb5a9d48aab9eacd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 22:02:04 +0200 Subject: [PATCH 087/367] Add linspace --- xarray/namedarray/_array_api.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index 31b427d09d7..aeefa8ba45e 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -201,6 +201,24 @@ def full( return NamedArray(_dims, _data) +def linspace( + start: int | float | complex, + stop: int | float | complex, + /, + num: int, + *, + dtype: _DType | None = None, + device=None, + endpoint: bool = True, +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.linspace( + start, stop, num=num, dtype=dtype, device=device, endpoint=endpoint + ) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + def ones( shape: _Shape, *, dtype: _DType | None = None, device=None ) -> NamedArray[_ShapeType, _DType]: From d41e6b09cf37a7fe1573228296e2437284b1e118 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 22:22:59 +0200 Subject: [PATCH 088/367] more basic getitem --- xarray/namedarray/core.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 97a5ad97716..1bfc6fc8460 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -64,6 +64,7 @@ _DimsLike, _DimsLikeAgg, _DType, + _IndexKeyLike, _IntOrUnknown, _ScalarType, _Shape, @@ -547,12 +548,16 @@ def __len__(self) -> _IntOrUnknown: def __bool__(self, /) -> bool: return self._data.__bool__() - def __getitem__(self, key): - if isinstance(key, int): + def __getitem__(self, key: _IndexKeyLike | NamedArray): + if isinstance(key, (int, slice, tuple)): _data = self._data[key] return self._new((), _data) + elif isinstance(key, NamedArray): + _key = self._data # TODO: Transpose, unordered dims shouldn't matter. + _data = self._data[_key] + return self._new(key._dims, _data) else: - raise NotImplementedError("only int supported") + raise NotImplementedError("{k=} is not supported") @property def dtype(self) -> _DType_co: From 6c04ca6f576b4ac3cc25c3b90e19dc92fe38cb94 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:38:25 +0200 Subject: [PATCH 089/367] Add comparison operators --- xarray/namedarray/core.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 1bfc6fc8460..8d3ea506605 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -548,6 +548,40 @@ def __len__(self) -> _IntOrUnknown: def __bool__(self, /) -> bool: return self._data.__bool__() + # Comparison Operators + + def __eq__(self, other, /): + from xarray.namedarray._array_api import equal + + return equal(self, other) + + def __ge__(self, other, /): + from xarray.namedarray._array_api import greater_equal + + return greater_equal(self, other) + + def __gt__(self, other, /): + from xarray.namedarray._array_api import greater + + return greater(self, other) + + def __le__(self, other, /): + from xarray.namedarray._array_api import less_equal + + return less_equal(self, other) + + def __lt__(self, other, /): + from xarray.namedarray._array_api import less + + return less(self, other) + + def __ne__(self, other, /): + from xarray.namedarray._array_api import not_equal + + return not_equal(self, other) + + # Something + def __getitem__(self, key: _IndexKeyLike | NamedArray): if isinstance(key, (int, slice, tuple)): _data = self._data[key] From 62c432647c0262fe422d27464aef784aa653ae35 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:45:14 +0200 Subject: [PATCH 090/367] add Arithmetic Operators --- xarray/namedarray/core.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 8d3ea506605..d9f4384aee8 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -548,6 +548,53 @@ def __len__(self) -> _IntOrUnknown: def __bool__(self, /) -> bool: return self._data.__bool__() + # Arithmetic Operators + + def __neg__(self, /): + from xarray.namedarray._array_api import negative + + return negative(self) + + def __pos__(self, /): + from xarray.namedarray._array_api import positive + + return positive(self) + + def __add__(self, other, /): + from xarray.namedarray._array_api import add + + return add(self, other) + + def __sub__(self, other, /): + from xarray.namedarray._array_api import subtract + + return subtract(self, other) + + def __mul__(self, other, /): + from xarray.namedarray._array_api import multiply + + return multiply(self, other) + + def __truediv__(self, other, /): + from xarray.namedarray._array_api import divide + + return divide(self, other) + + def __floordiv__(self, other, /): + from xarray.namedarray._array_api import floor_divide + + return floor_divide(self, other) + + def __mod__(self, other, /): + from xarray.namedarray._array_api import remainder + + return remainder(self, other) + + def __pow__(self, other, /): + from xarray.namedarray._array_api import pow + + return pow(self, other) + # Comparison Operators def __eq__(self, other, /): From 17eed31b430f6111e1c65f6d33816645ce802a0f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:47:09 +0200 Subject: [PATCH 091/367] array operators --- xarray/namedarray/core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d9f4384aee8..48dc3db345c 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -595,6 +595,13 @@ def __pow__(self, other, /): return pow(self, other) + # Array Operators + + def __matmul__(self, other, /): + from xarray.namedarray._array_api import matmul + + return matmul(self, other) + # Comparison Operators def __eq__(self, other, /): From a147c64ef36ac1a1ed1f9918241d784a15502a2b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 18 Aug 2024 23:49:49 +0200 Subject: [PATCH 092/367] bitwise operators --- xarray/namedarray/core.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 48dc3db345c..8b09d4e9e5b 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -602,6 +602,38 @@ def __matmul__(self, other, /): return matmul(self, other) + # Bitwise Operators + + def __invert__(self, /): + from xarray.namedarray._array_api import bitwise_invert + + return bitwise_invert(self) + + def __and__(self, other, /): + from xarray.namedarray._array_api import bitwise_and + + return bitwise_and(self) + + def __or__(self, other, /): + from xarray.namedarray._array_api import bitwise_or + + return bitwise_or(self) + + def __xor__(self, other, /): + from xarray.namedarray._array_api import bitwise_xor + + return bitwise_xor(self) + + def __lshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_left_shift + + return bitwise_left_shift(self) + + def __rshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_right_shift + + return bitwise_right_shift(self) + # Comparison Operators def __eq__(self, other, /): From a0263d443206a27ddd484cf94d03ba248dee4cee Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 00:10:50 +0200 Subject: [PATCH 093/367] reflected operators --- xarray/namedarray/core.py | 72 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 8b09d4e9e5b..3f01a82da82 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -666,6 +666,78 @@ def __ne__(self, other, /): return not_equal(self, other) + # Reflected Operators + + # (Reflected) Arithmetic Operators + + def __radd__(self, other, /): + from xarray.namedarray._array_api import add + + return add(other, self) + + def __rsub__(self, other, /): + from xarray.namedarray._array_api import subtract + + return subtract(other, self) + + def __rmul__(self, other, /): + from xarray.namedarray._array_api import multiply + + return multiply(other, self) + + def __rtruediv__(self, other, /): + from xarray.namedarray._array_api import divide + + return divide(other, self) + + def __rfloordiv__(self, other, /): + from xarray.namedarray._array_api import floor_divide + + return floor_divide(other, self) + + def __rmod__(self, other, /): + from xarray.namedarray._array_api import remainder + + return remainder(other, self) + + def __rpow__(self, other, /): + from xarray.namedarray._array_api import pow + + return pow(other, self) + + # (Reflected) Array Operators + + def __rmatmul__(self, other, /): + from xarray.namedarray._array_api import matmul + + return matmul(other, self) + + # (Reflected) Bitwise Operators + + def __rand__(self, other, /): + from xarray.namedarray._array_api import bitwise_and + + return bitwise_and(other, self) + + def __ror__(self, other, /): + from xarray.namedarray._array_api import bitwise_or + + return bitwise_or(other, self) + + def __rxor__(self, other, /): + from xarray.namedarray._array_api import bitwise_xor + + return bitwise_xor(other, self) + + def __rlshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_left_shift + + return bitwise_left_shift(other, self) + + def __rrshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_right_shift + + return bitwise_right_shift(other, self) # Something def __getitem__(self, key: _IndexKeyLike | NamedArray): From 36799a8b24093cd6a42ddac388c7c3fe2a87741c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Aug 2024 22:11:28 +0000 Subject: [PATCH 094/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 3f01a82da82..634267be4df 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -738,6 +738,7 @@ def __rrshift__(self, other, /): from xarray.namedarray._array_api import bitwise_right_shift return bitwise_right_shift(other, self) + # Something def __getitem__(self, key: _IndexKeyLike | NamedArray): From 693b3c3b2cdf5952a3b3aab8de10ae090c63ba6a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 18:37:10 +0200 Subject: [PATCH 095/367] Update _array_api.py --- xarray/namedarray/_array_api.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index aeefa8ba45e..1eed255b506 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -37,11 +37,15 @@ # %% Helper functions +def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: + return np if xp is None else xp + + def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: if isinstance(x._data, _arrayapi): return x._data.__array_namespace__() - return np + return _maybe_default_namespace() def _get_namespace_dtype(dtype: _dtype) -> ModuleType: @@ -49,10 +53,6 @@ def _get_namespace_dtype(dtype: _dtype) -> ModuleType: return xp -def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: - return np if xp is None else xp - - # %% array_api version __array_api_version__ = "2023.12" @@ -373,7 +373,6 @@ def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: return xp.can_cast(from_, to) else: xp = _get_namespace_dtype(from_) - from_ = from_.dtype return xp.can_cast(from_, to) From 45591ac430cf215180f0cf1f8488fbd8cfea2512 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:29:52 +0200 Subject: [PATCH 096/367] Split into multiple files. --- xarray/namedarray/_array_api/__init__.py | 257 ++++++++++ xarray/namedarray/_array_api/_utils.py | 25 + xarray/namedarray/_array_api/constants.py | 9 + .../_array_api/creation_functions.py | 202 ++++++++ .../_array_api/data_type_functions.py | 130 +++++ xarray/namedarray/_array_api/dtypes.py | 124 +++++ .../_array_api/elementwise_functions.py | 466 ++++++++++++++++++ .../_array_api/indexing_functions.py | 3 + .../_array_api/linear_algebra_functions.py | 3 + .../_array_api/manipulation_functions.py | 122 +++++ .../_array_api/searching_functions.py | 3 + .../_array_api/statistical_functions.py | 104 ++++ .../_array_api/utility_functions.py | 69 +++ 13 files changed, 1517 insertions(+) create mode 100644 xarray/namedarray/_array_api/__init__.py create mode 100644 xarray/namedarray/_array_api/_utils.py create mode 100644 xarray/namedarray/_array_api/constants.py create mode 100644 xarray/namedarray/_array_api/creation_functions.py create mode 100644 xarray/namedarray/_array_api/data_type_functions.py create mode 100644 xarray/namedarray/_array_api/dtypes.py create mode 100644 xarray/namedarray/_array_api/elementwise_functions.py create mode 100644 xarray/namedarray/_array_api/indexing_functions.py create mode 100644 xarray/namedarray/_array_api/linear_algebra_functions.py create mode 100644 xarray/namedarray/_array_api/manipulation_functions.py create mode 100644 xarray/namedarray/_array_api/searching_functions.py create mode 100644 xarray/namedarray/_array_api/statistical_functions.py create mode 100644 xarray/namedarray/_array_api/utility_functions.py diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py new file mode 100644 index 00000000000..da66bc620a2 --- /dev/null +++ b/xarray/namedarray/_array_api/__init__.py @@ -0,0 +1,257 @@ +__all__ = [] + +__array_api_version__ = "2023.12" + +__all__ += ["__array_api_version__"] + +from .array_object import Array + +__all__ += ["Array"] + +from .constants import e, inf, nan, newaxis, pi + +__all__ += ["e", "inf", "nan", "newaxis", "pi"] + +from .creation_functions import ( + arange, + asarray, + empty, + empty_like, + eye, + full, + full_like, + linspace, + meshgrid, + ones, + ones_like, + tril, + triu, + zeros, + zeros_like, +) + +__all__ += [ + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", +] + +from .data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type + +__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"] + +from .dtypes import ( + bool, + complex64, + complex128, + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) + +__all__ += [ + "bool", + "complex64", + "complex128", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", +] + +from .elementwise_functions import ( + abs, + acos, + acosh, + add, + asin, + asinh, + atan, + atan2, + atanh, + bitwise_and, + bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, + ceil, + conj, + cos, + cosh, + divide, + equal, + exp, + expm1, + floor, + floor_divide, + greater, + greater_equal, + imag, + isfinite, + isinf, + isnan, + less, + less_equal, + log, + log1p, + log2, + log10, + logaddexp, + logical_and, + logical_not, + logical_or, + logical_xor, + multiply, + negative, + not_equal, + positive, + pow, + real, + remainder, + round, + sign, + sin, + sinh, + sqrt, + square, + subtract, + tan, + tanh, + trunc, +) + +__all__ += [ + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_invert", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "conj", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "multiply", + "negative", + "not_equal", + "positive", + "pow", + "real", + "remainder", + "round", + "sign", + "sin", + "sinh", + "sqrt", + "square", + "subtract", + "tan", + "tanh", + "trunc", +] + +from .indexing_functions import take + +__all__ += ["take"] + +from .linear_algebra_functions import matmul, matrix_transpose, outer, tensordot, vecdot + +__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] + +from .manipulation_functions import ( + broadcast_arrays, + broadcast_to, + concat, + expand_dims, + flip, + moveaxis, + permute_dims, + reshape, + roll, + squeeze, + stack, +) + +__all__ += [ + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "reshape", + "roll", + "squeeze", + "stack", +] + +from .searching_functions import argmax, argmin, where + +__all__ += ["argmax", "argmin", "where"] + +from .statistical_functions import max, mean, min, prod, sum + +__all__ += ["max", "mean", "min", "prod", "sum"] + +from .utility_functions import all, any + +__all__ += ["all", "any"] \ No newline at end of file diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py new file mode 100644 index 00000000000..23275e5041d --- /dev/null +++ b/xarray/namedarray/_array_api/_utils.py @@ -0,0 +1,25 @@ +from typing import Any, ModuleType, TYPE_CHECKING + +import numpy as np + +from xarray.namedarray._typing import _arrayapi, _dtype + + +if TYPE_CHECKING: + from xarray.namedarray.core import NamedArray + + +def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: + return np if xp is None else xp + + +def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: + if isinstance(x._data, _arrayapi): + return x._data.__array_namespace__() + + return _maybe_default_namespace() + + +def _get_namespace_dtype(dtype: _dtype) -> ModuleType: + xp = __import__(dtype.__module__) + return xp diff --git a/xarray/namedarray/_array_api/constants.py b/xarray/namedarray/_array_api/constants.py new file mode 100644 index 00000000000..8b8da5ffbe8 --- /dev/null +++ b/xarray/namedarray/_array_api/constants.py @@ -0,0 +1,9 @@ +from xarray.namedarray._array_api._utils import _maybe_default_namespace + +_xp = _maybe_default_namespace() + +e = _xp.e +inf = _xp.inf +nan = _xp.nan +newaxis = _xp.newaxis +pi = _xp.pi diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/creation_functions.py new file mode 100644 index 00000000000..012b9dd0166 --- /dev/null +++ b/xarray/namedarray/_array_api/creation_functions.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._array_api._utils import _maybe_default_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + + +def _infer_dims( + shape: _Shape, + dims: _DimsLike | Default = _default, +) -> _DimsLike: + if dims is _default: + return tuple(f"dim_{n}" for n in range(len(shape))) + else: + return dims + + +def arange( + start: int | float, + /, + stop: int | float | None = None, + step: int | float = 1, + *, + dtype: _DType | None = None, + device=None, +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +@overload +def asarray( + obj: duckarray[_ShapeType, Any], + /, + *, + dtype: _DType, + device=..., + copy: bool | None = ..., + dims: _DimsLike = ..., +) -> NamedArray[_ShapeType, _DType]: ... +@overload +def asarray( + obj: _ArrayLike, + /, + *, + dtype: _DType, + device=..., + copy: bool | None = ..., + dims: _DimsLike = ..., +) -> NamedArray[Any, _DType]: ... +@overload +def asarray( + obj: duckarray[_ShapeType, _DType], + /, + *, + dtype: None, + device=None, + copy: bool | None = None, + dims: _DimsLike = ..., +) -> NamedArray[_ShapeType, _DType]: ... +@overload +def asarray( + obj: _ArrayLike, + /, + *, + dtype: None, + device=..., + copy: bool | None = ..., + dims: _DimsLike = ..., +) -> NamedArray[Any, _DType]: ... +def asarray( + obj: duckarray[_ShapeType, _DType] | _ArrayLike, + /, + *, + dtype: _DType | None = None, + device=None, + copy: bool | None = None, + dims: _DimsLike = _default, +) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: + """ + Create a Named array from an array-like object. + + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or ArrayLike + The actual data that populates the array. Should match the + shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or + attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + """ + data = obj + if isinstance(data, NamedArray): + if copy: + return data.copy() + else: + return data + + # TODO: dask.array.ma.MaskedArray also exists, better way? + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and + # xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + + _dims = _infer_dims(data.shape, dims) + return NamedArray(_dims, data) + + if isinstance(data, _arrayfunction_or_api): + _dims = _infer_dims(data.shape, dims) + return NamedArray(_dims, data) + + if isinstance(data, tuple): + _data = to_0d_object_array(data) + _dims = _infer_dims(_data.shape, dims) + return NamedArray(_dims, _data) + + # validate whether the data is valid data types. + _data = np.asarray(data, dtype=dtype, device=device, copy=copy) + _dims = _infer_dims(_data.shape, dims) + return NamedArray(_dims, _data) + + +def full( + shape: _Shape, + fill_value: bool | int | float | complex, + *, + dtype: _DType | None = None, + device=None, +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.full(shape, fill_value, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def linspace( + start: int | float | complex, + stop: int | float | complex, + /, + num: int, + *, + dtype: _DType | None = None, + device=None, + endpoint: bool = True, +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.linspace( + start, stop, num=num, dtype=dtype, device=device, endpoint=endpoint + ) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def ones( + shape: _Shape, *, dtype: _DType | None = None, device=None +) -> NamedArray[_ShapeType, _DType]: + return full(shape, 1, dtype=dtype, device=device) + + +def zeros( + shape: _Shape, *, dtype: _DType | None = None, device=None +) -> NamedArray[_ShapeType, _DType]: + return full(shape, 0, dtype=dtype, device=device) diff --git a/xarray/namedarray/_array_api/data_type_functions.py b/xarray/namedarray/_array_api/data_type_functions.py new file mode 100644 index 00000000000..fc0c1be4abc --- /dev/null +++ b/xarray/namedarray/_array_api/data_type_functions.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _get_namespace_dtype, +) +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + + +def astype( + x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True +) -> NamedArray[_ShapeType, _DType]: + """ + Copies an array to a specified data type irrespective of Type Promotion Rules rules. + + Parameters + ---------- + x : NamedArray + Array to cast. + dtype : _DType + Desired data type. + copy : bool, optional + Specifies whether to copy an array when the specified dtype matches the data + type of the input array x. + If True, a newly allocated array must always be returned. + If False and the specified dtype matches the data type of the input array, + the input array must be returned; otherwise, a newly allocated array must be + returned. Default: True. + + Returns + ------- + out : NamedArray + An array having the specified data type. The returned array must have the + same shape as x. + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.5, 2.5])) + >>> narr + Size: 16B + array([1.5, 2.5]) + >>> astype(narr, np.dtype(np.int32)) + Size: 8B + array([1, 2], dtype=int32) + """ + if isinstance(x._data, _arrayapi): + xp = x._data.__array_namespace__() + return x._new(data=xp.astype(x._data, dtype, copy=copy)) + + # np.astype doesn't exist yet: + return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] + + +def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: + if isinstance(from_, NamedArray): + xp = _get_data_namespace(from_) + from_ = from_.dtype + return xp.can_cast(from_, to) + else: + xp = _get_namespace_dtype(from_) + return xp.can_cast(from_, to) + + +def finfo(type: _dtype | NamedArray[Any, Any], /): + if isinstance(type, NamedArray): + xp = _get_data_namespace(type) + return xp.finfo(type._data) + else: + xp = _get_namespace_dtype(type) + return xp.finfo(type) + + +def iinfo(type: _dtype | NamedArray[Any, Any], /): + if isinstance(type, NamedArray): + xp = _get_data_namespace(type) + return xp.iinfo(type._data) + else: + xp = _get_namespace_dtype(type) + return xp.iinfo(type) + + +def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: + xp = _get_namespace_dtype(type) + return xp.isdtype(dtype, kind) + + +def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: + # TODO: Empty arg? + arr_or_dtype = arrays_and_dtypes[0] + if isinstance(arr_or_dtype, NamedArray): + xp = _get_data_namespace(arr_or_dtype) + else: + xp = _get_namespace_dtype(arr_or_dtype) + + return xp.result_type( + *(a.dtype if isinstance(a, NamedArray) else a for a in arrays_and_dtypes) + ) diff --git a/xarray/namedarray/_array_api/dtypes.py b/xarray/namedarray/_array_api/dtypes.py new file mode 100644 index 00000000000..016db30ff95 --- /dev/null +++ b/xarray/namedarray/_array_api/dtypes.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._array_api._utils import _maybe_default_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + +_xp = _maybe_default_namespace() +int8 = _xp.int8 +int16 = _xp.int16 +int32 = _xp.int32 +int64 = _xp.int64 +uint8 = _xp.uint8 +uint16 = _xp.uint16 +uint32 = _xp.uint32 +uint64 = _xp.uint64 +float32 = _xp.float32 +float64 = _xp.float64 +complex64 = _xp.complex64 +complex128 = _xp.complex128 +bool = _xp.bool + +_all_dtypes = ( + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, + bool, +) +_boolean_dtypes = (bool,) +_real_floating_dtypes = (float32, float64) +_floating_dtypes = (float32, float64, complex64, complex128) +_complex_floating_dtypes = (complex64, complex128) +_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) +_signed_integer_dtypes = (int8, int16, int32, int64) +_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) +_integer_or_boolean_dtypes = ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_real_numeric_dtypes = ( + float32, + float64, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) +_numeric_dtypes = ( + float32, + float64, + complex64, + complex128, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, +) + +_dtype_categories = { + "all": _all_dtypes, + "real numeric": _real_numeric_dtypes, + "numeric": _numeric_dtypes, + "integer": _integer_dtypes, + "integer or boolean": _integer_or_boolean_dtypes, + "boolean": _boolean_dtypes, + "real floating-point": _floating_dtypes, + "complex floating-point": _complex_floating_dtypes, + "floating-point": _floating_dtypes, +} diff --git a/xarray/namedarray/_array_api/elementwise_functions.py b/xarray/namedarray/_array_api/elementwise_functions.py new file mode 100644 index 00000000000..6a53a8b844b --- /dev/null +++ b/xarray/namedarray/_array_api/elementwise_functions.py @@ -0,0 +1,466 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + + +def abs(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.abs(x._data)) + return out + + +def acos(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.acos(x._data)) + return out + + +def acosh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.acosh(x._data)) + return out + + +def add(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.add(x1._data, x2._data)) + return out + + +def asin(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.asin(x._data)) + return out + + +def asinh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.asinh(x._data)) + return out + + +def atan(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.atan(x._data)) + return out + + +def atan2(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.atan2(x1._data, x2._data)) + return out + + +def atanh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.atanh(x._data)) + return out + + +def bitwise_and(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_and(x1._data, x2._data)) + return out + + +def bitwise_invert(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.bitwise_invert(x._data)) + return out + + +def bitwise_left_shift(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_left_shift(x1._data, x2._data)) + return out + + +def bitwise_or(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_or(x1._data, x2._data)) + return out + + +def bitwise_right_shift(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_right_shift(x1._data, x2._data)) + return out + + +def bitwise_xor(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.bitwise_xor(x1._data, x2._data)) + return out + + +def ceil(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.ceil(x._data)) + return out + + +def conj(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.conj(x._data)) + return out + + +def cos(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.cos(x._data)) + return out + + +def cosh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.cosh(x._data)) + return out + + +def divide(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.divide(x1._data, x2._data)) + return out + + +def exp(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.exp(x._data)) + return out + + +def expm1(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.expm1(x._data)) + return out + + +def equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.equal(x1._data, x2._data)) + return out + + +def floor(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.floor(x._data)) + return out + + +def floor_divide(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.floor_divide(x1._data, x2._data)) + return out + + +def greater(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.greater(x1._data, x2._data)) + return out + + +def greater_equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.greater_equal(x1._data, x2._data)) + return out + + +def imag( + x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the imaginary component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) + >>> imag(narr) + Size: 16B + array([2., 4.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.imag(x._data)) + return out + + +def isfinite(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.isfinite(x._data)) + return out + + +def isinf(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.isinf(x._data)) + return out + + +def isnan(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.isnan(x._data)) + return out + + +def less(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.less(x1._data, x2._data)) + return out + + +def less_equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.less_equal(x1._data, x2._data)) + return out + + +def log(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log(x._data)) + return out + + +def log1p(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log1p(x._data)) + return out + + +def log2(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log2(x._data)) + return out + + +def log10(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.log10(x._data)) + return out + + +def logaddexp(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logaddexp(x1._data, x2._data)) + return out + + +def logical_and(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logical_and(x1._data, x2._data)) + return out + + +def logical_not(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.logical_not(x._data)) + return out + + +def logical_or(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logical_or(x1._data, x2._data)) + return out + + +def logical_xor(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.logical_xor(x1._data, x2._data)) + return out + + +def multiply(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.multiply(x1._data, x2._data)) + return out + + +def negative(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.negative(x._data)) + return out + + +def not_equal(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.not_equal(x1._data, x2._data)) + return out + + +def positive(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.positive(x._data)) + return out + + +def pow(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.pow(x1._data, x2._data)) + return out + + +def real( + x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var] +) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: + """ + Returns the real component of a complex number for each element x_i of the + input array x. + + Parameters + ---------- + x : NamedArray + Input array. Should have a complex floating-point data type. + + Returns + ------- + out : NamedArray + An array containing the element-wise results. The returned array must have a + floating-point data type with the same floating-point precision as x + (e.g., if x is complex64, the returned array must have the floating-point + data type float32). + + Examples + -------- + >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) + >>> real(narr) + Size: 16B + array([1., 2.]) + """ + xp = _get_data_namespace(x) + out = x._new(data=xp.real(x._data)) + return out + + +def remainder(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.remainder(x1._data, x2._data)) + return out + + +def round(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.round(x._data)) + return out + + +def sign(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sign(x._data)) + return out + + +def sin(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sin(x._data)) + return out + + +def sinh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sinh(x._data)) + return out + + +def sqrt(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.sqrt(x._data)) + return out + + +def square(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.square(x._data)) + return out + + +def subtract(x1, x2, /): + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.subtract(x1._data, x2._data)) + return out + + +def tan(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.tan(x._data)) + return out + + +def tanh(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.tanh(x._data)) + return out + + +def trunc(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.trunc(x._data)) + return out diff --git a/xarray/namedarray/_array_api/indexing_functions.py b/xarray/namedarray/_array_api/indexing_functions.py new file mode 100644 index 00000000000..b62d0a8393b --- /dev/null +++ b/xarray/namedarray/_array_api/indexing_functions.py @@ -0,0 +1,3 @@ +from xarray.namedarray._array_api._utils import _get_data_namespace + +sdf = _get_data_namespace() diff --git a/xarray/namedarray/_array_api/linear_algebra_functions.py b/xarray/namedarray/_array_api/linear_algebra_functions.py new file mode 100644 index 00000000000..b62d0a8393b --- /dev/null +++ b/xarray/namedarray/_array_api/linear_algebra_functions.py @@ -0,0 +1,3 @@ +from xarray.namedarray._array_api._utils import _get_data_namespace + +sdf = _get_data_namespace() diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py new file mode 100644 index 00000000000..e38f12373d0 --- /dev/null +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + + +def expand_dims( + x: NamedArray[Any, _DType], + /, + *, + dim: _Dim | Default = _default, + axis: _Axis = 0, +) -> NamedArray[Any, _DType]: + """ + Expands the shape of an array by inserting a new dimension of size one at the + position specified by dims. + + Parameters + ---------- + x : + Array to expand. + dim : + Dimension name. New dimension will be stored in the axis position. + axis : + (Not recommended) Axis position (zero-based). Default is 0. + + Returns + ------- + out : + An expanded output array having the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> expand_dims(x) + Size: 32B + array([[[1., 2.], + [3., 4.]]]) + >>> expand_dims(x, dim="z") + Size: 32B + array([[[1., 2.], + [3., 4.]]]) + """ + xp = _get_data_namespace(x) + dims = x.dims + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) + d.insert(axis, dim) + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) + return out + + +def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: + """ + Permutes the dimensions of an array. + + Parameters + ---------- + x : + Array to permute. + axes : + Permutation of the dimensions of x. + + Returns + ------- + out : + An array with permuted dimensions. The returned array must have the same + data type as x. + + """ + + dims = x.dims + new_dims = tuple(dims[i] for i in axes) + if isinstance(x._data, _arrayapi): + xp = _get_data_namespace(x) + out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) + else: + out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] + return out + + +def reshape(x, /, shape: _Shape, *, copy: bool | None = None): + xp = _get_data_namespace(x) + _data = xp.reshape(x._data, shape) + out = asarray(_data, copy=copy) + # TODO: Have better control where the dims went. + # TODO: If reshaping should we save the dims? + # TODO: What's the xarray equivalent? + return out diff --git a/xarray/namedarray/_array_api/searching_functions.py b/xarray/namedarray/_array_api/searching_functions.py new file mode 100644 index 00000000000..b62d0a8393b --- /dev/null +++ b/xarray/namedarray/_array_api/searching_functions.py @@ -0,0 +1,3 @@ +from xarray.namedarray._array_api._utils import _get_data_namespace + +sdf = _get_data_namespace() diff --git a/xarray/namedarray/_array_api/statistical_functions.py b/xarray/namedarray/_array_api/statistical_functions.py new file mode 100644 index 00000000000..19e84126c2d --- /dev/null +++ b/xarray/namedarray/_array_api/statistical_functions.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + + +def mean( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + """ + Calculates the arithmetic mean of the input array x. + + Parameters + ---------- + x : + Should have a real-valued floating-point data type. + dims : + Dim or dims along which arithmetic means must be computed. By default, + the mean must be computed over the entire array. If a tuple of hashables, + arithmetic means must be computed over multiple axes. + Default: None. + keepdims : + if True, the reduced axes (dimensions) must be included in the result + as singleton dimensions, and, accordingly, the result must be compatible + with the input array (see Broadcasting). Otherwise, if False, the + reduced axes (dimensions) must not be included in the result. + Default: False. + axis : + Axis or axes along which arithmetic means must be computed. By default, + the mean must be computed over the entire array. If a tuple of integers, + arithmetic means must be computed over multiple axes. + Default: None. + + Returns + ------- + out : + If the arithmetic mean was computed over the entire array, + a zero-dimensional array containing the arithmetic mean; otherwise, + a non-zero-dimensional array containing the arithmetic means. + The returned array must have the same data type as x. + + Examples + -------- + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> mean(x).data + Array(2.5, dtype=float64) + >>> mean(x, dims=("x",)).data + Array([2., 3.], dtype=float64) + + Using keepdims: + + >>> mean(x, dims=("x",), keepdims=True) + + Array([[2., 3.]], dtype=float64) + >>> mean(x, dims=("y",), keepdims=True) + + Array([[1.5], + [3.5]], dtype=float64) + """ + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.mean(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out diff --git a/xarray/namedarray/_array_api/utility_functions.py b/xarray/namedarray/_array_api/utility_functions.py new file mode 100644 index 00000000000..849e13233b2 --- /dev/null +++ b/xarray/namedarray/_array_api/utility_functions.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from types import ModuleType +from typing import Any, overload + +import numpy as np + +from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _arrayfunction_or_api, + _ArrayLike, + _Axes, + _Axis, + _AxisLike, + _default, + _Dim, + _Dims, + _DimsLike, + _DType, + _dtype, + _ScalarType, + _Shape, + _ShapeType, + _SupportsImag, + _SupportsReal, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) +from xarray.namedarray.utils import ( + to_0d_object_array, +) + + +def all( + x, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.all(x._data, axis=axis_, keepdims=False) # We fix keepdims later + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def any( + x, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.any(x._data, axis=axis_, keepdims=False) # We fix keepdims later + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out From d57905000b67c277421714b3817f02fbc93d4ea2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:30:34 +0000 Subject: [PATCH 097/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/__init__.py | 39 ++++++++++++------- xarray/namedarray/_array_api/_utils.py | 3 +- .../_array_api/creation_functions.py | 13 ------- .../_array_api/data_type_functions.py | 25 +----------- xarray/namedarray/_array_api/dtypes.py | 34 ---------------- .../_array_api/elementwise_functions.py | 23 ----------- .../_array_api/manipulation_functions.py | 21 +--------- .../_array_api/statistical_functions.py | 22 +---------- .../_array_api/utility_functions.py | 22 +---------- 9 files changed, 31 insertions(+), 171 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index da66bc620a2..9bce7c05c7a 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -4,15 +4,15 @@ __all__ += ["__array_api_version__"] -from .array_object import Array +from xarray.namedarray._array_api.array_object import Array __all__ += ["Array"] -from .constants import e, inf, nan, newaxis, pi +from xarray.namedarray._array_api.constants import e, inf, nan, newaxis, pi __all__ += ["e", "inf", "nan", "newaxis", "pi"] -from .creation_functions import ( +from xarray.namedarray._array_api.creation_functions import ( arange, asarray, empty, @@ -48,11 +48,18 @@ "zeros_like", ] -from .data_type_functions import astype, can_cast, finfo, iinfo, isdtype, result_type +from xarray.namedarray._array_api.data_type_functions import ( + astype, + can_cast, + finfo, + iinfo, + isdtype, + result_type, +) __all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"] -from .dtypes import ( +from xarray.namedarray._array_api.dtypes import ( bool, complex64, complex128, @@ -84,7 +91,7 @@ "uint64", ] -from .elementwise_functions import ( +from xarray.namedarray._array_api.elementwise_functions import ( abs, acos, acosh, @@ -208,15 +215,21 @@ "trunc", ] -from .indexing_functions import take +from xarray.namedarray._array_api.indexing_functions import take __all__ += ["take"] -from .linear_algebra_functions import matmul, matrix_transpose, outer, tensordot, vecdot +from xarray.namedarray._array_api.linear_algebra_functions import ( + matmul, + matrix_transpose, + outer, + tensordot, + vecdot, +) __all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] -from .manipulation_functions import ( +from xarray.namedarray._array_api.manipulation_functions import ( broadcast_arrays, broadcast_to, concat, @@ -244,14 +257,14 @@ "stack", ] -from .searching_functions import argmax, argmin, where +from xarray.namedarray._array_api.searching_functions import argmax, argmin, where __all__ += ["argmax", "argmin", "where"] -from .statistical_functions import max, mean, min, prod, sum +from xarray.namedarray._array_api.statistical_functions import max, mean, min, prod, sum __all__ += ["max", "mean", "min", "prod", "sum"] -from .utility_functions import all, any +from xarray.namedarray._array_api.utility_functions import all, any -__all__ += ["all", "any"] \ No newline at end of file +__all__ += ["all", "any"] diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 23275e5041d..445012b3059 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,10 +1,9 @@ -from typing import Any, ModuleType, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, ModuleType import numpy as np from xarray.namedarray._typing import _arrayapi, _dtype - if TYPE_CHECKING: from xarray.namedarray.core import NamedArray diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/creation_functions.py index 012b9dd0166..957f11e6940 100644 --- a/xarray/namedarray/_array_api/creation_functions.py +++ b/xarray/namedarray/_array_api/creation_functions.py @@ -1,6 +1,5 @@ from __future__ import annotations -from types import ModuleType from typing import Any, overload import numpy as np @@ -8,29 +7,17 @@ from xarray.namedarray._array_api._utils import _maybe_default_namespace from xarray.namedarray._typing import ( Default, - _arrayapi, _arrayfunction_or_api, _ArrayLike, - _Axes, - _Axis, - _AxisLike, _default, - _Dim, - _Dims, _DimsLike, _DType, - _dtype, - _ScalarType, _Shape, _ShapeType, - _SupportsImag, - _SupportsReal, duckarray, ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, ) from xarray.namedarray.utils import ( to_0d_object_array, diff --git a/xarray/namedarray/_array_api/data_type_functions.py b/xarray/namedarray/_array_api/data_type_functions.py index fc0c1be4abc..fbfaaca0d0e 100644 --- a/xarray/namedarray/_array_api/data_type_functions.py +++ b/xarray/namedarray/_array_api/data_type_functions.py @@ -1,42 +1,19 @@ from __future__ import annotations -from types import ModuleType -from typing import Any, overload - -import numpy as np +from typing import Any from xarray.namedarray._array_api._utils import ( _get_data_namespace, _get_namespace_dtype, ) from xarray.namedarray._typing import ( - Default, _arrayapi, - _arrayfunction_or_api, - _ArrayLike, - _Axes, - _Axis, - _AxisLike, - _default, - _Dim, - _Dims, - _DimsLike, _DType, _dtype, - _ScalarType, - _Shape, _ShapeType, - _SupportsImag, - _SupportsReal, - duckarray, ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, -) -from xarray.namedarray.utils import ( - to_0d_object_array, ) diff --git a/xarray/namedarray/_array_api/dtypes.py b/xarray/namedarray/_array_api/dtypes.py index 016db30ff95..c796400329a 100644 --- a/xarray/namedarray/_array_api/dtypes.py +++ b/xarray/namedarray/_array_api/dtypes.py @@ -1,40 +1,6 @@ from __future__ import annotations -from types import ModuleType -from typing import Any, overload - -import numpy as np - from xarray.namedarray._array_api._utils import _maybe_default_namespace -from xarray.namedarray._typing import ( - Default, - _arrayapi, - _arrayfunction_or_api, - _ArrayLike, - _Axes, - _Axis, - _AxisLike, - _default, - _Dim, - _Dims, - _DimsLike, - _DType, - _dtype, - _ScalarType, - _Shape, - _ShapeType, - _SupportsImag, - _SupportsReal, - duckarray, -) -from xarray.namedarray.core import ( - NamedArray, - _dims_to_axis, - _get_remaining_dims, -) -from xarray.namedarray.utils import ( - to_0d_object_array, -) _xp = _maybe_default_namespace() int8 = _xp.int8 diff --git a/xarray/namedarray/_array_api/elementwise_functions.py b/xarray/namedarray/_array_api/elementwise_functions.py index 6a53a8b844b..acdd4dc5c48 100644 --- a/xarray/namedarray/_array_api/elementwise_functions.py +++ b/xarray/namedarray/_array_api/elementwise_functions.py @@ -1,39 +1,16 @@ from __future__ import annotations -from types import ModuleType -from typing import Any, overload - import numpy as np from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( - Default, - _arrayapi, - _arrayfunction_or_api, - _ArrayLike, - _Axes, - _Axis, - _AxisLike, - _default, - _Dim, - _Dims, - _DimsLike, - _DType, - _dtype, _ScalarType, - _Shape, _ShapeType, _SupportsImag, _SupportsReal, - duckarray, ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, -) -from xarray.namedarray.utils import ( - to_0d_object_array, ) diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py index e38f12373d0..3da92b76fa9 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -1,38 +1,19 @@ from __future__ import annotations -from types import ModuleType -from typing import Any, overload - -import numpy as np +from typing import Any from xarray.namedarray._typing import ( Default, _arrayapi, - _arrayfunction_or_api, - _ArrayLike, _Axes, _Axis, - _AxisLike, _default, _Dim, - _Dims, - _DimsLike, _DType, - _dtype, - _ScalarType, _Shape, - _ShapeType, - _SupportsImag, - _SupportsReal, - duckarray, ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, -) -from xarray.namedarray.utils import ( - to_0d_object_array, ) diff --git a/xarray/namedarray/_array_api/statistical_functions.py b/xarray/namedarray/_array_api/statistical_functions.py index 19e84126c2d..8aa1db92a7f 100644 --- a/xarray/namedarray/_array_api/statistical_functions.py +++ b/xarray/namedarray/_array_api/statistical_functions.py @@ -1,40 +1,20 @@ from __future__ import annotations -from types import ModuleType -from typing import Any, overload - -import numpy as np +from typing import Any from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( Default, - _arrayapi, - _arrayfunction_or_api, - _ArrayLike, - _Axes, - _Axis, _AxisLike, _default, - _Dim, _Dims, - _DimsLike, _DType, - _dtype, - _ScalarType, - _Shape, - _ShapeType, - _SupportsImag, - _SupportsReal, - duckarray, ) from xarray.namedarray.core import ( NamedArray, _dims_to_axis, _get_remaining_dims, ) -from xarray.namedarray.utils import ( - to_0d_object_array, -) def mean( diff --git a/xarray/namedarray/_array_api/utility_functions.py b/xarray/namedarray/_array_api/utility_functions.py index 849e13233b2..86eb34c4d9c 100644 --- a/xarray/namedarray/_array_api/utility_functions.py +++ b/xarray/namedarray/_array_api/utility_functions.py @@ -1,40 +1,20 @@ from __future__ import annotations -from types import ModuleType -from typing import Any, overload - -import numpy as np +from typing import Any from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( Default, - _arrayapi, - _arrayfunction_or_api, - _ArrayLike, - _Axes, - _Axis, _AxisLike, _default, - _Dim, _Dims, - _DimsLike, _DType, - _dtype, - _ScalarType, - _Shape, - _ShapeType, - _SupportsImag, - _SupportsReal, - duckarray, ) from xarray.namedarray.core import ( NamedArray, _dims_to_axis, _get_remaining_dims, ) -from xarray.namedarray.utils import ( - to_0d_object_array, -) def all( From 42d0293491e60ab9aebd95257c1b25a3cdd68caf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:46:19 +0200 Subject: [PATCH 098/367] rename _array_api.py --- xarray/namedarray/{_array_api.py => _array_api2.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename xarray/namedarray/{_array_api.py => _array_api2.py} (100%) diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api2.py similarity index 100% rename from xarray/namedarray/_array_api.py rename to xarray/namedarray/_array_api2.py From 5a3778cb27f5e3d7efe07a8e87540c37971f30f8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:46:32 +0200 Subject: [PATCH 099/367] fixes --- xarray/namedarray/_array_api/__init__.py | 137 +++++++++++------- xarray/namedarray/_array_api/_utils.py | 5 +- .../_array_api/indexing_functions.py | 2 +- .../_array_api/manipulation_functions.py | 3 + xarray/namedarray/core.py | 2 +- 5 files changed, 94 insertions(+), 55 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 9bce7c05c7a..eca394653a8 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -4,7 +4,7 @@ __all__ += ["__array_api_version__"] -from xarray.namedarray._array_api.array_object import Array +from xarray.namedarray.core import NamedArray as Array __all__ += ["Array"] @@ -15,37 +15,37 @@ from xarray.namedarray._array_api.creation_functions import ( arange, asarray, - empty, - empty_like, - eye, + # empty, + # empty_like, + # eye, full, - full_like, + # full_like, linspace, - meshgrid, + # meshgrid, ones, - ones_like, - tril, - triu, + # ones_like, + # tril, + # triu, zeros, - zeros_like, + # zeros_like, ) __all__ += [ "arange", "asarray", - "empty", - "empty_like", - "eye", + # "empty", + # "empty_like", + # "eye", "full", - "full_like", + # "full_like", "linspace", - "meshgrid", + # "meshgrid", "ones", - "ones_like", - "tril", - "triu", + # "ones_like", + # "tril", + # "triu", "zeros", - "zeros_like", + # "zeros_like", ] from xarray.namedarray._array_api.data_type_functions import ( @@ -57,7 +57,14 @@ result_type, ) -__all__ += ["astype", "can_cast", "finfo", "iinfo", "isdtype", "result_type"] +__all__ += [ + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", +] from xarray.namedarray._array_api.dtypes import ( bool, @@ -215,56 +222,82 @@ "trunc", ] -from xarray.namedarray._array_api.indexing_functions import take +# from xarray.namedarray._array_api.indexing_functions import take -__all__ += ["take"] +# __all__ += ["take"] -from xarray.namedarray._array_api.linear_algebra_functions import ( - matmul, - matrix_transpose, - outer, - tensordot, - vecdot, -) +# from xarray.namedarray._array_api.linear_algebra_functions import ( +# matmul, +# matrix_transpose, +# outer, +# tensordot, +# vecdot, +# ) -__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] +# __all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] from xarray.namedarray._array_api.manipulation_functions import ( - broadcast_arrays, - broadcast_to, - concat, + # broadcast_arrays, + # broadcast_to, + # concat, expand_dims, - flip, - moveaxis, + # flip, + # moveaxis, permute_dims, reshape, - roll, - squeeze, - stack, + # roll, + # squeeze, + # stack, ) __all__ += [ - "broadcast_arrays", - "broadcast_to", - "concat", + # "broadcast_arrays", + # "broadcast_to", + # "concat", "expand_dims", - "flip", - "moveaxis", + # "flip", + # "moveaxis", "permute_dims", "reshape", - "roll", - "squeeze", - "stack", + # "roll", + # "squeeze", + # "stack", ] -from xarray.namedarray._array_api.searching_functions import argmax, argmin, where +# from xarray.namedarray._array_api.searching_functions import ( +# argmax, +# argmin, +# where, +# ) -__all__ += ["argmax", "argmin", "where"] +# __all__ += [ +# "argmax", +# "argmin", +# "where", +# ] -from xarray.namedarray._array_api.statistical_functions import max, mean, min, prod, sum +from xarray.namedarray._array_api.statistical_functions import ( + # max, + mean, + # min, + # prod, + # sum, +) -__all__ += ["max", "mean", "min", "prod", "sum"] +__all__ += [ + # "max", + "mean", + # "min", + # "prod", + # "sum", +] -from xarray.namedarray._array_api.utility_functions import all, any +from xarray.namedarray._array_api.utility_functions import ( + all, + any, +) -__all__ += ["all", "any"] +__all__ += [ + "all", + "any", +] diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 445012b3059..f630f2261e6 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING, Any, ModuleType +from __future__ import annotations + +from types import ModuleType +from typing import TYPE_CHECKING, Any import numpy as np diff --git a/xarray/namedarray/_array_api/indexing_functions.py b/xarray/namedarray/_array_api/indexing_functions.py index b62d0a8393b..12b69a6bc86 100644 --- a/xarray/namedarray/_array_api/indexing_functions.py +++ b/xarray/namedarray/_array_api/indexing_functions.py @@ -1,3 +1,3 @@ from xarray.namedarray._array_api._utils import _get_data_namespace -sdf = _get_data_namespace() +sdf = _get_data_namespace diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py index 3da92b76fa9..503f0471b09 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -2,6 +2,9 @@ from typing import Any +from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api.creation_functions import asarray + from xarray.namedarray._typing import ( Default, _arrayapi, diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 634267be4df..74d78f21d26 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -788,7 +788,7 @@ def nbytes(self) -> _IntOrUnknown: If the underlying data array does not include ``nbytes``, estimates the bytes consumed based on the ``size`` and ``dtype``. """ - from xarray.namedarray._array_api import _get_data_namespace + from xarray.namedarray._array_api._utils import _get_data_namespace if hasattr(self._data, "nbytes"): return self._data.nbytes # type: ignore[no-any-return] From 7f3d50205d6fd2c47b47c133569cba3157572140 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:47:11 +0000 Subject: [PATCH 100/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/__init__.py | 7 ------- xarray/namedarray/_array_api/manipulation_functions.py | 1 - 2 files changed, 8 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index eca394653a8..7bd73181c4c 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -27,7 +27,6 @@ # tril, # triu, zeros, - # zeros_like, ) __all__ += [ @@ -245,9 +244,6 @@ # moveaxis, permute_dims, reshape, - # roll, - # squeeze, - # stack, ) __all__ += [ @@ -279,9 +275,6 @@ from xarray.namedarray._array_api.statistical_functions import ( # max, mean, - # min, - # prod, - # sum, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py index 503f0471b09..15ab7da7868 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -4,7 +4,6 @@ from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._array_api.creation_functions import asarray - from xarray.namedarray._typing import ( Default, _arrayapi, From 7d32bdc7f95acae3eeca13a73a5c61e40600eea6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 19:48:30 +0200 Subject: [PATCH 101/367] Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit 7f3d50205d6fd2c47b47c133569cba3157572140. --- xarray/namedarray/_array_api/__init__.py | 7 +++++++ xarray/namedarray/_array_api/manipulation_functions.py | 1 + 2 files changed, 8 insertions(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 7bd73181c4c..eca394653a8 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -27,6 +27,7 @@ # tril, # triu, zeros, + # zeros_like, ) __all__ += [ @@ -244,6 +245,9 @@ # moveaxis, permute_dims, reshape, + # roll, + # squeeze, + # stack, ) __all__ += [ @@ -275,6 +279,9 @@ from xarray.namedarray._array_api.statistical_functions import ( # max, mean, + # min, + # prod, + # sum, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py index 15ab7da7868..503f0471b09 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -4,6 +4,7 @@ from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._array_api.creation_functions import asarray + from xarray.namedarray._typing import ( Default, _arrayapi, From c1163ada6de9ea1f563d2a5eeb633c1ca769f8e3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:20:00 +0200 Subject: [PATCH 102/367] Update _typing.py --- xarray/namedarray/_typing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 4b26d42eaf6..3b88e110d26 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -50,8 +50,6 @@ class Default(Enum): _ScalarType = TypeVar("_ScalarType", bound=np.generic) _ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) -_ArrayLike = np.typing.ArrayLike - # A protocol for anything with the dtype attribute @runtime_checkable @@ -100,6 +98,10 @@ def dtype(self) -> _DType_co: ... _AttrsLike = Union[Mapping[Any, Any], None] +_ArrayLike = np.typing.ArrayLike + +_Device = Any + class _SupportsReal(Protocol[_T_co]): @property From c32abe50b3b1932eb15ce0f96f05e7a0eae3c9d0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:21:53 +0200 Subject: [PATCH 103/367] Update core.py --- xarray/namedarray/core.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 74d78f21d26..7e2e421327a 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -59,6 +59,7 @@ _AttrsLike, _AxisLike, _Chunks, + _Device, _Dim, _Dims, _DimsLike, @@ -752,6 +753,20 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray): else: raise NotImplementedError("{k=} is not supported") + @property + def device(self) -> _Device: + """ + Device of the array’s elements. + + See Also + -------- + ndarray.device + """ + if isinstance(self._data, _arrayapi): + return self._data.device + else: + raise NotImplementedError("self._data missing device") + @property def dtype(self) -> _DType_co: """ From 1855f7fdad7dd4117c4d4551b4ffebe31a5f62f5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:22:34 +0200 Subject: [PATCH 104/367] Add more creation functions --- .../_array_api/creation_functions.py | 98 ++++++++++++++++--- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/creation_functions.py index 957f11e6940..f0588fba8d6 100644 --- a/xarray/namedarray/_array_api/creation_functions.py +++ b/xarray/namedarray/_array_api/creation_functions.py @@ -4,12 +4,16 @@ import numpy as np -from xarray.namedarray._array_api._utils import _maybe_default_namespace +from xarray.namedarray._array_api._utils import ( + _maybe_default_namespace, + _get_data_namespace, +) from xarray.namedarray._typing import ( Default, _arrayfunction_or_api, _ArrayLike, _default, + _Device, _DimsLike, _DType, _Shape, @@ -24,6 +28,12 @@ ) +def _like_args(x, dtype=None, device: _Device | None = None): + if dtype is None: + dtype = x.dtype + return dict(shape=x.shape, dtype=dtype, device=device) + + def _infer_dims( shape: _Shape, dims: _DimsLike | Default = _default, @@ -41,7 +51,7 @@ def arange( step: int | float = 1, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) @@ -55,7 +65,7 @@ def asarray( /, *, dtype: _DType, - device=..., + device: _Device | None = ..., copy: bool | None = ..., dims: _DimsLike = ..., ) -> NamedArray[_ShapeType, _DType]: ... @@ -65,7 +75,7 @@ def asarray( /, *, dtype: _DType, - device=..., + device: _Device | None = ..., copy: bool | None = ..., dims: _DimsLike = ..., ) -> NamedArray[Any, _DType]: ... @@ -75,7 +85,7 @@ def asarray( /, *, dtype: None, - device=None, + device: _Device | None = None, copy: bool | None = None, dims: _DimsLike = ..., ) -> NamedArray[_ShapeType, _DType]: ... @@ -85,7 +95,7 @@ def asarray( /, *, dtype: None, - device=..., + device: _Device | None = ..., copy: bool | None = ..., dims: _DimsLike = ..., ) -> NamedArray[Any, _DType]: ... @@ -94,7 +104,7 @@ def asarray( /, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, copy: bool | None = None, dims: _DimsLike = _default, ) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: @@ -146,12 +156,35 @@ def asarray( return NamedArray(_dims, _data) +def empty( + shape: _ShapeType, *, dtype: _DType | None = None, device: _Device | None = None +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.empty(shape, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def empty_like( + x: NamedArray[_ShapeType, _DType], + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.empty(x.shape, dtype=_dtype, device=_device) + return x._new(data=_data) + + def full( shape: _Shape, fill_value: bool | int | float | complex, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() _data = xp.full(shape, fill_value, dtype=dtype, device=device) @@ -159,6 +192,21 @@ def full( return NamedArray(_dims, _data) +def full_like( + x: NamedArray[_ShapeType, _DType], + fill_value: bool | int | float | complex, + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.full(x.shape, fill_value, dtype=_dtype, device=_device) + return x._new(data=_data) + + def linspace( start: int | float | complex, stop: int | float | complex, @@ -166,7 +214,7 @@ def linspace( num: int, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, endpoint: bool = True, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() @@ -178,12 +226,40 @@ def linspace( def ones( - shape: _Shape, *, dtype: _DType | None = None, device=None + shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: return full(shape, 1, dtype=dtype, device=device) +def ones_like( + x: NamedArray[_ShapeType, _DType], + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.ones(x.shape, dtype=_dtype, device=_device) + return x._new(data=_data) + + def zeros( - shape: _Shape, *, dtype: _DType | None = None, device=None + shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: return full(shape, 0, dtype=dtype, device=device) + + +def zeros_like( + x: NamedArray[_ShapeType, _DType], + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.zeros(x.shape, dtype=_dtype, device=_device) + return x._new(data=_data) From cd9061827805d2a245a1c2e790a000a9815e06d9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:22:49 +0200 Subject: [PATCH 105/367] add more statistical functions --- .../_array_api/statistical_functions.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/xarray/namedarray/_array_api/statistical_functions.py b/xarray/namedarray/_array_api/statistical_functions.py index 8aa1db92a7f..d99f0f812ef 100644 --- a/xarray/namedarray/_array_api/statistical_functions.py +++ b/xarray/namedarray/_array_api/statistical_functions.py @@ -17,6 +17,23 @@ ) +def max( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.max(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + def mean( x: NamedArray[Any, _DType], /, @@ -82,3 +99,54 @@ def mean( dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out + + +def min( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.min(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def prod( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.prod(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def sum( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + axis_ = _dims_to_axis(x, dims, axis) + d = xp.sum(x._data, axis=axis_, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out From 44a16c7e04edfe9e48abdccb4eafd5250efb255e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:22:54 +0200 Subject: [PATCH 106/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 36 ++++++++++++------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index eca394653a8..c4c55d6ff10 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -15,37 +15,37 @@ from xarray.namedarray._array_api.creation_functions import ( arange, asarray, - # empty, - # empty_like, + empty, + empty_like, # eye, full, - # full_like, + full_like, linspace, # meshgrid, ones, - # ones_like, + ones_like, # tril, # triu, zeros, - # zeros_like, + zeros_like, ) __all__ += [ "arange", "asarray", - # "empty", - # "empty_like", + "empty", + "empty_like", # "eye", "full", - # "full_like", + "full_like", "linspace", # "meshgrid", "ones", - # "ones_like", + "ones_like", # "tril", # "triu", "zeros", - # "zeros_like", + "zeros_like", ] from xarray.namedarray._array_api.data_type_functions import ( @@ -277,19 +277,19 @@ # ] from xarray.namedarray._array_api.statistical_functions import ( - # max, + max, mean, - # min, - # prod, - # sum, + min, + prod, + sum, ) __all__ += [ - # "max", + "max", "mean", - # "min", - # "prod", - # "sum", + "min", + "prod", + "sum", ] from xarray.namedarray._array_api.utility_functions import ( From ef6efb95f6e4119c5cb7998bcaa83ea7041e2d8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 20:23:32 +0000 Subject: [PATCH 107/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/__init__.py | 3 --- xarray/namedarray/_array_api/creation_functions.py | 2 +- xarray/namedarray/_array_api/manipulation_functions.py | 1 - 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index c4c55d6ff10..f8ef66c6655 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -245,9 +245,6 @@ # moveaxis, permute_dims, reshape, - # roll, - # squeeze, - # stack, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/creation_functions.py index f0588fba8d6..12204f07b9c 100644 --- a/xarray/namedarray/_array_api/creation_functions.py +++ b/xarray/namedarray/_array_api/creation_functions.py @@ -5,8 +5,8 @@ import numpy as np from xarray.namedarray._array_api._utils import ( - _maybe_default_namespace, _get_data_namespace, + _maybe_default_namespace, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py index 503f0471b09..15ab7da7868 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -4,7 +4,6 @@ from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._array_api.creation_functions import asarray - from xarray.namedarray._typing import ( Default, _arrayapi, From 878209ea5375307f6c968c78dbfb7fc705decc5d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 22:24:14 +0200 Subject: [PATCH 108/367] Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit ef6efb95f6e4119c5cb7998bcaa83ea7041e2d8f. --- xarray/namedarray/_array_api/__init__.py | 3 +++ xarray/namedarray/_array_api/creation_functions.py | 2 +- xarray/namedarray/_array_api/manipulation_functions.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index f8ef66c6655..c4c55d6ff10 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -245,6 +245,9 @@ # moveaxis, permute_dims, reshape, + # roll, + # squeeze, + # stack, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/creation_functions.py index 12204f07b9c..f0588fba8d6 100644 --- a/xarray/namedarray/_array_api/creation_functions.py +++ b/xarray/namedarray/_array_api/creation_functions.py @@ -5,8 +5,8 @@ import numpy as np from xarray.namedarray._array_api._utils import ( - _get_data_namespace, _maybe_default_namespace, + _get_data_namespace, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/manipulation_functions.py index 15ab7da7868..503f0471b09 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/manipulation_functions.py @@ -4,6 +4,7 @@ from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._array_api.creation_functions import asarray + from xarray.namedarray._typing import ( Default, _arrayapi, From 5698c969e18df375a9793530fcc65bf524c842b4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:20:36 +0200 Subject: [PATCH 109/367] add underscore to files, following array api strict --- xarray/namedarray/_array_api/__init__.py | 26 +-- .../{constants.py => _constants.py} | 0 ...on_functions.py => _creation_functions.py} | 0 ...e_functions.py => _data_type_functions.py} | 0 .../_array_api/{dtypes.py => _dtypes.py} | 0 ...functions.py => _elementwise_functions.py} | 0 ...ng_functions.py => _indexing_functions.py} | 0 xarray/namedarray/_array_api/_info.py | 152 ++++++++++++++++++ ...ctions.py => _linear_algebra_functions.py} | 0 ...unctions.py => _manipulation_functions.py} | 2 +- ...g_functions.py => _searching_functions.py} | 0 ...functions.py => _statistical_functions.py} | 0 ...ity_functions.py => _utility_functions.py} | 0 13 files changed, 169 insertions(+), 11 deletions(-) rename xarray/namedarray/_array_api/{constants.py => _constants.py} (100%) rename xarray/namedarray/_array_api/{creation_functions.py => _creation_functions.py} (100%) rename xarray/namedarray/_array_api/{data_type_functions.py => _data_type_functions.py} (100%) rename xarray/namedarray/_array_api/{dtypes.py => _dtypes.py} (100%) rename xarray/namedarray/_array_api/{elementwise_functions.py => _elementwise_functions.py} (100%) rename xarray/namedarray/_array_api/{indexing_functions.py => _indexing_functions.py} (100%) create mode 100644 xarray/namedarray/_array_api/_info.py rename xarray/namedarray/_array_api/{linear_algebra_functions.py => _linear_algebra_functions.py} (100%) rename xarray/namedarray/_array_api/{manipulation_functions.py => _manipulation_functions.py} (97%) rename xarray/namedarray/_array_api/{searching_functions.py => _searching_functions.py} (100%) rename xarray/namedarray/_array_api/{statistical_functions.py => _statistical_functions.py} (100%) rename xarray/namedarray/_array_api/{utility_functions.py => _utility_functions.py} (100%) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index c4c55d6ff10..370286d3c61 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -8,11 +8,11 @@ __all__ += ["Array"] -from xarray.namedarray._array_api.constants import e, inf, nan, newaxis, pi +from xarray.namedarray._array_api._constants import e, inf, nan, newaxis, pi __all__ += ["e", "inf", "nan", "newaxis", "pi"] -from xarray.namedarray._array_api.creation_functions import ( +from xarray.namedarray._array_api._creation_functions import ( arange, asarray, empty, @@ -48,7 +48,7 @@ "zeros_like", ] -from xarray.namedarray._array_api.data_type_functions import ( +from xarray.namedarray._array_api._data_type_functions import ( astype, can_cast, finfo, @@ -66,7 +66,7 @@ "result_type", ] -from xarray.namedarray._array_api.dtypes import ( +from xarray.namedarray._array_api._dtypes import ( bool, complex64, complex128, @@ -98,7 +98,7 @@ "uint64", ] -from xarray.namedarray._array_api.elementwise_functions import ( +from xarray.namedarray._array_api._elementwise_functions import ( abs, acos, acosh, @@ -226,7 +226,13 @@ # __all__ += ["take"] -# from xarray.namedarray._array_api.linear_algebra_functions import ( +from ._info import __array_namespace_info__ + +__all__ += [ + "__array_namespace_info__", +] + +# from xarray.namedarray._array_api._linear_algebra_functions import ( # matmul, # matrix_transpose, # outer, @@ -236,7 +242,7 @@ # __all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] -from xarray.namedarray._array_api.manipulation_functions import ( +from xarray.namedarray._array_api._manipulation_functions import ( # broadcast_arrays, # broadcast_to, # concat, @@ -264,7 +270,7 @@ # "stack", ] -# from xarray.namedarray._array_api.searching_functions import ( +# from xarray.namedarray._array_api._searching_functions import ( # argmax, # argmin, # where, @@ -276,7 +282,7 @@ # "where", # ] -from xarray.namedarray._array_api.statistical_functions import ( +from xarray.namedarray._array_api._statistical_functions import ( max, mean, min, @@ -292,7 +298,7 @@ "sum", ] -from xarray.namedarray._array_api.utility_functions import ( +from xarray.namedarray._array_api._utility_functions import ( all, any, ) diff --git a/xarray/namedarray/_array_api/constants.py b/xarray/namedarray/_array_api/_constants.py similarity index 100% rename from xarray/namedarray/_array_api/constants.py rename to xarray/namedarray/_array_api/_constants.py diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py similarity index 100% rename from xarray/namedarray/_array_api/creation_functions.py rename to xarray/namedarray/_array_api/_creation_functions.py diff --git a/xarray/namedarray/_array_api/data_type_functions.py b/xarray/namedarray/_array_api/_data_type_functions.py similarity index 100% rename from xarray/namedarray/_array_api/data_type_functions.py rename to xarray/namedarray/_array_api/_data_type_functions.py diff --git a/xarray/namedarray/_array_api/dtypes.py b/xarray/namedarray/_array_api/_dtypes.py similarity index 100% rename from xarray/namedarray/_array_api/dtypes.py rename to xarray/namedarray/_array_api/_dtypes.py diff --git a/xarray/namedarray/_array_api/elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py similarity index 100% rename from xarray/namedarray/_array_api/elementwise_functions.py rename to xarray/namedarray/_array_api/_elementwise_functions.py diff --git a/xarray/namedarray/_array_api/indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py similarity index 100% rename from xarray/namedarray/_array_api/indexing_functions.py rename to xarray/namedarray/_array_api/_indexing_functions.py diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py new file mode 100644 index 00000000000..c50d3594f65 --- /dev/null +++ b/xarray/namedarray/_array_api/_info.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from types import ModuleType + from typing import Optional, Union, Tuple, List + from xarray.namedarray._typing import _Device + +# from ._array_object import CPU_DEVICE +from ._dtypes import ( + bool, + int8, + int16, + int32, + int64, + uint8, + uint16, + uint32, + uint64, + float32, + float64, + complex64, + complex128, +) + + +def __array_namespace_info__() -> ModuleType: + import xarray.namedarray._array_api._info + + return xarray.namedarray._array_api._info + + +def capabilities() -> dict: + return { + "boolean indexing": False, + "data-dependent shapes": False, + } + + +def default_device() -> _Device: + from xarray.namedarray._array_api._utils import _maybe_default_namespace + + xp = _maybe_default_namespace() + info = xp.__array_namespace_info__() + return info.default_device + + +def default_dtypes( + *, + device: _Device | None = None, +) -> dict: + return { + "real floating": float64, + "complex floating": complex128, + "integral": int64, + "indexing": int64, + } + + +def dtypes( + *, + device: _Device | None = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, +) -> dict: + if kind is None: + return { + "bool": bool, + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if kind == "bool": + return {"bool": bool} + if kind == "signed integer": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + } + if kind == "unsigned integer": + return { + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "integral": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + } + if kind == "real floating": + return { + "float32": float32, + "float64": float64, + } + if kind == "complex floating": + return { + "complex64": complex64, + "complex128": complex128, + } + if kind == "numeric": + return { + "int8": int8, + "int16": int16, + "int32": int32, + "int64": int64, + "uint8": uint8, + "uint16": uint16, + "uint32": uint32, + "uint64": uint64, + "float32": float32, + "float64": float64, + "complex64": complex64, + "complex128": complex128, + } + if isinstance(kind, tuple): + res = {} + for k in kind: + res.update(dtypes(kind=k)) + return res + raise ValueError(f"unsupported kind: {kind!r}") + + +def devices() -> List[_Device]: + return [default_device()] + + +__all__ = [ + "capabilities", + "default_device", + "default_dtypes", + "devices", + "dtypes", +] diff --git a/xarray/namedarray/_array_api/linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py similarity index 100% rename from xarray/namedarray/_array_api/linear_algebra_functions.py rename to xarray/namedarray/_array_api/_linear_algebra_functions.py diff --git a/xarray/namedarray/_array_api/manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py similarity index 97% rename from xarray/namedarray/_array_api/manipulation_functions.py rename to xarray/namedarray/_array_api/_manipulation_functions.py index 503f0471b09..eb04c21b1aa 100644 --- a/xarray/namedarray/_array_api/manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -3,7 +3,7 @@ from typing import Any from xarray.namedarray._array_api._utils import _get_data_namespace -from xarray.namedarray._array_api.creation_functions import asarray +from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py similarity index 100% rename from xarray/namedarray/_array_api/searching_functions.py rename to xarray/namedarray/_array_api/_searching_functions.py diff --git a/xarray/namedarray/_array_api/statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py similarity index 100% rename from xarray/namedarray/_array_api/statistical_functions.py rename to xarray/namedarray/_array_api/_statistical_functions.py diff --git a/xarray/namedarray/_array_api/utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py similarity index 100% rename from xarray/namedarray/_array_api/utility_functions.py rename to xarray/namedarray/_array_api/_utility_functions.py From d912de64f97cbc08b1a311de306cae5e73c4e752 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:22:29 +0200 Subject: [PATCH 110/367] add info --- xarray/namedarray/_array_api/__init__.py | 2 +- xarray/namedarray/_array_api/_info.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 370286d3c61..3e0fa49b1cc 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -226,7 +226,7 @@ # __all__ += ["take"] -from ._info import __array_namespace_info__ +from xarray.namedarray._array_api._info import __array_namespace_info__ __all__ += [ "__array_namespace_info__", diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index c50d3594f65..006b724fccd 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -7,8 +7,7 @@ from typing import Optional, Union, Tuple, List from xarray.namedarray._typing import _Device -# from ._array_object import CPU_DEVICE -from ._dtypes import ( +from xarray.namedarray._array_api._dtypes import ( bool, int8, int16, From 0595c4228a9616b5d1fc170efceec493dc059e37 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:23:13 +0000 Subject: [PATCH 111/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/__init__.py | 3 --- .../namedarray/_array_api/_creation_functions.py | 2 +- xarray/namedarray/_array_api/_info.py | 15 ++++++++------- .../_array_api/_manipulation_functions.py | 3 +-- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 3e0fa49b1cc..083c3a570ff 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -251,9 +251,6 @@ # moveaxis, permute_dims, reshape, - # roll, - # squeeze, - # stack, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index f0588fba8d6..12204f07b9c 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -5,8 +5,8 @@ import numpy as np from xarray.namedarray._array_api._utils import ( - _maybe_default_namespace, _get_data_namespace, + _maybe_default_namespace, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index 006b724fccd..9979ce5f854 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -4,11 +4,16 @@ if TYPE_CHECKING: from types import ModuleType - from typing import Optional, Union, Tuple, List + from typing import Optional, Union + from xarray.namedarray._typing import _Device from xarray.namedarray._array_api._dtypes import ( bool, + complex64, + complex128, + float32, + float64, int8, int16, int32, @@ -17,10 +22,6 @@ uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) @@ -60,7 +61,7 @@ def default_dtypes( def dtypes( *, device: _Device | None = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, + kind: Optional[Union[str, tuple[str, ...]]] = None, ) -> dict: if kind is None: return { @@ -138,7 +139,7 @@ def dtypes( raise ValueError(f"unsupported kind: {kind!r}") -def devices() -> List[_Device]: +def devices() -> list[_Device]: return [default_device()] diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index eb04c21b1aa..c08287a9442 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -2,9 +2,8 @@ from typing import Any -from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._array_api._creation_functions import asarray - +from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( Default, _arrayapi, From dd8f93a9d178037b9267a6f26537684216ff8ac4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:27:52 +0200 Subject: [PATCH 112/367] Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit 0595c4228a9616b5d1fc170efceec493dc059e37. --- xarray/namedarray/_array_api/__init__.py | 3 +++ .../namedarray/_array_api/_creation_functions.py | 2 +- xarray/namedarray/_array_api/_info.py | 15 +++++++-------- .../_array_api/_manipulation_functions.py | 3 ++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 083c3a570ff..3e0fa49b1cc 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -251,6 +251,9 @@ # moveaxis, permute_dims, reshape, + # roll, + # squeeze, + # stack, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 12204f07b9c..f0588fba8d6 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -5,8 +5,8 @@ import numpy as np from xarray.namedarray._array_api._utils import ( - _get_data_namespace, _maybe_default_namespace, + _get_data_namespace, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index 9979ce5f854..006b724fccd 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -4,16 +4,11 @@ if TYPE_CHECKING: from types import ModuleType - from typing import Optional, Union - + from typing import Optional, Union, Tuple, List from xarray.namedarray._typing import _Device from xarray.namedarray._array_api._dtypes import ( bool, - complex64, - complex128, - float32, - float64, int8, int16, int32, @@ -22,6 +17,10 @@ uint16, uint32, uint64, + float32, + float64, + complex64, + complex128, ) @@ -61,7 +60,7 @@ def default_dtypes( def dtypes( *, device: _Device | None = None, - kind: Optional[Union[str, tuple[str, ...]]] = None, + kind: Optional[Union[str, Tuple[str, ...]]] = None, ) -> dict: if kind is None: return { @@ -139,7 +138,7 @@ def dtypes( raise ValueError(f"unsupported kind: {kind!r}") -def devices() -> list[_Device]: +def devices() -> List[_Device]: return [default_device()] diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index c08287a9442..eb04c21b1aa 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -2,8 +2,9 @@ from typing import Any -from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._creation_functions import asarray + from xarray.namedarray._typing import ( Default, _arrayapi, From f1efd10a87a0c540bb166d12089ff1a46b843d0e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:28:42 +0200 Subject: [PATCH 113/367] add dummy stack --- xarray/namedarray/_array_api/__init__.py | 2 +- xarray/namedarray/_array_api/_manipulation_functions.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 3e0fa49b1cc..06631dceba3 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -253,7 +253,7 @@ reshape, # roll, # squeeze, - # stack, + stack, ) __all__ += [ diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index eb04c21b1aa..d8735fafd31 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -104,3 +104,7 @@ def reshape(x, /, shape: _Shape, *, copy: bool | None = None): # TODO: If reshaping should we save the dims? # TODO: What's the xarray equivalent? return out + + +def stack(arrays, /, *, axis=0): + raise NotImplementedError("TODO:") From 8098fa30798cfe9623b53aa99f4557e0c55affcf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:29:23 +0000 Subject: [PATCH 114/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../namedarray/_array_api/_creation_functions.py | 2 +- xarray/namedarray/_array_api/_info.py | 15 ++++++++------- .../_array_api/_manipulation_functions.py | 3 +-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index f0588fba8d6..12204f07b9c 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -5,8 +5,8 @@ import numpy as np from xarray.namedarray._array_api._utils import ( - _maybe_default_namespace, _get_data_namespace, + _maybe_default_namespace, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index 006b724fccd..9979ce5f854 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -4,11 +4,16 @@ if TYPE_CHECKING: from types import ModuleType - from typing import Optional, Union, Tuple, List + from typing import Optional, Union + from xarray.namedarray._typing import _Device from xarray.namedarray._array_api._dtypes import ( bool, + complex64, + complex128, + float32, + float64, int8, int16, int32, @@ -17,10 +22,6 @@ uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) @@ -60,7 +61,7 @@ def default_dtypes( def dtypes( *, device: _Device | None = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, + kind: Optional[Union[str, tuple[str, ...]]] = None, ) -> dict: if kind is None: return { @@ -138,7 +139,7 @@ def dtypes( raise ValueError(f"unsupported kind: {kind!r}") -def devices() -> List[_Device]: +def devices() -> list[_Device]: return [default_device()] diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index d8735fafd31..662911cbbf6 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -2,9 +2,8 @@ from typing import Any -from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._array_api._creation_functions import asarray - +from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( Default, _arrayapi, From a0489f0bb25bfe3c0bd537717080216a62e474a0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:30:33 +0200 Subject: [PATCH 115/367] Update _info.py --- xarray/namedarray/_array_api/_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index 006b724fccd..f545ab38ae3 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -60,7 +60,7 @@ def default_dtypes( def dtypes( *, device: _Device | None = None, - kind: Optional[Union[str, Tuple[str, ...]]] = None, + kind: Union[str, Tuple[str, ...]] | None = None, ) -> dict: if kind is None: return { From a2e020ef1a3ce46827fb1ef25a70ac7cc3d3408f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:32:38 +0000 Subject: [PATCH 116/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_info.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index 7eb34b0df0d..beb30a1d0ad 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -4,7 +4,6 @@ if TYPE_CHECKING: from types import ModuleType - from typing import Optional, Union from xarray.namedarray._typing import _Device @@ -61,7 +60,7 @@ def default_dtypes( def dtypes( *, device: _Device | None = None, - kind: str| Tuple[str, ...] | None = None, + kind: str | Tuple[str, ...] | None = None, ) -> dict: if kind is None: return { From 6eb9c7f2f5b422aa4d8efe0d30c4fa2841d4c6d9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 00:07:20 +0200 Subject: [PATCH 117/367] Delete _array_api2.py --- xarray/namedarray/_array_api2.py | 1030 ------------------------------ 1 file changed, 1030 deletions(-) delete mode 100644 xarray/namedarray/_array_api2.py diff --git a/xarray/namedarray/_array_api2.py b/xarray/namedarray/_array_api2.py deleted file mode 100644 index 1eed255b506..00000000000 --- a/xarray/namedarray/_array_api2.py +++ /dev/null @@ -1,1030 +0,0 @@ -from __future__ import annotations - -from types import ModuleType -from typing import Any, overload - -import numpy as np - -from xarray.namedarray._typing import ( - Default, - _arrayapi, - _arrayfunction_or_api, - _ArrayLike, - _Axes, - _Axis, - _AxisLike, - _default, - _Dim, - _Dims, - _DimsLike, - _DType, - _dtype, - _ScalarType, - _Shape, - _ShapeType, - _SupportsImag, - _SupportsReal, - duckarray, -) -from xarray.namedarray.core import ( - NamedArray, - _dims_to_axis, - _get_remaining_dims, -) -from xarray.namedarray.utils import ( - to_0d_object_array, -) - - -# %% Helper functions -def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: - return np if xp is None else xp - - -def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: - if isinstance(x._data, _arrayapi): - return x._data.__array_namespace__() - - return _maybe_default_namespace() - - -def _get_namespace_dtype(dtype: _dtype) -> ModuleType: - xp = __import__(dtype.__module__) - return xp - - -# %% array_api version -__array_api_version__ = "2023.12" - - -# %% Constants -e = np.e -inf = np.inf -nan = np.nan -newaxis = np.newaxis -pi = np.pi - - -# %% Creation Functions -def _infer_dims( - shape: _Shape, - dims: _DimsLike | Default = _default, -) -> _DimsLike: - if dims is _default: - return tuple(f"dim_{n}" for n in range(len(shape))) - else: - return dims - - -def arange( - start: int | float, - /, - stop: int | float | None = None, - step: int | float = 1, - *, - dtype: _DType | None = None, - device=None, -) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() - _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) - _dims = _infer_dims(_data.shape) - return NamedArray(_dims, _data) - - -@overload -def asarray( - obj: duckarray[_ShapeType, Any], - /, - *, - dtype: _DType, - device=..., - copy: bool | None = ..., - dims: _DimsLike = ..., -) -> NamedArray[_ShapeType, _DType]: ... -@overload -def asarray( - obj: _ArrayLike, - /, - *, - dtype: _DType, - device=..., - copy: bool | None = ..., - dims: _DimsLike = ..., -) -> NamedArray[Any, _DType]: ... -@overload -def asarray( - obj: duckarray[_ShapeType, _DType], - /, - *, - dtype: None, - device=None, - copy: bool | None = None, - dims: _DimsLike = ..., -) -> NamedArray[_ShapeType, _DType]: ... -@overload -def asarray( - obj: _ArrayLike, - /, - *, - dtype: None, - device=..., - copy: bool | None = ..., - dims: _DimsLike = ..., -) -> NamedArray[Any, _DType]: ... -def asarray( - obj: duckarray[_ShapeType, _DType] | _ArrayLike, - /, - *, - dtype: _DType | None = None, - device=None, - copy: bool | None = None, - dims: _DimsLike = _default, -) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: - """ - Create a Named array from an array-like object. - - Parameters - ---------- - dims : str or iterable of str - Name(s) of the dimension(s). - data : T_DuckArray or ArrayLike - The actual data that populates the array. Should match the - shape specified by `dims`. - attrs : dict, optional - A dictionary containing any additional information or - attributes you want to store with the array. - Default is None, meaning no attributes will be stored. - """ - data = obj - if isinstance(data, NamedArray): - if copy: - return data.copy() - else: - return data - - # TODO: dask.array.ma.MaskedArray also exists, better way? - if isinstance(data, np.ma.MaskedArray): - mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] - if mask.any(): - # TODO: requires refactoring/vendoring xarray.core.dtypes and - # xarray.core.duck_array_ops - raise NotImplementedError("MaskedArray is not supported yet") - - _dims = _infer_dims(data.shape, dims) - return NamedArray(_dims, data) - - if isinstance(data, _arrayfunction_or_api): - _dims = _infer_dims(data.shape, dims) - return NamedArray(_dims, data) - - if isinstance(data, tuple): - _data = to_0d_object_array(data) - _dims = _infer_dims(_data.shape, dims) - return NamedArray(_dims, _data) - - # validate whether the data is valid data types. - _data = np.asarray(data, dtype=dtype, device=device, copy=copy) - _dims = _infer_dims(_data.shape, dims) - return NamedArray(_dims, _data) - - -def full( - shape: _Shape, - fill_value: bool | int | float | complex, - *, - dtype: _DType | None = None, - device=None, -) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() - _data = xp.full(shape, fill_value, dtype=dtype, device=device) - _dims = _infer_dims(_data.shape) - return NamedArray(_dims, _data) - - -def linspace( - start: int | float | complex, - stop: int | float | complex, - /, - num: int, - *, - dtype: _DType | None = None, - device=None, - endpoint: bool = True, -) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() - _data = xp.linspace( - start, stop, num=num, dtype=dtype, device=device, endpoint=endpoint - ) - _dims = _infer_dims(_data.shape) - return NamedArray(_dims, _data) - - -def ones( - shape: _Shape, *, dtype: _DType | None = None, device=None -) -> NamedArray[_ShapeType, _DType]: - return full(shape, 1, dtype=dtype, device=device) - - -def zeros( - shape: _Shape, *, dtype: _DType | None = None, device=None -) -> NamedArray[_ShapeType, _DType]: - return full(shape, 0, dtype=dtype, device=device) - - -# %% Data types -# TODO: should delegate to underlying array? Cubed doesn't at the moment. -int8 = np.int8 -int16 = np.int16 -int32 = np.int32 -int64 = np.int64 -uint8 = np.uint8 -uint16 = np.uint16 -uint32 = np.uint32 -uint64 = np.uint64 -float32 = np.float32 -float64 = np.float64 -complex64 = np.complex64 -complex128 = np.complex128 -bool = np.bool - -_all_dtypes = ( - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - complex64, - complex128, - bool, -) -_boolean_dtypes = (bool,) -_real_floating_dtypes = (float32, float64) -_floating_dtypes = (float32, float64, complex64, complex128) -_complex_floating_dtypes = (complex64, complex128) -_integer_dtypes = (int8, int16, int32, int64, uint8, uint16, uint32, uint64) -_signed_integer_dtypes = (int8, int16, int32, int64) -_unsigned_integer_dtypes = (uint8, uint16, uint32, uint64) -_integer_or_boolean_dtypes = ( - bool, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) -_real_numeric_dtypes = ( - float32, - float64, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) -_numeric_dtypes = ( - float32, - float64, - complex64, - complex128, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, -) - -_dtype_categories = { - "all": _all_dtypes, - "real numeric": _real_numeric_dtypes, - "numeric": _numeric_dtypes, - "integer": _integer_dtypes, - "integer or boolean": _integer_or_boolean_dtypes, - "boolean": _boolean_dtypes, - "real floating-point": _floating_dtypes, - "complex floating-point": _complex_floating_dtypes, - "floating-point": _floating_dtypes, -} - -# %% Data type functions - - -def astype( - x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True -) -> NamedArray[_ShapeType, _DType]: - """ - Copies an array to a specified data type irrespective of Type Promotion Rules rules. - - Parameters - ---------- - x : NamedArray - Array to cast. - dtype : _DType - Desired data type. - copy : bool, optional - Specifies whether to copy an array when the specified dtype matches the data - type of the input array x. - If True, a newly allocated array must always be returned. - If False and the specified dtype matches the data type of the input array, - the input array must be returned; otherwise, a newly allocated array must be - returned. Default: True. - - Returns - ------- - out : NamedArray - An array having the specified data type. The returned array must have the - same shape as x. - - Examples - -------- - >>> narr = NamedArray(("x",), np.asarray([1.5, 2.5])) - >>> narr - Size: 16B - array([1.5, 2.5]) - >>> astype(narr, np.dtype(np.int32)) - Size: 8B - array([1, 2], dtype=int32) - """ - if isinstance(x._data, _arrayapi): - xp = x._data.__array_namespace__() - return x._new(data=xp.astype(x._data, dtype, copy=copy)) - - # np.astype doesn't exist yet: - return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] - - -def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: - if isinstance(from_, NamedArray): - xp = _get_data_namespace(from_) - from_ = from_.dtype - return xp.can_cast(from_, to) - else: - xp = _get_namespace_dtype(from_) - return xp.can_cast(from_, to) - - -def finfo(type: _dtype | NamedArray[Any, Any], /): - if isinstance(type, NamedArray): - xp = _get_data_namespace(type) - return xp.finfo(type._data) - else: - xp = _get_namespace_dtype(type) - return xp.finfo(type) - - -def iinfo(type: _dtype | NamedArray[Any, Any], /): - if isinstance(type, NamedArray): - xp = _get_data_namespace(type) - return xp.iinfo(type._data) - else: - xp = _get_namespace_dtype(type) - return xp.iinfo(type) - - -def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: - xp = _get_namespace_dtype(type) - return xp.isdtype(dtype, kind) - - -def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: - # TODO: Empty arg? - arr_or_dtype = arrays_and_dtypes[0] - if isinstance(arr_or_dtype, NamedArray): - xp = _get_data_namespace(arr_or_dtype) - else: - xp = _get_namespace_dtype(arr_or_dtype) - - return xp.result_type( - *(a.dtype if isinstance(a, NamedArray) else a for a in arrays_and_dtypes) - ) - - -# %% Elementwise Functions -def abs(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.abs(x._data)) - return out - - -def acos(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.acos(x._data)) - return out - - -def acosh(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.acosh(x._data)) - return out - - -def add(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.add(x1._data, x2._data)) - return out - - -def asin(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.asin(x._data)) - return out - - -def asinh(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.asinh(x._data)) - return out - - -def atan(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.atan(x._data)) - return out - - -def atan2(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.atan2(x1._data, x2._data)) - return out - - -def atanh(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.atanh(x._data)) - return out - - -def bitwise_and(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_and(x1._data, x2._data)) - return out - - -def bitwise_invert(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.bitwise_invert(x._data)) - return out - - -def bitwise_left_shift(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_left_shift(x1._data, x2._data)) - return out - - -def bitwise_or(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_or(x1._data, x2._data)) - return out - - -def bitwise_right_shift(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_right_shift(x1._data, x2._data)) - return out - - -def bitwise_xor(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_xor(x1._data, x2._data)) - return out - - -def ceil(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.ceil(x._data)) - return out - - -def conj(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.conj(x._data)) - return out - - -def cos(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.cos(x._data)) - return out - - -def cosh(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.cosh(x._data)) - return out - - -def divide(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.divide(x1._data, x2._data)) - return out - - -def exp(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.exp(x._data)) - return out - - -def expm1(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.expm1(x._data)) - return out - - -def equal(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.equal(x1._data, x2._data)) - return out - - -def floor(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.floor(x._data)) - return out - - -def floor_divide(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.floor_divide(x1._data, x2._data)) - return out - - -def greater(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.greater(x1._data, x2._data)) - return out - - -def greater_equal(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.greater_equal(x1._data, x2._data)) - return out - - -def imag( - x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] -) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: - """ - Returns the imaginary component of a complex number for each element x_i of the - input array x. - - Parameters - ---------- - x : NamedArray - Input array. Should have a complex floating-point data type. - - Returns - ------- - out : NamedArray - An array containing the element-wise results. The returned array must have a - floating-point data type with the same floating-point precision as x - (e.g., if x is complex64, the returned array must have the floating-point - data type float32). - - Examples - -------- - >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) - >>> imag(narr) - Size: 16B - array([2., 4.]) - """ - xp = _get_data_namespace(x) - out = x._new(data=xp.imag(x._data)) - return out - - -def isfinite(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.isfinite(x._data)) - return out - - -def isinf(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.isinf(x._data)) - return out - - -def isnan(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.isnan(x._data)) - return out - - -def less(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.less(x1._data, x2._data)) - return out - - -def less_equal(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.less_equal(x1._data, x2._data)) - return out - - -def log(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.log(x._data)) - return out - - -def log1p(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.log1p(x._data)) - return out - - -def log2(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.log2(x._data)) - return out - - -def log10(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.log10(x._data)) - return out - - -def logaddexp(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logaddexp(x1._data, x2._data)) - return out - - -def logical_and(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logical_and(x1._data, x2._data)) - return out - - -def logical_not(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.logical_not(x._data)) - return out - - -def logical_or(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logical_or(x1._data, x2._data)) - return out - - -def logical_xor(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logical_xor(x1._data, x2._data)) - return out - - -def multiply(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.multiply(x1._data, x2._data)) - return out - - -def negative(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.negative(x._data)) - return out - - -def not_equal(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.not_equal(x1._data, x2._data)) - return out - - -def positive(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.positive(x._data)) - return out - - -def pow(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.pow(x1._data, x2._data)) - return out - - -def real( - x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var] -) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: - """ - Returns the real component of a complex number for each element x_i of the - input array x. - - Parameters - ---------- - x : NamedArray - Input array. Should have a complex floating-point data type. - - Returns - ------- - out : NamedArray - An array containing the element-wise results. The returned array must have a - floating-point data type with the same floating-point precision as x - (e.g., if x is complex64, the returned array must have the floating-point - data type float32). - - Examples - -------- - >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) - >>> real(narr) - Size: 16B - array([1., 2.]) - """ - xp = _get_data_namespace(x) - out = x._new(data=xp.real(x._data)) - return out - - -def remainder(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.remainder(x1._data, x2._data)) - return out - - -def round(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.round(x._data)) - return out - - -def sign(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.sign(x._data)) - return out - - -def sin(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.sin(x._data)) - return out - - -def sinh(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.sinh(x._data)) - return out - - -def sqrt(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.sqrt(x._data)) - return out - - -def square(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.square(x._data)) - return out - - -def subtract(x1, x2, /): - xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.subtract(x1._data, x2._data)) - return out - - -def tan(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.tan(x._data)) - return out - - -def tanh(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.tanh(x._data)) - return out - - -def trunc(x, /): - xp = _get_data_namespace(x) - out = x._new(data=xp.trunc(x._data)) - return out - - -# %% Manipulation functions -def expand_dims( - x: NamedArray[Any, _DType], - /, - *, - dim: _Dim | Default = _default, - axis: _Axis = 0, -) -> NamedArray[Any, _DType]: - """ - Expands the shape of an array by inserting a new dimension of size one at the - position specified by dims. - - Parameters - ---------- - x : - Array to expand. - dim : - Dimension name. New dimension will be stored in the axis position. - axis : - (Not recommended) Axis position (zero-based). Default is 0. - - Returns - ------- - out : - An expanded output array having the same data type as x. - - Examples - -------- - >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) - >>> expand_dims(x) - Size: 32B - array([[[1., 2.], - [3., 4.]]]) - >>> expand_dims(x, dim="z") - Size: 32B - array([[[1., 2.], - [3., 4.]]]) - """ - xp = _get_data_namespace(x) - dims = x.dims - if dim is _default: - dim = f"dim_{len(dims)}" - d = list(dims) - d.insert(axis, dim) - out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) - return out - - -def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: - """ - Permutes the dimensions of an array. - - Parameters - ---------- - x : - Array to permute. - axes : - Permutation of the dimensions of x. - - Returns - ------- - out : - An array with permuted dimensions. The returned array must have the same - data type as x. - - """ - - dims = x.dims - new_dims = tuple(dims[i] for i in axes) - if isinstance(x._data, _arrayapi): - xp = _get_data_namespace(x) - out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) - else: - out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] - return out - - -def reshape(x, /, shape: _Shape, *, copy: bool | None = None): - xp = _get_data_namespace(x) - _data = xp.reshape(x._data, shape) - out = asarray(_data, copy=copy) - # TODO: Have better control where the dims went. - # TODO: If reshaping should we save the dims? - # TODO: What's the xarray equivalent? - return out - - -# %% Statistical Functions -def mean( - x: NamedArray[Any, _DType], - /, - *, - dims: _Dims | Default = _default, - keepdims: bool = False, - axis: _AxisLike | None = None, -) -> NamedArray[Any, _DType]: - """ - Calculates the arithmetic mean of the input array x. - - Parameters - ---------- - x : - Should have a real-valued floating-point data type. - dims : - Dim or dims along which arithmetic means must be computed. By default, - the mean must be computed over the entire array. If a tuple of hashables, - arithmetic means must be computed over multiple axes. - Default: None. - keepdims : - if True, the reduced axes (dimensions) must be included in the result - as singleton dimensions, and, accordingly, the result must be compatible - with the input array (see Broadcasting). Otherwise, if False, the - reduced axes (dimensions) must not be included in the result. - Default: False. - axis : - Axis or axes along which arithmetic means must be computed. By default, - the mean must be computed over the entire array. If a tuple of integers, - arithmetic means must be computed over multiple axes. - Default: None. - - Returns - ------- - out : - If the arithmetic mean was computed over the entire array, - a zero-dimensional array containing the arithmetic mean; otherwise, - a non-zero-dimensional array containing the arithmetic means. - The returned array must have the same data type as x. - - Examples - -------- - >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) - >>> mean(x).data - Array(2.5, dtype=float64) - >>> mean(x, dims=("x",)).data - Array([2., 3.], dtype=float64) - - Using keepdims: - - >>> mean(x, dims=("x",), keepdims=True) - - Array([[2., 3.]], dtype=float64) - >>> mean(x, dims=("y",), keepdims=True) - - Array([[1.5], - [3.5]], dtype=float64) - """ - xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.mean(x._data, axis=axis_, keepdims=False) # We fix keepdims later - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out - - -# %% Utility functions -def all( - x, - /, - *, - dims: _Dims | Default = _default, - keepdims: bool = False, - axis: _AxisLike | None = None, -) -> NamedArray[Any, _DType]: - xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.all(x._data, axis=axis_, keepdims=False) # We fix keepdims later - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out - - -def any( - x, - /, - *, - dims: _Dims | Default = _default, - keepdims: bool = False, - axis: _AxisLike | None = None, -) -> NamedArray[Any, _DType]: - xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.any(x._data, axis=axis_, keepdims=False) # We fix keepdims later - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out From 47bd2c57da107d8486b939d7fb927f19888fc89b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:31:30 +0200 Subject: [PATCH 118/367] test using array_api_strict instead of numpy --- xarray/namedarray/_array_api/_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index f630f2261e6..0f2169a6f37 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -3,7 +3,6 @@ from types import ModuleType from typing import TYPE_CHECKING, Any -import numpy as np from xarray.namedarray._typing import _arrayapi, _dtype @@ -12,7 +11,14 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: - return np if xp is None else xp + if xp is None: + import array_api_strict as xpd + + # import numpy as xpd + + return xpd + else: + return xp def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: From e01091b489140df5f9a83a2ab676d4319751ce46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:32:15 +0000 Subject: [PATCH 119/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 0f2169a6f37..59e57a70da5 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -3,7 +3,6 @@ from types import ModuleType from typing import TYPE_CHECKING, Any - from xarray.namedarray._typing import _arrayapi, _dtype if TYPE_CHECKING: From b4b301fa9c627671a1c7bcdcfdd048754a195f5f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:51:34 +0200 Subject: [PATCH 120/367] Update array-api-tests.yml --- .github/workflows/array-api-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index c7ccef7cde5..4403a59ebbc 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -36,6 +36,7 @@ jobs: python -m pip install ${GITHUB_WORKSPACE}/xarray python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt python -m pip install hypothesis + python -m pip install array-api-strict - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api From 98548ca138c5b1ff6bb71302266b9b51f03ac42a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:36:43 +0200 Subject: [PATCH 121/367] try array api compat --- .github/workflows/array-api-tests.yml | 1 + xarray/namedarray/_array_api/_utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 4403a59ebbc..efeeaecbf66 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,6 +37,7 @@ jobs: python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt python -m pip install hypothesis python -m pip install array-api-strict + python -m pip install array-api-compat - name: Run the array API testsuite env: ARRAY_API_TESTS_MODULE: xarray.namedarray._array_api diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 59e57a70da5..1b46b711e0e 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -11,7 +11,8 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: if xp is None: - import array_api_strict as xpd + # import array_api_strict as xpd + import array_api_compat as xpd # import numpy as xpd From b5d851c2ba2f6c8efa2e39b22682740ce07a7207 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:41:09 +0200 Subject: [PATCH 122/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 1b46b711e0e..8a296ee552e 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -12,7 +12,7 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: if xp is None: # import array_api_strict as xpd - import array_api_compat as xpd + import array_api_compat.numpy as xpd # import numpy as xpd From 3c30064a7cea6d6b540156b881b2857c567e3196 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:14:31 +0200 Subject: [PATCH 123/367] Update core.py --- xarray/namedarray/core.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 7e2e421327a..9dc4297499f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -549,6 +549,18 @@ def __len__(self) -> _IntOrUnknown: def __bool__(self, /) -> bool: return self._data.__bool__() + def __complex__(self, /) -> complex: + return self._data.__complex__() + + def __float__(self, /) -> float: + return self._data.__float__() + + def __index__(self, /) -> int: + return self._data.__index__() + + def __int__(self, /) -> int: + return self._data.__int__() + # Arithmetic Operators def __neg__(self, /): From 89026ac6df26462ecad3dd70929ed33987007d2f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:35:50 +0200 Subject: [PATCH 124/367] add more methods, reorder for easier array api comparisons --- xarray/namedarray/core.py | 124 +++++++++++++++++++++++--------------- 1 file changed, 76 insertions(+), 48 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 9dc4297499f..9c514016ff4 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -516,6 +516,45 @@ def copy( """ return self._copy(deep=deep, data=data) + def __len__(self) -> _IntOrUnknown: + try: + return self.shape[0] + except Exception as exc: + raise TypeError("len() of unsized object") from exc + + # Array API: + # Attributes: + + @property + def device(self) -> _Device: + """ + Device of the array’s elements. + + See Also + -------- + ndarray.device + """ + if isinstance(self._data, _arrayapi): + return self._data.device + else: + raise NotImplementedError("self._data missing device") + + @property + def dtype(self) -> _DType_co: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype + + @property + def mT(self): + raise NotImplementedError("Todo: ") + @property def ndim(self) -> int: """ @@ -527,6 +566,22 @@ def ndim(self) -> int: """ return len(self.shape) + @property + def shape(self) -> _Shape: + """ + Get the shape of the array. + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + @property def size(self) -> _IntOrUnknown: """ @@ -540,11 +595,26 @@ def size(self) -> _IntOrUnknown: """ return math.prod(self.shape) - def __len__(self) -> _IntOrUnknown: - try: - return self.shape[0] - except Exception as exc: - raise TypeError("len() of unsized object") from exc + @property + def T(self): + raise NotImplementedError("Todo: ") + + # methods + def __abs__(self, /): + from xarray.namedarray._array_api import abs + + return abs(self) + + # def __array_namespace__(self, /, *, api_version=None): + # if api_version is not None and api_version not in ( + # "2021.12", + # "2022.12", + # "2023.12", + # ): + # raise ValueError(f"Unrecognized array API version: {api_version!r}") + # import xarray.namedarray._array_api as array_api + + # return array_api def __bool__(self, /) -> bool: return self._data.__bool__() @@ -752,7 +822,7 @@ def __rrshift__(self, other, /): return bitwise_right_shift(other, self) - # Something + # Indexing def __getitem__(self, key: _IndexKeyLike | NamedArray): if isinstance(key, (int, slice, tuple)): @@ -765,48 +835,6 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray): else: raise NotImplementedError("{k=} is not supported") - @property - def device(self) -> _Device: - """ - Device of the array’s elements. - - See Also - -------- - ndarray.device - """ - if isinstance(self._data, _arrayapi): - return self._data.device - else: - raise NotImplementedError("self._data missing device") - - @property - def dtype(self) -> _DType_co: - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def shape(self) -> _Shape: - """ - Get the shape of the array. - - Returns - ------- - shape : tuple of ints - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape - @property def nbytes(self) -> _IntOrUnknown: """ From 3f2e2f971f1eaaeebe921ffa0371d640d1f3db6f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 18:06:48 +0200 Subject: [PATCH 125/367] add more creastion functions --- xarray/namedarray/_array_api/__init__.py | 16 +++--- .../_array_api/_creation_functions.py | 51 ++++++++++++++++++- 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 06631dceba3..b0c32cf225c 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -17,15 +17,15 @@ asarray, empty, empty_like, - # eye, + eye, full, full_like, linspace, - # meshgrid, + meshgrid, ones, ones_like, - # tril, - # triu, + tril, + triu, zeros, zeros_like, ) @@ -35,15 +35,15 @@ "asarray", "empty", "empty_like", - # "eye", + "eye", "full", "full_like", "linspace", - # "meshgrid", + "meshgrid", "ones", "ones_like", - # "tril", - # "triu", + "tril", + "triu", "zeros", "zeros_like", ] diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 12204f07b9c..1ea8418b49c 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -179,6 +179,21 @@ def empty_like( return x._new(data=_data) +def eye( + n_rows: int, + n_cols: int | None = None, + /, + *, + k: int = 0, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + def full( shape: _Shape, fill_value: bool | int | float | complex, @@ -219,12 +234,26 @@ def linspace( ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() _data = xp.linspace( - start, stop, num=num, dtype=dtype, device=device, endpoint=endpoint + start, + stop, + num=num, + dtype=dtype, + device=device, + endpoint=endpoint, ) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) +def meshgrid(*arrays: NamedArray, indexing: str = "xy") -> list[NamedArray]: + arr = arrays[0] + xp = _get_data_namespace(arr) + _datas = xp.meshgrid(*[a._data for a in arrays], indexing=indexing) + # TODO: Can probably determine dim names from arrays, for now just default names: + _dims = _infer_dims(_datas[0].shape) + return [arr._new(_dims, _datas)] + + def ones( shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: @@ -245,6 +274,26 @@ def ones_like( return x._new(data=_data) +def tril( + x: NamedArray[_ShapeType, _DType], /, *, k: int = 0 +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _data = xp.tril(x._data, dtype=x.dtype) + # TODO: Can probably determine dim names from x, for now just default names: + _dims = _infer_dims(_data.shape) + return x._new(_dims, _data) + + +def triu( + x: NamedArray[_ShapeType, _DType], /, *, k: int = 0 +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _data = xp.triu(x._data, dtype=x.dtype) + # TODO: Can probably determine dim names from x, for now just default names: + _dims = _infer_dims(_data.shape) + return x._new(_dims, _data) + + def zeros( shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: From d0dd5f4bca14cf4d1003166fb3bed2eae3630229 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 18:18:05 +0200 Subject: [PATCH 126/367] Update core.py --- xarray/namedarray/core.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 9c514016ff4..570bf0955a2 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -50,6 +50,7 @@ ) if TYPE_CHECKING: + from enum import IntEnum from numpy.typing import NDArray from xarray.core.types import T_Chunks @@ -631,6 +632,23 @@ def __index__(self, /) -> int: def __int__(self, /) -> int: return self._data.__int__() + # dlpack + def __dlpack__( + self, + /, + *, + stream: int | Any | None = None, + max_version: tuple[int, int] | None = None, + dl_device: tuple[IntEnum, int] | None = None, + copy: bool | None = None, + ) -> Any: + return self._data.__dlpack__( + stream=stream, max_version=max_version, dl_device=dl_device, copy=copy + ) + + def __dlpack_device__(self, /) -> tuple[IntEnum, int]: + return self._data.__dlpack_device__() + # Arithmetic Operators def __neg__(self, /): From 38d2bed4fdb52e4b6322e88cec45d430d3085e96 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 18:37:51 +0200 Subject: [PATCH 127/367] fix eye --- .../namedarray/_array_api/_creation_functions.py | 15 ++++++++------- xarray/namedarray/_array_api/_utils.py | 5 ++++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 1ea8418b49c..196c5c9b9f0 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -6,7 +6,8 @@ from xarray.namedarray._array_api._utils import ( _get_data_namespace, - _maybe_default_namespace, + # _maybe_default_namespace, + _get_namespace_dtype, ) from xarray.namedarray._typing import ( Default, @@ -53,7 +54,7 @@ def arange( dtype: _DType | None = None, device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -159,7 +160,7 @@ def asarray( def empty( shape: _ShapeType, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.empty(shape, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -188,8 +189,8 @@ def eye( dtype: _DType | None = None, device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() - _data = xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) + xp = _get_namespace_dtype(dtype) + _data = xp.eye(n_rows, n_cols, k=k, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -201,7 +202,7 @@ def full( dtype: _DType | None = None, device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.full(shape, fill_value, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -232,7 +233,7 @@ def linspace( device: _Device | None = None, endpoint: bool = True, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.linspace( start, stop, diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 8a296ee552e..018b84ca54c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -28,6 +28,9 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return _maybe_default_namespace() -def _get_namespace_dtype(dtype: _dtype) -> ModuleType: +def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: + if dtype is None: + return _maybe_default_namespace() + xp = __import__(dtype.__module__) return xp From d6177131d51b821a239fb5eef8d7b3e78d418cd7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:38:29 +0000 Subject: [PATCH 128/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 570bf0955a2..15a6d3d927e 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: from enum import IntEnum + from numpy.typing import NDArray from xarray.core.types import T_Chunks From 320a692e866fc48c2e25d6189a57dfc0598af7f1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:26:39 +0200 Subject: [PATCH 129/367] Add searching functions --- xarray/namedarray/_array_api/__init__.py | 23 ++--- .../_array_api/_creation_functions.py | 12 +-- .../_array_api/_searching_functions.py | 87 ++++++++++++++++++- xarray/namedarray/_array_api/_utils.py | 26 +++++- 4 files changed, 125 insertions(+), 23 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index b0c32cf225c..873bff3c5b0 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -270,17 +270,20 @@ # "stack", ] -# from xarray.namedarray._array_api._searching_functions import ( -# argmax, -# argmin, -# where, -# ) +from xarray.namedarray._array_api._searching_functions import ( + argmax, + argmin, + nonzero, + where, +) -# __all__ += [ -# "argmax", -# "argmin", -# "where", -# ] +__all__ += [ + "argmax", + "argmin", + "nonzero", + "searchsorted", + "where", +] from xarray.namedarray._array_api._statistical_functions import ( max, diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 196c5c9b9f0..31a690d424b 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -6,8 +6,8 @@ from xarray.namedarray._array_api._utils import ( _get_data_namespace, - # _maybe_default_namespace, _get_namespace_dtype, + _infer_dims, ) from xarray.namedarray._typing import ( Default, @@ -35,16 +35,6 @@ def _like_args(x, dtype=None, device: _Device | None = None): return dict(shape=x.shape, dtype=dtype, device=device) -def _infer_dims( - shape: _Shape, - dims: _DimsLike | Default = _default, -) -> _DimsLike: - if dims is _default: - return tuple(f"dim_{n}" for n in range(len(shape))) - else: - return dims - - def arange( start: int | float, /, diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index b62d0a8393b..2ee30632b2f 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -1,3 +1,88 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims +from xarray.namedarray._typing import ( + Default, + _arrayfunction_or_api, + _ArrayLike, + _default, + _Device, + _DimsLike, + _DType, + _Dims, + _Shape, + _ShapeType, + duckarray, +) +from xarray.namedarray.core import ( + NamedArray, + _dims_to_axis, + _get_remaining_dims, +) + +if TYPE_CHECKING: + from typing import Literal, Optional, Tuple + from xarray.namedarray._array_api._utils import _get_data_namespace -sdf = _get_data_namespace() + +def argmax( + x: NamedArray, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: int | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axis) + _data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + return x._new(dims=_dims, data=data_) + + +def argmin( + x: NamedArray, + /, + *, + dims: _Dims | Default = _default, + keepdims: bool = False, + axis: int | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axis) + _data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + return x._new(dims=_dims, data=data_) + + +def nonzero(x: NamedArray, /) -> tuple[NamedArray, ...]: + xp = _get_data_namespace(x) + _datas = xp.nonzero(x._data) + # TODO: Verify that dims and axis matches here: + return tuple(x._new(dim, i) for dim, i in zip(x.dims, _datas)) + + +def searchsorted( + x1: NamedArray, + x2: NamedArray, + /, + *, + side: Literal["left", "right"] = "left", + sorter: NamedArray | None = None, +) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.searchsorted(x1._data, x2._data, side=side, sorter=sorter) + # TODO: Check dims, probably can do it smarter: + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def where(condition: NamedArray, x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.where(condition._data, x1._data, x2._data) + return x1._new(x1.dims, _data) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 018b84ca54c..f01ec05b146 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -3,7 +3,21 @@ from types import ModuleType from typing import TYPE_CHECKING, Any -from xarray.namedarray._typing import _arrayapi, _dtype +from xarray.namedarray._typing import ( + Default, + _arrayfunction_or_api, + _ArrayLike, + _default, + _arrayapi, + _Device, + _DimsLike, + _DType, + _Dims, + _Shape, + _ShapeType, + duckarray, + _dtype, +) if TYPE_CHECKING: from xarray.namedarray.core import NamedArray @@ -34,3 +48,13 @@ def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: xp = __import__(dtype.__module__) return xp + + +def _infer_dims( + shape: _Shape, + dims: _DimsLike | Default = _default, +) -> _DimsLike: + if dims is _default: + return tuple(f"dim_{n}" for n in range(len(shape))) + else: + return dims From c78f7425984c0ac91151334e164e6ad990b99131 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:27:22 +0000 Subject: [PATCH 130/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../namedarray/_array_api/_creation_functions.py | 1 - .../namedarray/_array_api/_searching_functions.py | 14 ++------------ xarray/namedarray/_array_api/_utils.py | 11 ++--------- 3 files changed, 4 insertions(+), 22 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 31a690d424b..707e1f46761 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -10,7 +10,6 @@ _infer_dims, ) from xarray.namedarray._typing import ( - Default, _arrayfunction_or_api, _ArrayLike, _default, diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 2ee30632b2f..ee6825c544c 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -1,20 +1,12 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims from xarray.namedarray._typing import ( Default, - _arrayfunction_or_api, - _ArrayLike, _default, - _Device, - _DimsLike, - _DType, _Dims, - _Shape, - _ShapeType, - duckarray, ) from xarray.namedarray.core import ( NamedArray, @@ -23,9 +15,7 @@ ) if TYPE_CHECKING: - from typing import Literal, Optional, Tuple - -from xarray.namedarray._array_api._utils import _get_data_namespace + from typing import Literal def argmax( diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index f01ec05b146..4e5322186ae 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -5,18 +5,11 @@ from xarray.namedarray._typing import ( Default, - _arrayfunction_or_api, - _ArrayLike, - _default, _arrayapi, - _Device, + _default, _DimsLike, - _DType, - _Dims, - _Shape, - _ShapeType, - duckarray, _dtype, + _Shape, ) if TYPE_CHECKING: From 4e8018aa0f52187423f27452074d614b08913d1e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:14:57 +0200 Subject: [PATCH 131/367] add more statistical functions --- .../_array_api/_statistical_functions.py | 92 +++++++++++++++---- 1 file changed, 75 insertions(+), 17 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index d99f0f812ef..2f3a4a5259b 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -7,8 +7,10 @@ Default, _AxisLike, _default, + _Dim, _Dims, _DType, + _ShapeType, ) from xarray.namedarray.core import ( NamedArray, @@ -17,6 +19,23 @@ ) +def cumulative_sum( + x: NamedArray[_ShapeType, _DType], + /, + *, + dim: _Dim | Default = _default, + dtype: _DType | None = None, + include_initial: bool = False, + axis: int | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dim, axis) + _data = xp.cumulative_sum( + x._data, axis=_axis, dtype=dtype, include_initial=include_initial + ) + return x._new(dims=x.dims, data=_data) + + def max( x: NamedArray[Any, _DType], /, @@ -26,12 +45,11 @@ def max( axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.max(x._data, axis=axis_, keepdims=False) # We fix keepdims later + _axis = _dims_to_axis(x, dims, axis) + _data = xp.max(x._data, axis=_axis, keepdims=False) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + return x._new(dims=dims_, data=data_) def mean( @@ -93,10 +111,10 @@ def mean( [3.5]], dtype=float64) """ xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.mean(x._data, axis=axis_, keepdims=False) # We fix keepdims later + _axis = _dims_to_axis(x, dims, axis) + _data = xp.mean(x._data, axis=_axis, keepdims=False) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out @@ -110,10 +128,10 @@ def min( axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.min(x._data, axis=axis_, keepdims=False) # We fix keepdims later + _axis = _dims_to_axis(x, dims, axis) + _data = xp.min(x._data, axis=_axis, keepdims=False) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out @@ -127,10 +145,30 @@ def prod( axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.prod(x._data, axis=axis_, keepdims=False) # We fix keepdims later + _axis = _dims_to_axis(x, dims, axis) + _data = xp.prod(x._data, axis=_axis, keepdims=False) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def std( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + correction: int | float = 0.0, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axis) + _data = xp.std( + x._data, axis=_axis, correction=correction, keepdims=False + ) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out @@ -144,9 +182,29 @@ def sum( axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.sum(x._data, axis=axis_, keepdims=False) # We fix keepdims later + _axis = _dims_to_axis(x, dims, axis) + _data = xp.sum(x._data, axis=_axis, keepdims=False) # We fix keepdims later + # TODO: Why do we need to do the keepdims ourselves? + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) + out = x._new(dims=dims_, data=data_) + return out + + +def var( + x: NamedArray[Any, _DType], + /, + *, + dims: _Dims | Default = _default, + correction: int | float = 0.0, + keepdims: bool = False, + axis: _AxisLike | None = None, +) -> NamedArray[Any, _DType]: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axis) + _data = xp.var( + x._data, axis=_axis, correction=correction, keepdims=False + ) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) + dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out From 2c6acecbbd6ed45e6701bca65a9f8dbf7838feac Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:28:35 +0200 Subject: [PATCH 132/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 873bff3c5b0..b0947d03afc 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -286,19 +286,25 @@ ] from xarray.namedarray._array_api._statistical_functions import ( + cumulative_sum, max, mean, min, prod, + std, sum, + var, ) __all__ += [ + "cumulative_sum", "max", "mean", "min", "prod", + "std", "sum", + "var", ] from xarray.namedarray._array_api._utility_functions import ( From 2860fd5ab6c40acb887621f36c0bd24d2f661c7c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:28:41 +0200 Subject: [PATCH 133/367] Update _creation_functions.py --- xarray/namedarray/_array_api/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 707e1f46761..22a3b14c3d9 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -241,7 +241,7 @@ def meshgrid(*arrays: NamedArray, indexing: str = "xy") -> list[NamedArray]: _datas = xp.meshgrid(*[a._data for a in arrays], indexing=indexing) # TODO: Can probably determine dim names from arrays, for now just default names: _dims = _infer_dims(_datas[0].shape) - return [arr._new(_dims, _datas)] + return [arr._new(_dims, _data) for _data in _datas] def ones( From cee7899fa70c330d55f088f7c46851ee9ff38dc3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:44:43 +0200 Subject: [PATCH 134/367] add protocols for integer info --- .../_array_api/_data_type_functions.py | 8 +++++--- xarray/namedarray/_typing.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_data_type_functions.py b/xarray/namedarray/_array_api/_data_type_functions.py index fbfaaca0d0e..f6ca29a1c42 100644 --- a/xarray/namedarray/_array_api/_data_type_functions.py +++ b/xarray/namedarray/_array_api/_data_type_functions.py @@ -11,6 +11,8 @@ _DType, _dtype, _ShapeType, + _FInfo, + _IInfo, ) from xarray.namedarray.core import ( NamedArray, @@ -71,7 +73,7 @@ def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: return xp.can_cast(from_, to) -def finfo(type: _dtype | NamedArray[Any, Any], /): +def finfo(type: _dtype | NamedArray[Any, Any], /) -> _FInfo: if isinstance(type, NamedArray): xp = _get_data_namespace(type) return xp.finfo(type._data) @@ -80,7 +82,7 @@ def finfo(type: _dtype | NamedArray[Any, Any], /): return xp.finfo(type) -def iinfo(type: _dtype | NamedArray[Any, Any], /): +def iinfo(type: _dtype | NamedArray[Any, Any], /) -> _IInfo: if isinstance(type, NamedArray): xp = _get_data_namespace(type) return xp.iinfo(type._data) @@ -90,7 +92,7 @@ def iinfo(type: _dtype | NamedArray[Any, Any], /): def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: - xp = _get_namespace_dtype(type) + xp = _get_namespace_dtype(dtype) return xp.isdtype(dtype, kind) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 3b88e110d26..e0731529528 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -103,6 +103,22 @@ def dtype(self) -> _DType_co: ... _Device = Any +class _IInfo(Protocol): + bits: int + max: int + min: int + dtype: _dtype + + +class _FInfo(Protocol): + bits: int + eps: float + max: float + min: float + smallest_normal: float + dtype: _dtype + + class _SupportsReal(Protocol[_T_co]): @property def real(self) -> _T_co: ... From 69e017469363b02931cabeb1fefef9ee7126314c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:46:21 +0000 Subject: [PATCH 135/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_data_type_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_data_type_functions.py b/xarray/namedarray/_array_api/_data_type_functions.py index f6ca29a1c42..895b7a50fb1 100644 --- a/xarray/namedarray/_array_api/_data_type_functions.py +++ b/xarray/namedarray/_array_api/_data_type_functions.py @@ -10,9 +10,9 @@ _arrayapi, _DType, _dtype, - _ShapeType, _FInfo, _IInfo, + _ShapeType, ) from xarray.namedarray.core import ( NamedArray, From 5b38cce9600d6603d22600254089cfef47d94f84 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 22:02:23 +0200 Subject: [PATCH 136/367] add hypot --- xarray/namedarray/_array_api/__init__.py | 2 ++ xarray/namedarray/_array_api/_elementwise_functions.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index b0947d03afc..c4a3e14912b 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -126,6 +126,7 @@ floor_divide, greater, greater_equal, + hypot, imag, isfinite, isinf, @@ -188,6 +189,7 @@ "floor_divide", "greater", "greater_equal", + "hypot", "imag", "isfinite", "isinf", diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index acdd4dc5c48..589b8638492 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -188,6 +188,13 @@ def greater_equal(x1, x2, /): return out +def hypot(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.hypot(x1._data, x2._data)) + return out + + def imag( x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var] ) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: From 12d3669c446b98636dc78ec875f9c4b7bab9fe67 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 22:02:42 +0200 Subject: [PATCH 137/367] workaround strange errors --- xarray/namedarray/_array_api/_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 4e5322186ae..48f228f628d 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -39,7 +39,15 @@ def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: if dtype is None: return _maybe_default_namespace() - xp = __import__(dtype.__module__) + try: + xp = __import__(dtype.__module__) + except AttributeError: + # TODO: Fix this. + # FAILED array_api_tests/test_searching_functions.py::test_searchsorted - AttributeError: 'numpy.dtypes.Float64DType' object has no attribute '__module__'. Did you mean: '__mul__'? + # Falsifying example: test_searchsorted( + # data=data(...), + # ) + return _maybe_default_namespace() return xp From 716cb391f61728ee90d586e8af77e0f4d3e5e559 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 22:29:28 +0200 Subject: [PATCH 138/367] add more elementwise functions --- xarray/namedarray/_array_api/__init__.py | 10 +++++ .../_array_api/_elementwise_functions.py | 38 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index c4a3e14912b..1110149585a 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -115,7 +115,9 @@ bitwise_right_shift, bitwise_xor, ceil, + clip, conj, + copysign, cos, cosh, divide, @@ -142,6 +144,8 @@ logical_not, logical_or, logical_xor, + maximum, + minimum, multiply, negative, not_equal, @@ -151,6 +155,7 @@ remainder, round, sign, + signbit, sin, sinh, sqrt, @@ -178,7 +183,9 @@ "bitwise_right_shift", "bitwise_xor", "ceil", + "clip", "conj", + "copysign", "cos", "cosh", "divide", @@ -205,6 +212,8 @@ "logical_not", "logical_or", "logical_xor", + "maximum", + "minimum", "multiply", "negative", "not_equal", @@ -214,6 +223,7 @@ "remainder", "round", "sign", + "signbit", "sin", "sinh", "sqrt", diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 589b8638492..43cc19e5e61 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -117,12 +117,30 @@ def ceil(x, /): return out +def clip( + x: NamedArray, + /, + min: int | float | NamedArray | None = None, + max: int | float | NamedArray | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + out = x._new(data=xp.clip(x._data)) + return out + + def conj(x, /): xp = _get_data_namespace(x) out = x._new(data=xp.conj(x._data)) return out +def copysign(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.copysign(x1._data, x2._data)) + return out + + def cos(x, /): xp = _get_data_namespace(x) out = x._new(data=xp.cos(x._data)) @@ -317,6 +335,20 @@ def logical_xor(x1, x2, /): return out +def maximum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.maximum(x1._data, x2._data)) + return out + + +def minimum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + # TODO: Handle attrs? will get x1 now + out = x1._new(data=xp.minimum(x1._data, x2._data)) + return out + + def multiply(x1, x2, /): xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now @@ -401,6 +433,12 @@ def sign(x, /): return out +def signbit(x, /): + xp = _get_data_namespace(x) + out = x._new(data=xp.signbit(x._data)) + return out + + def sin(x, /): xp = _get_data_namespace(x) out = x._new(data=xp.sin(x._data)) From 085c4f78a3e61d4b4a0521cddf4a1940f4a40a5d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 22:49:24 +0200 Subject: [PATCH 139/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 1110149585a..004af60f137 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -286,6 +286,7 @@ argmax, argmin, nonzero, + searchsorted, where, ) From ef032b23f6be13ac095c9fcd5ff40680eb0378a0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 22:59:31 +0200 Subject: [PATCH 140/367] add take --- xarray/namedarray/_array_api/__init__.py | 4 +-- .../_array_api/_indexing_functions.py | 34 ++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 004af60f137..1b34ffdd021 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -234,9 +234,9 @@ "trunc", ] -# from xarray.namedarray._array_api.indexing_functions import take +from xarray.namedarray._array_api._indexing_functions import take -# __all__ += ["take"] +__all__ += ["take"] from xarray.namedarray._array_api._info import __array_namespace_info__ diff --git a/xarray/namedarray/_array_api/_indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py index 12b69a6bc86..b0a2f2a966a 100644 --- a/xarray/namedarray/_array_api/_indexing_functions.py +++ b/xarray/namedarray/_array_api/_indexing_functions.py @@ -1,3 +1,35 @@ +from __future__ import annotations + +from typing import Any + +from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _Axes, + _Axis, + _default, + _Dim, + _DType, + _Shape, +) +from xarray.namedarray.core import ( + _dims_to_axis, + NamedArray, +) + -sdf = _get_data_namespace +def take( + x: NamedArray, + indices: NamedArray, + /, + *, + dim: _Dim | Default = _default, + axis: int | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dim, axis) + # TODO: Handle attrs? will get x1 now + out = x._new(data=xp.take(x._data, indices._data, axis=_axis)) + return out From 470fe5fc089151726af872fccde6d3941691397a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:01:19 +0000 Subject: [PATCH 141/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_indexing_functions.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/_indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py index b0a2f2a966a..2859171b853 100644 --- a/xarray/namedarray/_array_api/_indexing_functions.py +++ b/xarray/namedarray/_array_api/_indexing_functions.py @@ -1,22 +1,14 @@ from __future__ import annotations -from typing import Any - -from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( Default, - _arrayapi, - _Axes, - _Axis, _default, _Dim, - _DType, - _Shape, ) from xarray.namedarray.core import ( - _dims_to_axis, NamedArray, + _dims_to_axis, ) From 2de92d9d3b21b15714e99629c7c49cbc2bd26fc6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 23:14:32 +0200 Subject: [PATCH 142/367] add sorting functions --- xarray/namedarray/_array_api/__init__.py | 4 ++ .../_array_api/_sorting_functions.py | 54 +++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 xarray/namedarray/_array_api/_sorting_functions.py diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 1b34ffdd021..89c8402ec0f 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -298,6 +298,10 @@ "where", ] +from xarray.namedarray._array_api._sorting_functions import argsort, sort + +__all__ += ["argsort", "sort"] + from xarray.namedarray._array_api._statistical_functions import ( cumulative_sum, max, diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py new file mode 100644 index 00000000000..164e11f3075 --- /dev/null +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import Any + +from xarray.namedarray._array_api._creation_functions import asarray +from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._typing import ( + Default, + _arrayapi, + _Axes, + _Axis, + _default, + _Dim, + _DType, + _Shape, +) +from xarray.namedarray.core import ( + _dims_to_axis, + NamedArray, +) + + +def argsort( + x: NamedArray, + /, + *, + dim: _Dim | Default = _default, + descending: bool = False, + stable: bool = True, + axis: int = -1, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dim, axis) + out = x._new( + data=xp.argsort(x._data, axis=_axis, descending=descending, stable=stable) + ) + return out + + +def sort( + x: NamedArray, + /, + *, + dim: _Dim | Default = _default, + descending: bool = False, + stable: bool = True, + axis: int = -1, +) -> NamedArray: + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dim, axis) + out = x._new( + data=xp.argsort(x._data, axis=_axis, descending=descending, stable=stable) + ) + return out From 87472d19c2a7d8ae702e1b18db03b5ac3b20650c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:15:58 +0000 Subject: [PATCH 143/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_sorting_functions.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 164e11f3075..201dcb65251 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -1,22 +1,14 @@ from __future__ import annotations -from typing import Any - -from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( Default, - _arrayapi, - _Axes, - _Axis, _default, _Dim, - _DType, - _Shape, ) from xarray.namedarray.core import ( - _dims_to_axis, NamedArray, + _dims_to_axis, ) From 270c93c9f4e3b80b5b70fbcefc491b42d57039d7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 20 Aug 2024 23:24:56 +0200 Subject: [PATCH 144/367] Update core.py --- xarray/namedarray/core.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 15a6d3d927e..af55abefdca 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -597,6 +597,12 @@ def size(self) -> _IntOrUnknown: """ return math.prod(self.shape) + def to_device(self, device: _Device, /, stream: None = None) -> Self: + if isinstance(self._data, _arrayapi): + return self._replace(data=self._data.to_device(device, stream=stream)) + else: + raise NotImplementedError("Only array api are valid.") + @property def T(self): raise NotImplementedError("Todo: ") From 57da327147f246d37a015512b3031db460f696f9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 21 Aug 2024 21:55:45 +0200 Subject: [PATCH 145/367] fix cumsum --- .../_array_api/_statistical_functions.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 2f3a4a5259b..ad2e39a75a5 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -30,9 +30,28 @@ def cumulative_sum( ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis) - _data = xp.cumulative_sum( - x._data, axis=_axis, dtype=dtype, include_initial=include_initial - ) + try: + _data = xp.cumulative_sum( + x._data, axis=_axis, dtype=dtype, include_initial=include_initial + ) + except AttributeError: + # Use np.cumsum until new name is introduced: + # np.cumsum does not support include_initial + if include_initial: + if axis < 0: + axis += x.ndim + d = xp.concat( + [ + xp.zeros( + x.shape[:axis] + (1,) + x.shape[axis + 1 :], dtype=x.dtype + ), + x._data, + ], + axis=axis, + ) + else: + d = x._data + _data = xp.cumsum(d, axis=axis, dtype=dtype) return x._new(dims=x.dims, data=_data) From 5178661c026e66da558749e9cab77a1363814595 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:40:08 +0200 Subject: [PATCH 146/367] Update _searching_functions.py --- xarray/namedarray/_array_api/_searching_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index ee6825c544c..259185aeed2 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -75,4 +75,6 @@ def searchsorted( def where(condition: NamedArray, x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) _data = xp.where(condition._data, x1._data, x2._data) - return x1._new(x1.dims, _data) + # TODO: Wrong, _dims should be either of the arguments. How to choose? + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) From 2e6f3da8ca0397c5a4afd9f503215086396439e7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:57:36 +0200 Subject: [PATCH 147/367] handle scalars --- xarray/namedarray/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index af55abefdca..33555d40275 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -668,10 +668,10 @@ def __pos__(self, /): return positive(self) - def __add__(self, other, /): - from xarray.namedarray._array_api import add + def __add__(self, other: int | float | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import add, asarray - return add(self, other) + return add(self, asarray(other)) def __sub__(self, other, /): from xarray.namedarray._array_api import subtract @@ -743,11 +743,10 @@ def __rshift__(self, other, /): return bitwise_right_shift(self) # Comparison Operators + def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import equal, asarray - def __eq__(self, other, /): - from xarray.namedarray._array_api import equal - - return equal(self, other) + return equal(self, asarray(other)) def __ge__(self, other, /): from xarray.namedarray._array_api import greater_equal From 85c6f4bece9c732ecab10dcc8316737f1c3cb9de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:58:20 +0000 Subject: [PATCH 148/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 33555d40275..dd668a979ba 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -744,7 +744,7 @@ def __rshift__(self, other, /): # Comparison Operators def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import equal, asarray + from xarray.namedarray._array_api import asarray, equal return equal(self, asarray(other)) From 2af7eec2de3ee627eb22d9cdef19dd3ab1ca87b0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:01:28 +0200 Subject: [PATCH 149/367] move dim manipulation to array_api --- .../_array_api/_indexing_functions.py | 7 +- xarray/namedarray/_array_api/_info.py | 22 ++-- .../_array_api/_searching_functions.py | 9 +- .../_array_api/_sorting_functions.py | 3 +- .../_array_api/_statistical_functions.py | 8 +- .../_array_api/_utility_functions.py | 8 +- xarray/namedarray/_array_api/_utils.py | 103 +++++++++++++++++- 7 files changed, 135 insertions(+), 25 deletions(-) diff --git a/xarray/namedarray/_array_api/_indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py index 2859171b853..89cbe8b1da1 100644 --- a/xarray/namedarray/_array_api/_indexing_functions.py +++ b/xarray/namedarray/_array_api/_indexing_functions.py @@ -1,15 +1,12 @@ from __future__ import annotations -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._utils import _get_data_namespace, _dims_to_axis from xarray.namedarray._typing import ( Default, _default, _Dim, ) -from xarray.namedarray.core import ( - NamedArray, - _dims_to_axis, -) +from xarray.namedarray.core import NamedArray def take( diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index beb30a1d0ad..688fd0c43df 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -2,10 +2,6 @@ from typing import TYPE_CHECKING -if TYPE_CHECKING: - from types import ModuleType - - from xarray.namedarray._typing import _Device from xarray.namedarray._array_api._dtypes import ( bool, @@ -23,6 +19,16 @@ uint64, ) +if TYPE_CHECKING: + from types import ModuleType + + from xarray.namedarray._typing import ( + _Device, + _Capabilities, + _DefaultDataTypes, + _DataTypes, + ) + def __array_namespace_info__() -> ModuleType: import xarray.namedarray._array_api._info @@ -30,7 +36,7 @@ def __array_namespace_info__() -> ModuleType: return xarray.namedarray._array_api._info -def capabilities() -> dict: +def capabilities() -> _Capabilities: return { "boolean indexing": False, "data-dependent shapes": False, @@ -48,7 +54,7 @@ def default_device() -> _Device: def default_dtypes( *, device: _Device | None = None, -) -> dict: +) -> _DefaultDataTypes: return { "real floating": float64, "complex floating": complex128, @@ -60,8 +66,8 @@ def default_dtypes( def dtypes( *, device: _Device | None = None, - kind: str | Tuple[str, ...] | None = None, -) -> dict: + kind: str | tuple[str, ...] | None = None, +) -> _DataTypes: if kind is None: return { "bool": bool, diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 259185aeed2..1100784c2d8 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -2,7 +2,12 @@ from typing import TYPE_CHECKING -from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _infer_dims, + _dims_to_axis, + _get_remaining_dims, +) from xarray.namedarray._typing import ( Default, _default, @@ -10,8 +15,6 @@ ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, ) if TYPE_CHECKING: diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 201dcb65251..698704f798b 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._utils import _get_data_namespace, _dims_to_axis from xarray.namedarray._typing import ( Default, _default, @@ -8,7 +8,6 @@ ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, ) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index ad2e39a75a5..3cd0f451564 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -2,7 +2,11 @@ from typing import Any -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _dims_to_axis, + _get_remaining_dims, +) from xarray.namedarray._typing import ( Default, _AxisLike, @@ -14,8 +18,6 @@ ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, ) diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 86eb34c4d9c..e9936148b47 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -2,7 +2,11 @@ from typing import Any -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _dims_to_axis, + _get_remaining_dims, +) from xarray.namedarray._typing import ( Default, _AxisLike, @@ -12,8 +16,6 @@ ) from xarray.namedarray.core import ( NamedArray, - _dims_to_axis, - _get_remaining_dims, ) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 48f228f628d..1bee2533006 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import ModuleType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable from xarray.namedarray._typing import ( Default, @@ -9,7 +9,12 @@ _default, _DimsLike, _dtype, + _AxisLike, _Shape, + _Dim, + _Dims, + _DType, + duckarray, ) if TYPE_CHECKING: @@ -59,3 +64,99 @@ def _infer_dims( return tuple(f"dim_{n}" for n in range(len(shape))) else: return dims + + +def _normalize_dimensions(dims: _DimsLike) -> _Dims: + """ + Normalize dimensions. + + Examples + -------- + >>> _normalize_dimensions(None) + (None,) + >>> _normalize_dimensions(1) + (1,) + >>> _normalize_dimensions("2") + ('2',) + >>> _normalize_dimensions(("time",)) + ('time',) + >>> _normalize_dimensions(["time"]) + ('time',) + >>> _normalize_dimensions([("time", "x", "y")]) + (('time', 'x', 'y'),) + """ + if isinstance(dims, str) or not isinstance(dims, Iterable): + return (dims,) + + return tuple(dims) + + +def _assert_either_dim_or_axis( + dims: _Dim | _Dims | Default, axis: _AxisLike | None +) -> None: + if dims is not _default and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim(s)' arguments") + + +def _dims_to_axis( + x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None +) -> _AxisLike | None: + """ + Convert dims to axis indices. + + Examples + -------- + >>> narr = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) + >>> _dims_to_axis(narr, ("y",), None) + (1,) + >>> _dims_to_axis(narr, None, 0) + (0,) + >>> _dims_to_axis(narr, None, None) + """ + _assert_either_dim_or_axis(dims, axis) + + if dims is not _default: + axis = () + for dim in dims: + try: + axis = (x.dims.index(dim),) + except ValueError: + raise ValueError(f"{dim!r} not found in array dimensions {x.dims!r}") + return axis + + if isinstance(axis, int): + return (axis,) + + return axis + + +def _get_remaining_dims( + x: NamedArray[Any, _DType], + data: duckarray[Any, _DType], + axis: _AxisLike | None, + *, + keepdims: bool, +) -> tuple[_Dims, duckarray[Any, _DType]]: + """ + Get the reamining dims after a reduce operation. + """ + if data.shape == x.shape: + return x.dims, data + + removed_axes: tuple[int, ...] + if axis is None: + removed_axes = tuple(v for v in range(x.ndim)) + else: + removed_axes = axis % x.ndim if isinstance(axis, tuple) else (axis % x.ndim,) + + if keepdims: + # Insert None (aka newaxis) for removed dims + slices = tuple( + None if i in removed_axes else slice(None, None) for i in range(x.ndim) + ) + data = data[slices] + dims = x.dims + else: + dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) + + return dims, data From 14df7c80a7449b9853af3c08ac185326dc8c64c9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:01:46 +0200 Subject: [PATCH 150/367] Add more typing --- xarray/namedarray/_typing.py | 54 ++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index e0731529528..fbce5281ae3 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -15,6 +15,7 @@ Union, overload, runtime_checkable, + TypedDict, ) import numpy as np @@ -42,14 +43,14 @@ class Default(Enum): _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) +_ScalarType = TypeVar("_ScalarType", bound=np.generic) +_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) + _dtype = np.dtype _DType = TypeVar("_DType", bound=np.dtype[Any]) _DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) # A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic` -_ScalarType = TypeVar("_ScalarType", bound=np.generic) -_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) - # A protocol for anything with the dtype attribute @runtime_checkable @@ -58,6 +59,16 @@ class _SupportsDType(Protocol[_DType_co]): def dtype(self) -> _DType_co: ... +class _SupportsReal(Protocol[_T_co]): + @property + def real(self) -> _T_co: ... + + +class _SupportsImag(Protocol[_T_co]): + @property + def imag(self) -> _T_co: ... + + _DTypeLike = Union[ np.dtype[_ScalarType], type[_ScalarType], @@ -119,14 +130,39 @@ class _FInfo(Protocol): dtype: _dtype -class _SupportsReal(Protocol[_T_co]): - @property - def real(self) -> _T_co: ... +_Capabilities = TypedDict( + "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} +) +_DefaultDataTypes = TypedDict( + "DefaultDataTypes", + { + "real floating": _dtype, + "complex floating": _dtype, + "integral": _dtype, + "indexing": _dtype, + }, +) -class _SupportsImag(Protocol[_T_co]): - @property - def imag(self) -> _T_co: ... +_DataTypes = TypedDict( + "DataTypes", + { + "bool": _dtype, + "float32": _dtype, + "float64": _dtype, + "complex64": _dtype, + "complex128": _dtype, + "int8": _dtype, + "int16": _dtype, + "int32": _dtype, + "int64": _dtype, + "uint8": _dtype, + "uint16": _dtype, + "uint32": _dtype, + "uint64": _dtype, + }, + total=False, +) @runtime_checkable From 10a19f40317445ffc5c6830e1efed88af31fa187 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 18:02:30 +0000 Subject: [PATCH 151/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../_array_api/_indexing_functions.py | 2 +- xarray/namedarray/_array_api/_info.py | 5 ++- .../_array_api/_searching_functions.py | 4 +-- .../_array_api/_sorting_functions.py | 2 +- .../_array_api/_statistical_functions.py | 2 +- .../_array_api/_utility_functions.py | 2 +- xarray/namedarray/_array_api/_utils.py | 11 +++--- xarray/namedarray/_typing.py | 36 +++++++++---------- 8 files changed, 30 insertions(+), 34 deletions(-) diff --git a/xarray/namedarray/_array_api/_indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py index 89cbe8b1da1..581e8bc8c22 100644 --- a/xarray/namedarray/_array_api/_indexing_functions.py +++ b/xarray/namedarray/_array_api/_indexing_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from xarray.namedarray._array_api._utils import _get_data_namespace, _dims_to_axis +from xarray.namedarray._array_api._utils import _dims_to_axis, _get_data_namespace from xarray.namedarray._typing import ( Default, _default, diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index 688fd0c43df..b9d4749faec 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING - from xarray.namedarray._array_api._dtypes import ( bool, complex64, @@ -23,10 +22,10 @@ from types import ModuleType from xarray.namedarray._typing import ( - _Device, _Capabilities, - _DefaultDataTypes, _DataTypes, + _DefaultDataTypes, + _Device, ) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 1100784c2d8..81679ae06e9 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -3,10 +3,10 @@ from typing import TYPE_CHECKING from xarray.namedarray._array_api._utils import ( - _get_data_namespace, - _infer_dims, _dims_to_axis, + _get_data_namespace, _get_remaining_dims, + _infer_dims, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 698704f798b..0bc5f53b34a 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from xarray.namedarray._array_api._utils import _get_data_namespace, _dims_to_axis +from xarray.namedarray._array_api._utils import _dims_to_axis, _get_data_namespace from xarray.namedarray._typing import ( Default, _default, diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 3cd0f451564..09c7f6da0a7 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -3,8 +3,8 @@ from typing import Any from xarray.namedarray._array_api._utils import ( - _get_data_namespace, _dims_to_axis, + _get_data_namespace, _get_remaining_dims, ) from xarray.namedarray._typing import ( diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index e9936148b47..72f5451a471 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -3,8 +3,8 @@ from typing import Any from xarray.namedarray._array_api._utils import ( - _get_data_namespace, _dims_to_axis, + _get_data_namespace, _get_remaining_dims, ) from xarray.namedarray._typing import ( diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 1bee2533006..c9d4932d3cb 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,19 +1,20 @@ from __future__ import annotations +from collections.abc import Iterable from types import ModuleType -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from xarray.namedarray._typing import ( Default, _arrayapi, - _default, - _DimsLike, - _dtype, _AxisLike, - _Shape, + _default, _Dim, _Dims, + _DimsLike, _DType, + _dtype, + _Shape, duckarray, ) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index fbce5281ae3..00345a9ab29 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -11,11 +11,11 @@ Literal, Protocol, SupportsIndex, + TypedDict, TypeVar, Union, overload, runtime_checkable, - TypedDict, ) import numpy as np @@ -144,25 +144,21 @@ class _FInfo(Protocol): }, ) -_DataTypes = TypedDict( - "DataTypes", - { - "bool": _dtype, - "float32": _dtype, - "float64": _dtype, - "complex64": _dtype, - "complex128": _dtype, - "int8": _dtype, - "int16": _dtype, - "int32": _dtype, - "int64": _dtype, - "uint8": _dtype, - "uint16": _dtype, - "uint32": _dtype, - "uint64": _dtype, - }, - total=False, -) + +class _DataTypes(TypedDict, total=False): + bool: _dtype + float32: _dtype + float64: _dtype + complex64: _dtype + complex128: _dtype + int8: _dtype + int16: _dtype + int32: _dtype + int64: _dtype + uint8: _dtype + uint16: _dtype + uint32: _dtype + uint64: _dtype @runtime_checkable From 82d3af9e13d7283508f6c7b468eba77eef89459a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:10:50 +0200 Subject: [PATCH 152/367] Update core.py --- xarray/namedarray/core.py | 209 ++++++++++---------------------------- 1 file changed, 51 insertions(+), 158 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index dd668a979ba..d81fcbaa67d 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -50,24 +50,17 @@ ) if TYPE_CHECKING: - from enum import IntEnum + from numpy.typing import ArrayLike, NDArray - from numpy.typing import NDArray - - from xarray.core.types import T_Chunks + from xarray.core.types import Dims, T_Chunks from xarray.namedarray._typing import ( Default, - _ArrayLike, _AttrsLike, - _AxisLike, _Chunks, - _Device, _Dim, _Dims, _DimsLike, - _DimsLikeAgg, _DType, - _IndexKeyLike, _IntOrUnknown, _ScalarType, _Shape, @@ -102,114 +95,6 @@ ) -def _normalize_dimensions(dims: _DimsLike) -> _Dims: - """ - Normalize dimensions. - - Examples - -------- - >>> _normalize_dimensions(None) - (None,) - >>> _normalize_dimensions(1) - (1,) - >>> _normalize_dimensions("2") - ('2',) - >>> _normalize_dimensions(("time",)) - ('time',) - >>> _normalize_dimensions(["time"]) - ('time',) - >>> _normalize_dimensions([("time", "x", "y")]) - (('time', 'x', 'y'),) - """ - if isinstance(dims, str) or not isinstance(dims, Iterable): - return (dims,) - - return tuple(dims) - - -def _assert_either_dim_or_axis( - dims: _Dim | _Dims | Default, axis: _AxisLike | None -) -> None: - if dims is not _default and axis is not None: - raise ValueError("cannot supply both 'axis' and 'dim(s)' arguments") - - -def _dims_to_axis( - x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None -) -> _AxisLike | None: - """ - Convert dims to axis indices. - - Examples - -------- - >>> narr = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) - >>> _dims_to_axis(narr, ("y",), None) - (1,) - >>> _dims_to_axis(narr, None, 0) - (0,) - >>> _dims_to_axis(narr, None, None) - """ - _assert_either_dim_or_axis(dims, axis) - - if dims is not _default: - return x._dims_to_axes(dims) - - if isinstance(axis, int): - return (axis,) - - return axis - - -def _get_remaining_dims( - x: NamedArray[Any, _DType], - data: duckarray[Any, _DType], - axis: _AxisLike | None, - *, - keepdims: bool, -) -> tuple[_Dims, duckarray[Any, _DType]]: - """ - Get the reamining dims after a reduce operation. - - Parameters - ---------- - x : - DESCRIPTION. - data : - DESCRIPTION. - axis : - DESCRIPTION. - keepdims : - DESCRIPTION. - - Returns - ------- - tuple[_Dims, duckarray[Any, _DType]] - DESCRIPTION. - - """ - if data.shape == x.shape: - return x.dims, data - - removed_axes: np.ndarray[Any, np.dtype[np.intp]] - if axis is None: - removed_axes = np.arange(x.ndim, dtype=np.intp) - else: - removed_axes = np.atleast_1d(axis) % x.ndim - - if keepdims: - # Insert np.newaxis for removed dims - slices = tuple( - np.newaxis if i in removed_axes else slice(None, None) - for i in range(x.ndim) - ) - data = data[slices] - dims = x.dims - else: - dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) - - return dims, data - - @overload def _new( x: NamedArray[Any, _DType_co], @@ -279,14 +164,14 @@ def from_array( @overload def from_array( dims: _DimsLike, - data: _ArrayLike, + data: ArrayLike, attrs: _AttrsLike = ..., ) -> NamedArray[Any, Any]: ... def from_array( dims: _DimsLike, - data: duckarray[_ShapeType, _DType] | _ArrayLike, + data: duckarray[_ShapeType, _DType] | ArrayLike, attrs: _AttrsLike = None, ) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: """ @@ -296,7 +181,7 @@ def from_array( ---------- dims : str or iterable of str Name(s) of the dimension(s). - data : T_DuckArray or _ArrayLike + data : T_DuckArray or ArrayLike The actual data that populates the array. Should match the shape specified by `dims`. attrs : dict, optional @@ -900,7 +785,7 @@ def dims(self, value: _DimsLike) -> None: self._dims = self._parse_dimensions(value) def _parse_dimensions(self, dims: _DimsLike) -> _Dims: - dims = _normalize_dimensions(dims) + dims = (dims,) if isinstance(dims, str) else tuple(dims) if len(dims) != self.ndim: raise ValueError( f"dimensions {dims} must have the same length as the " @@ -1084,18 +969,15 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, . int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if dims is _default: - return None - - if isinstance(dims, tuple): - return tuple(self._dim_to_axis(d) for d in dims) - - return self._dim_to_axis(dims) + if not isinstance(dim, str) and isinstance(dim, Iterable): + return tuple(self._get_axis_num(d) for d in dim) + else: + return self._get_axis_num(dim) - def _dim_to_axis(self, dim: _Dim) -> int: + def _get_axis_num(self: Any, dim: Hashable) -> int: + _raise_if_any_duplicate_dimensions(self.dims) try: - out = self.dims.index(dim) - return out + return self.dims.index(dim) # type: ignore[no-any-return] except ValueError: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @@ -1257,8 +1139,8 @@ def as_numpy(self) -> Self: def reduce( self, func: Callable[..., Any], - dim: _DimsLikeAgg | Default = _default, - axis: int | Sequence[int] | None = None, # TODO: Use _AxisLike + dim: Dims = None, + axis: int | Sequence[int] | None = None, keepdims: bool = False, **kwargs: Any, ) -> NamedArray[Any, Any]: @@ -1290,43 +1172,54 @@ def reduce( Array with summarized data and the indicated dimension(s) removed. """ - d: _Dims | None - if dim is None or dim is ...: # TODO: isinstance(dim, types.EllipsisType) - # TODO: What's the point of ellipsis? Use either ... or None? - d = None - else: - dimslike: _DimsLike = dim - d = _normalize_dimensions(dimslike) + if dim == ...: + dim = None + if dim is not None and axis is not None: + raise ValueError("cannot supply both 'axis' and 'dim' arguments") - axislike: _AxisLike | None - if axis is None or isinstance(axis, int): - axislike = axis - else: - axislike = tuple(axis) - axis_ = _dims_to_axis(self, d, axislike) + if dim is not None: + axis = self.get_axis_num(dim) - data: duckarray[Any, Any] | _ArrayLike with warnings.catch_warnings(): warnings.filterwarnings( "ignore", r"Mean of empty slice", category=RuntimeWarning ) - if axis_ is not None: - if isinstance(axis_, tuple) and len(axis_) == 1: + if axis is not None: + if isinstance(axis, tuple) and len(axis) == 1: # unpack axis for the benefit of functions # like np.argmin which can't handle tuple arguments - data = func(self.data, axis=axis_[0], **kwargs) - else: - data = func(self.data, axis=axis_, **kwargs) + axis = axis[0] + data = func(self.data, axis=axis, **kwargs) else: data = func(self.data, **kwargs) - if not isinstance(data, _arrayfunction_or_api): - data = np.asarray(data) - - dims_, data = _get_remaining_dims(self, data, axis_, keepdims=keepdims) + if getattr(data, "shape", ()) == self.shape: + dims = self.dims + else: + removed_axes: Iterable[int] + if axis is None: + removed_axes = range(self.ndim) + else: + removed_axes = np.atleast_1d(axis) % self.ndim + if keepdims: + # Insert np.newaxis for removed dims + slices = tuple( + np.newaxis if i in removed_axes else slice(None, None) + for i in range(self.ndim) + ) + if getattr(data, "shape", None) is None: + # Reduce has produced a scalar value, not an array-like + data = np.asanyarray(data)[slices] + else: + data = data[slices] + dims = self.dims + else: + dims = tuple( + adim for n, adim in enumerate(self.dims) if n not in removed_axes + ) # Return NamedArray to handle IndexVariable when data is nD - return from_array(dims_, data, attrs=self._attrs) + return from_array(dims, data, attrs=self._attrs) def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" @@ -1349,7 +1242,7 @@ def _repr_html_(self) -> str: def _as_sparse( self, sparse_format: Literal["coo"] | Default = _default, - fill_value: _ArrayLike | Default = _default, + fill_value: ArrayLike | Default = _default, ) -> NamedArray[Any, _DType_co]: """ Use sparse-array as backend. From 83ebadc0a5baf992d6dd6e946f10dd79b58babb3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:13:24 +0200 Subject: [PATCH 153/367] Update xarray/core/variable.py --- xarray/core/variable.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 967de3babf4..3cd8e4acbd5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1653,11 +1653,6 @@ def reduce( # type: ignore[override] Array with summarized data and the indicated dimension(s) removed. """ - if dim is None: - from xarray.namedarray._typing import _default - - dim = _default - keep_attrs_ = ( _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs ) From c18c4cbdaac99c0eb50ae09ad1514c7439654624 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:15:47 +0200 Subject: [PATCH 154/367] Apply suggestions from code review --- xarray/namedarray/_typing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 00345a9ab29..190e8e3e649 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -76,7 +76,7 @@ def imag(self) -> _T_co: ... ] # For unknown shapes Dask uses np.nan, array_api uses None: -_IntOrUnknown = int # Union[int, _Unknown] +_IntOrUnknown = int _Shape = tuple[_IntOrUnknown, ...] _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] _ShapeType = TypeVar("_ShapeType", bound=Any) @@ -95,9 +95,7 @@ def imag(self) -> _T_co: ... _Dim = Hashable _Dims = tuple[_Dim, ...] -# _DimsLike = Union[str, Iterable[_Dim], Default] _DimsLike = Union[str, Iterable[_Dim]] -_DimsLikeAgg = Union[_DimsLike, "ellipsis", None] # https://data-apis.org/array-api/latest/API_specification/indexing.html # TODO: np.array_api was bugged and didn't allow (None,), but should! From 2953f7d54bea42e16f11d4d1e7e61619e7f7456b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:20:51 +0200 Subject: [PATCH 155/367] Update core.py --- xarray/namedarray/core.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d81fcbaa67d..c8482e2dfb2 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -50,22 +50,27 @@ ) if TYPE_CHECKING: - from numpy.typing import ArrayLike, NDArray + from enum import IntEnum - from xarray.core.types import Dims, T_Chunks + from numpy.typing import NDArray + + from xarray.core.types import T_Chunks from xarray.namedarray._typing import ( Default, + _ArrayLike, _AttrsLike, + _AxisLike, _Chunks, + _Device, _Dim, _Dims, _DimsLike, + _DimsLikeAgg, _DType, + _IndexKeyLike, _IntOrUnknown, _ScalarType, _Shape, - _ShapeType, - duckarray, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint From 2ddd601fead4d047c15e9833f3f97db4b3f40dab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 18:22:41 +0000 Subject: [PATCH 156/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c8482e2dfb2..435d7302e92 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -57,15 +57,12 @@ from xarray.core.types import T_Chunks from xarray.namedarray._typing import ( Default, - _ArrayLike, _AttrsLike, - _AxisLike, _Chunks, _Device, _Dim, _Dims, _DimsLike, - _DimsLikeAgg, _DType, _IndexKeyLike, _IntOrUnknown, From d8e76e013279aa1c8cf0b1bf0ea6b3a112593e98 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:27:12 +0200 Subject: [PATCH 157/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 89c8402ec0f..7d3f88fd3ef 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -279,7 +279,7 @@ "reshape", # "roll", # "squeeze", - # "stack", + "stack", ] from xarray.namedarray._array_api._searching_functions import ( From 832c275aca55828c81b299db27d20a818a641831 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:27:16 +0200 Subject: [PATCH 158/367] Update core.py --- xarray/namedarray/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c8482e2dfb2..b6777dc9461 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -31,6 +31,7 @@ _arrayfunction_or_api, _chunkedarray, _default, + duckarray, _dtype, _DType_co, _ScalarType_co, @@ -52,7 +53,7 @@ if TYPE_CHECKING: from enum import IntEnum - from numpy.typing import NDArray + from numpy.typing import NDArray, ArrayLike from xarray.core.types import T_Chunks from xarray.namedarray._typing import ( @@ -71,6 +72,7 @@ _IntOrUnknown, _ScalarType, _Shape, + _ShapeType, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint From c9ed055d9229446983c034912aa5b2dab5b731f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 18:28:43 +0000 Subject: [PATCH 159/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 9ec05f00290..df3498a32e8 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -31,7 +31,6 @@ _arrayfunction_or_api, _chunkedarray, _default, - duckarray, _dtype, _DType_co, _ScalarType_co, @@ -39,6 +38,7 @@ _sparsearrayfunction_or_api, _SupportsImag, _SupportsReal, + duckarray, ) from xarray.namedarray.parallelcompat import guess_chunkmanager from xarray.namedarray.pycompat import to_numpy @@ -53,7 +53,7 @@ if TYPE_CHECKING: from enum import IntEnum - from numpy.typing import NDArray, ArrayLike + from numpy.typing import ArrayLike, NDArray from xarray.core.types import T_Chunks from xarray.namedarray._typing import ( From 11cda780b3d4ca0c6159cf8d1a3eb93bb26c4cec Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:36:36 +0200 Subject: [PATCH 160/367] Update core.py --- xarray/namedarray/core.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index df3498a32e8..b6c3af4d060 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -55,7 +55,7 @@ from numpy.typing import ArrayLike, NDArray - from xarray.core.types import T_Chunks + from xarray.core.types import Dims, T_Chunks from xarray.namedarray._typing import ( Default, _AttrsLike, @@ -493,8 +493,14 @@ def to_device(self, device: _Device, /, stream: None = None) -> Self: raise NotImplementedError("Only array api are valid.") @property - def T(self): - raise NotImplementedError("Todo: ") + def T(self) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions.""" + if self.ndim != 2: + raise ValueError( + f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." + ) + + return self.permute_dims() # methods def __abs__(self, /): @@ -1329,16 +1335,6 @@ def permute_dims( return permute_dims(self, axes) - @property - def T(self) -> NamedArray[Any, _DType_co]: - """Return a new object with transposed dimensions.""" - if self.ndim != 2: - raise ValueError( - f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." - ) - - return self.permute_dims() - def broadcast_to( self, dim: Mapping[_Dim, int] | None = None, **dim_kwargs: Any ) -> NamedArray[Any, _DType_co]: From 6027533da269e942bc520145aca44493530d3ab5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:39:18 +0200 Subject: [PATCH 161/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index b6c3af4d060..042b01f31fd 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -744,7 +744,7 @@ def __rrshift__(self, other, /): # Indexing def __getitem__(self, key: _IndexKeyLike | NamedArray): - if isinstance(key, (int, slice, tuple)): + if isinstance(key, int | slice | tuple): _data = self._data[key] return self._new((), _data) elif isinstance(key, NamedArray): From 1e0843fc2453efab9134b11616993f810a7c76ba Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 21:01:22 +0200 Subject: [PATCH 162/367] add linear functions --- xarray/namedarray/_array_api/__init__.py | 22 +++++---- .../_array_api/_linear_algebra_functions.py | 47 ++++++++++++++++++- xarray/namedarray/_typing.py | 4 +- 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 7d3f88fd3ef..363b13cc5ff 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -244,15 +244,21 @@ "__array_namespace_info__", ] -# from xarray.namedarray._array_api._linear_algebra_functions import ( -# matmul, -# matrix_transpose, -# outer, -# tensordot, -# vecdot, -# ) +from xarray.namedarray._array_api._linear_algebra_functions import ( + matmul, + matrix_transpose, + outer, + tensordot, + vecdot, +) -# __all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] +__all__ += [ + "matmul", + "matrix_transpose", + "outer", + "tensordot", + "vecdot", +] from xarray.namedarray._array_api._manipulation_functions import ( # broadcast_arrays, diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index b62d0a8393b..c2f6ecb959a 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -1,3 +1,46 @@ -from xarray.namedarray._array_api._utils import _get_data_namespace +from __future__ import annotations -sdf = _get_data_namespace() +from typing import TYPE_CHECKING, Sequence + + +from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims + +from xarray.namedarray.core import NamedArray + + +def matmul(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.matmul(x1._data, x2._data) + # TODO: Figure out a better way: + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def tensordot( + x1: NamedArray, + x2: NamedArray, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.tensordot(x1._data, x2._data, axes=axes) + # TODO: Figure out a better way: + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def matrix_transpose(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.matrix_transpose(x._data) + # TODO: Figure out a better way: + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def vecdot(x1: NamedArray, x2: NamedArray, /, *, axis: int = -1) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.vecdot(x1._data, x2._data, axis=axis) + # TODO: Figure out a better way: + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 190e8e3e649..01de017ca47 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -129,11 +129,11 @@ class _FInfo(Protocol): _Capabilities = TypedDict( - "Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} + "_Capabilities", {"boolean indexing": bool, "data-dependent shapes": bool} ) _DefaultDataTypes = TypedDict( - "DefaultDataTypes", + "_DefaultDataTypes", { "real floating": _dtype, "complex floating": _dtype, From 7bc74c59e18a828e18b961fa49910f78a6d1209b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 19:02:46 +0000 Subject: [PATCH 163/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_linear_algebra_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index c2f6ecb959a..245f61f179e 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -1,10 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence - +from collections.abc import Sequence from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims - from xarray.namedarray.core import NamedArray From bdfde2680383057a845c53b9fd05ade3950fcf6c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 21:04:42 +0200 Subject: [PATCH 164/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 363b13cc5ff..8167d6adc8f 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -247,7 +247,6 @@ from xarray.namedarray._array_api._linear_algebra_functions import ( matmul, matrix_transpose, - outer, tensordot, vecdot, ) @@ -255,7 +254,6 @@ __all__ += [ "matmul", "matrix_transpose", - "outer", "tensordot", "vecdot", ] From 40609ca714500f149ea98d81f0ff79a33414f330 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 21:56:03 +0200 Subject: [PATCH 165/367] Update core.py --- xarray/namedarray/core.py | 468 +++++++++++++++++++++----------------- 1 file changed, 260 insertions(+), 208 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 042b01f31fd..9cfc94f5356 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -413,100 +413,22 @@ def __len__(self) -> _IntOrUnknown: except Exception as exc: raise TypeError("len() of unsized object") from exc - # Array API: - # Attributes: + # < Array api > - @property - def device(self) -> _Device: - """ - Device of the array’s elements. - - See Also - -------- - ndarray.device - """ - if isinstance(self._data, _arrayapi): - return self._data.device - else: - raise NotImplementedError("self._data missing device") - - @property - def dtype(self) -> _DType_co: - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def mT(self): - raise NotImplementedError("Todo: ") - - @property - def ndim(self) -> int: - """ - Number of array dimensions. - - See Also - -------- - numpy.ndarray.ndim - """ - return len(self.shape) - - @property - def shape(self) -> _Shape: - """ - Get the shape of the array. - - Returns - ------- - shape : tuple of ints - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape - - @property - def size(self) -> _IntOrUnknown: - """ - Number of elements in the array. - - Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. - - See Also - -------- - numpy.ndarray.size - """ - return math.prod(self.shape) + def __abs__(self, /) -> Self: + from xarray.namedarray._array_api import abs - def to_device(self, device: _Device, /, stream: None = None) -> Self: - if isinstance(self._data, _arrayapi): - return self._replace(data=self._data.to_device(device, stream=stream)) - else: - raise NotImplementedError("Only array api are valid.") + return abs(self) - @property - def T(self) -> NamedArray[Any, _DType_co]: - """Return a new object with transposed dimensions.""" - if self.ndim != 2: - raise ValueError( - f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." - ) + def __add__(self, other: int | float | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import add, asarray - return self.permute_dims() + return add(self, asarray(other)) - # methods - def __abs__(self, /): - from xarray.namedarray._array_api import abs + def __and__(self, other, /): + from xarray.namedarray._array_api import bitwise_and - return abs(self) + return bitwise_and(self, asarray(other)) # def __array_namespace__(self, /, *, api_version=None): # if api_version is not None and api_version not in ( @@ -525,16 +447,6 @@ def __bool__(self, /) -> bool: def __complex__(self, /) -> complex: return self._data.__complex__() - def __float__(self, /) -> float: - return self._data.__float__() - - def __index__(self, /) -> int: - return self._data.__index__() - - def __int__(self, /) -> int: - return self._data.__int__() - - # dlpack def __dlpack__( self, /, @@ -551,208 +463,348 @@ def __dlpack__( def __dlpack_device__(self, /) -> tuple[IntEnum, int]: return self._data.__dlpack_device__() - # Arithmetic Operators + def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import asarray, equal - def __neg__(self, /): - from xarray.namedarray._array_api import negative + return equal(self, asarray(other)) - return negative(self) + def __float__(self, /) -> float: + return self._data.__float__() - def __pos__(self, /): - from xarray.namedarray._array_api import positive + def __floordiv__(self, other, /): + from xarray.namedarray._array_api import floor_divide - return positive(self) + return floor_divide(self, asarray(other)) - def __add__(self, other: int | float | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import add, asarray + def __ge__(self, other, /): + from xarray.namedarray._array_api import greater_equal - return add(self, asarray(other)) + return greater_equal(self, asarray(other)) - def __sub__(self, other, /): - from xarray.namedarray._array_api import subtract + def __getitem__(self, key: _IndexKeyLike | NamedArray): + if isinstance(key, int | slice | tuple): + _data = self._data[key] + return self._new((), _data) + elif isinstance(key, NamedArray): + _key = self._data # TODO: Transpose, unordered dims shouldn't matter. + _data = self._data[_key] + return self._new(key._dims, _data) + else: + raise NotImplementedError("{k=} is not supported") - return subtract(self, other) + def __gt__(self, other, /): + from xarray.namedarray._array_api import greater - def __mul__(self, other, /): - from xarray.namedarray._array_api import multiply + return greater(self, asarray(other)) - return multiply(self, other) + def __index__(self, /) -> int: + return self._data.__index__() - def __truediv__(self, other, /): - from xarray.namedarray._array_api import divide + def __int__(self, /) -> int: + return self._data.__int__() - return divide(self, other) + def __invert__(self, /): + from xarray.namedarray._array_api import bitwise_invert - def __floordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide + return bitwise_invert(self) - return floor_divide(self, other) + def __iter__(self: NamedArray, /): + from xarray.namedarray._array_api import asarray - def __mod__(self, other, /): - from xarray.namedarray._array_api import remainder + # TODO: smarter way to retain dims, xarray? + return (asarray(i) for i in self._data) - return remainder(self, other) + def __le__(self, other, /): + from xarray.namedarray._array_api import less_equal - def __pow__(self, other, /): - from xarray.namedarray._array_api import pow + return less_equal(self, asarray(other)) - return pow(self, other) + def __lshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_left_shift - # Array Operators + return bitwise_left_shift(self) + + def __lt__(self, other, /): + from xarray.namedarray._array_api import less + + return less(self, asarray(other)) def __matmul__(self, other, /): from xarray.namedarray._array_api import matmul - return matmul(self, other) + return matmul(self, asarray(other)) - # Bitwise Operators + def __mod__(self, other, /): + from xarray.namedarray._array_api import remainder - def __invert__(self, /): - from xarray.namedarray._array_api import bitwise_invert + return remainder(self, asarray(other)) - return bitwise_invert(self) + def __mul__(self, other, /): + from xarray.namedarray._array_api import multiply - def __and__(self, other, /): - from xarray.namedarray._array_api import bitwise_and + return multiply(self, asarray(other)) + + def __ne__(self, other, /): + from xarray.namedarray._array_api import not_equal + + return not_equal(self, asarray(other)) + + def __neg__(self, /): + from xarray.namedarray._array_api import negative - return bitwise_and(self) + return negative(self) def __or__(self, other, /): from xarray.namedarray._array_api import bitwise_or return bitwise_or(self) - def __xor__(self, other, /): - from xarray.namedarray._array_api import bitwise_xor + def __pos__(self, /): + from xarray.namedarray._array_api import positive - return bitwise_xor(self) + return positive(self) - def __lshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_left_shift + def __pow__(self, other, /): + from xarray.namedarray._array_api import pow - return bitwise_left_shift(self) + return pow(self, asarray(other)) def __rshift__(self, other, /): from xarray.namedarray._array_api import bitwise_right_shift return bitwise_right_shift(self) - # Comparison Operators - def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import asarray, equal + def __setitem__( + self, + key: _IndexKeyLike, + value: int | float | bool | NamedArray, + /, + ) -> None: + from xarray.namedarray._array_api import asarray - return equal(self, asarray(other)) + if isinstance(key, NamedArray): + key = key._data + self._array.__setitem__(key, asarray(value)._data) - def __ge__(self, other, /): - from xarray.namedarray._array_api import greater_equal + def __sub__(self, other, /): + from xarray.namedarray._array_api import subtract - return greater_equal(self, other) + return subtract(self, asarray(other)) - def __gt__(self, other, /): - from xarray.namedarray._array_api import greater + def __truediv__(self, other, /): + from xarray.namedarray._array_api import divide - return greater(self, other) + return divide(self, asarray(other)) - def __le__(self, other, /): - from xarray.namedarray._array_api import less_equal + def __xor__(self, other, /): + from xarray.namedarray._array_api import bitwise_xor - return less_equal(self, other) + return bitwise_xor(self) - def __lt__(self, other, /): - from xarray.namedarray._array_api import less + def __iadd__(self, other, /): + self._data.__iadd__(other._data) + return self - return less(self, other) + def __radd__(self, other, /): + from xarray.namedarray._array_api import add - def __ne__(self, other, /): - from xarray.namedarray._array_api import not_equal + return add(asarray(other), self) - return not_equal(self, other) + def __iand__(self, other, /): + self._data.__iand__(other._data) + return self - # Reflected Operators + def __rand__(self, other, /): + from xarray.namedarray._array_api import bitwise_and - # (Reflected) Arithmetic Operators + return bitwise_and(asarray(other), self) - def __radd__(self, other, /): - from xarray.namedarray._array_api import add + def __ifloordiv__(self, other, /): + self._data.__ifloordiv__(other._data) + return self - return add(other, self) + def __rfloordiv__(self, other, /): + from xarray.namedarray._array_api import floor_divide - def __rsub__(self, other, /): - from xarray.namedarray._array_api import subtract + return floor_divide(asarray(other), self) - return subtract(other, self) + def __ilshift__(self, other, /): + self._data.__ilshift__(other._data) + return self - def __rmul__(self, other, /): - from xarray.namedarray._array_api import multiply + def __rlshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_left_shift - return multiply(other, self) + return bitwise_left_shift(asarray(other), self) - def __rtruediv__(self, other, /): - from xarray.namedarray._array_api import divide + def __imatmul__(self, other, /): + self._data.__imatmul__(other._data) + return self - return divide(other, self) + def __rmatmul__(self, other, /): + from xarray.namedarray._array_api import matmul - def __rfloordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide + return matmul(asarray(other), self) - return floor_divide(other, self) + def __imod__(self, other, /): + self._data.__imod__(other._data) + return self def __rmod__(self, other, /): from xarray.namedarray._array_api import remainder - return remainder(other, self) + return remainder(asarray(other), self) + + def __imul__(self, other, /): + self._data.__imul__(other._data) + return self + + def __rmul__(self, other, /): + from xarray.namedarray._array_api import multiply + + return multiply(asarray(other), self) + + def __ior__(self, other, /): + self._data.__ior__(other._data) + return self + + def __ror__(self, other, /): + from xarray.namedarray._array_api import bitwise_or + + return bitwise_or(asarray(other), self) + + def __ipow__(self, other, /): + self._data.__ipow__(other._data) + return self def __rpow__(self, other, /): from xarray.namedarray._array_api import pow - return pow(other, self) + return pow(asarray(other), self) - # (Reflected) Array Operators + def __irshift__(self, other, /): + self._data.__irshift__(other._data) + return self - def __rmatmul__(self, other, /): - from xarray.namedarray._array_api import matmul + def __rrshift__(self, other, /): + from xarray.namedarray._array_api import bitwise_right_shift - return matmul(other, self) + return bitwise_right_shift(asarray(other), self) - # (Reflected) Bitwise Operators + def __isub__(self, other, /): + self._data.__isub__(other._data) + return self - def __rand__(self, other, /): - from xarray.namedarray._array_api import bitwise_and + def __rsub__(self, other, /): + from xarray.namedarray._array_api import subtract - return bitwise_and(other, self) + return subtract(asarray(other), self) - def __ror__(self, other, /): - from xarray.namedarray._array_api import bitwise_or + def __itruediv__(self, other, /): + self._data.__itruediv__(asarray(other)._data) + return self - return bitwise_or(other, self) + def __rtruediv__(self, other, /): + from xarray.namedarray._array_api import divide + + return divide(asarray(other), self) + + def __ixor__(self, other, /): + self._data.__ixor__(other._data) + return self def __rxor__(self, other, /): from xarray.namedarray._array_api import bitwise_xor - return bitwise_xor(other, self) - - def __rlshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_left_shift + return bitwise_xor(asarray(other), self) - return bitwise_left_shift(other, self) + def to_device(self, device: _Device, /, stream: None = None) -> Self: + if isinstance(self._data, _arrayapi): + return self._replace(data=self._data.to_device(device, stream=stream)) + else: + raise NotImplementedError("Only array api are valid.") - def __rrshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_right_shift + @property + def dtype(self) -> _DType_co: + """ + Data-type of the array’s elements. - return bitwise_right_shift(other, self) + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype - # Indexing + @property + def device(self) -> _Device: + """ + Device of the array’s elements. - def __getitem__(self, key: _IndexKeyLike | NamedArray): - if isinstance(key, int | slice | tuple): - _data = self._data[key] - return self._new((), _data) - elif isinstance(key, NamedArray): - _key = self._data # TODO: Transpose, unordered dims shouldn't matter. - _data = self._data[_key] - return self._new(key._dims, _data) + See Also + -------- + ndarray.device + """ + if isinstance(self._data, _arrayapi): + return self._data.device else: - raise NotImplementedError("{k=} is not supported") + raise NotImplementedError("self._data missing device") + + @property + def mT(self): + raise NotImplementedError("Todo: ") + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def shape(self) -> _Shape: + """ + Get the shape of the array. + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def size(self) -> _IntOrUnknown: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + @property + def T(self) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions.""" + if self.ndim != 2: + raise ValueError( + f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." + ) + + return self.permute_dims() + + # @property def nbytes(self) -> _IntOrUnknown: From 9aa110ed3e02b5cd7e060ce365542f3271e2a471 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 21:59:03 +0200 Subject: [PATCH 166/367] Update core.py --- xarray/namedarray/core.py | 54 +++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 9cfc94f5356..4f5ed7e8dfe 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -426,7 +426,7 @@ def __add__(self, other: int | float | NamedArray, /) -> NamedArray: return add(self, asarray(other)) def __and__(self, other, /): - from xarray.namedarray._array_api import bitwise_and + from xarray.namedarray._array_api import bitwise_and, asarray return bitwise_and(self, asarray(other)) @@ -472,12 +472,12 @@ def __float__(self, /) -> float: return self._data.__float__() def __floordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide + from xarray.namedarray._array_api import floor_divide, asarray return floor_divide(self, asarray(other)) def __ge__(self, other, /): - from xarray.namedarray._array_api import greater_equal + from xarray.namedarray._array_api import greater_equal, asarray return greater_equal(self, asarray(other)) @@ -493,7 +493,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray): raise NotImplementedError("{k=} is not supported") def __gt__(self, other, /): - from xarray.namedarray._array_api import greater + from xarray.namedarray._array_api import greater, asarray return greater(self, asarray(other)) @@ -515,7 +515,7 @@ def __iter__(self: NamedArray, /): return (asarray(i) for i in self._data) def __le__(self, other, /): - from xarray.namedarray._array_api import less_equal + from xarray.namedarray._array_api import less_equal, asarray return less_equal(self, asarray(other)) @@ -525,27 +525,27 @@ def __lshift__(self, other, /): return bitwise_left_shift(self) def __lt__(self, other, /): - from xarray.namedarray._array_api import less + from xarray.namedarray._array_api import less, asarray return less(self, asarray(other)) def __matmul__(self, other, /): - from xarray.namedarray._array_api import matmul + from xarray.namedarray._array_api import matmul, asarray return matmul(self, asarray(other)) def __mod__(self, other, /): - from xarray.namedarray._array_api import remainder + from xarray.namedarray._array_api import remainder, asarray return remainder(self, asarray(other)) def __mul__(self, other, /): - from xarray.namedarray._array_api import multiply + from xarray.namedarray._array_api import multiply, asarray return multiply(self, asarray(other)) def __ne__(self, other, /): - from xarray.namedarray._array_api import not_equal + from xarray.namedarray._array_api import not_equal, asarray return not_equal(self, asarray(other)) @@ -565,7 +565,7 @@ def __pos__(self, /): return positive(self) def __pow__(self, other, /): - from xarray.namedarray._array_api import pow + from xarray.namedarray._array_api import pow, asarray return pow(self, asarray(other)) @@ -587,12 +587,12 @@ def __setitem__( self._array.__setitem__(key, asarray(value)._data) def __sub__(self, other, /): - from xarray.namedarray._array_api import subtract + from xarray.namedarray._array_api import subtract, asarray return subtract(self, asarray(other)) def __truediv__(self, other, /): - from xarray.namedarray._array_api import divide + from xarray.namedarray._array_api import divide, asarray return divide(self, asarray(other)) @@ -606,7 +606,7 @@ def __iadd__(self, other, /): return self def __radd__(self, other, /): - from xarray.namedarray._array_api import add + from xarray.namedarray._array_api import add, asarray return add(asarray(other), self) @@ -615,7 +615,7 @@ def __iand__(self, other, /): return self def __rand__(self, other, /): - from xarray.namedarray._array_api import bitwise_and + from xarray.namedarray._array_api import bitwise_and, asarray return bitwise_and(asarray(other), self) @@ -624,7 +624,7 @@ def __ifloordiv__(self, other, /): return self def __rfloordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide + from xarray.namedarray._array_api import floor_divide, asarray return floor_divide(asarray(other), self) @@ -633,7 +633,7 @@ def __ilshift__(self, other, /): return self def __rlshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_left_shift + from xarray.namedarray._array_api import bitwise_left_shift, asarray return bitwise_left_shift(asarray(other), self) @@ -642,7 +642,7 @@ def __imatmul__(self, other, /): return self def __rmatmul__(self, other, /): - from xarray.namedarray._array_api import matmul + from xarray.namedarray._array_api import matmul, asarray return matmul(asarray(other), self) @@ -651,7 +651,7 @@ def __imod__(self, other, /): return self def __rmod__(self, other, /): - from xarray.namedarray._array_api import remainder + from xarray.namedarray._array_api import remainder, asarray return remainder(asarray(other), self) @@ -660,7 +660,7 @@ def __imul__(self, other, /): return self def __rmul__(self, other, /): - from xarray.namedarray._array_api import multiply + from xarray.namedarray._array_api import multiply, asarray return multiply(asarray(other), self) @@ -669,7 +669,7 @@ def __ior__(self, other, /): return self def __ror__(self, other, /): - from xarray.namedarray._array_api import bitwise_or + from xarray.namedarray._array_api import bitwise_or, asarray return bitwise_or(asarray(other), self) @@ -678,7 +678,7 @@ def __ipow__(self, other, /): return self def __rpow__(self, other, /): - from xarray.namedarray._array_api import pow + from xarray.namedarray._array_api import pow, asarray return pow(asarray(other), self) @@ -687,7 +687,7 @@ def __irshift__(self, other, /): return self def __rrshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_right_shift + from xarray.namedarray._array_api import bitwise_right_shift, asarray return bitwise_right_shift(asarray(other), self) @@ -696,16 +696,16 @@ def __isub__(self, other, /): return self def __rsub__(self, other, /): - from xarray.namedarray._array_api import subtract + from xarray.namedarray._array_api import subtract, asarray return subtract(asarray(other), self) def __itruediv__(self, other, /): - self._data.__itruediv__(asarray(other)._data) + self._data.__itruediv__(other._data) return self def __rtruediv__(self, other, /): - from xarray.namedarray._array_api import divide + from xarray.namedarray._array_api import divide, asarray return divide(asarray(other), self) @@ -714,7 +714,7 @@ def __ixor__(self, other, /): return self def __rxor__(self, other, /): - from xarray.namedarray._array_api import bitwise_xor + from xarray.namedarray._array_api import bitwise_xor, asarray return bitwise_xor(asarray(other), self) From d4f9b4898703050123087e5447bdefa3dfac164a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:00:28 +0000 Subject: [PATCH 167/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 50 +++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 4f5ed7e8dfe..0f733491e96 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -426,7 +426,7 @@ def __add__(self, other: int | float | NamedArray, /) -> NamedArray: return add(self, asarray(other)) def __and__(self, other, /): - from xarray.namedarray._array_api import bitwise_and, asarray + from xarray.namedarray._array_api import asarray, bitwise_and return bitwise_and(self, asarray(other)) @@ -472,12 +472,12 @@ def __float__(self, /) -> float: return self._data.__float__() def __floordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide, asarray + from xarray.namedarray._array_api import asarray, floor_divide return floor_divide(self, asarray(other)) def __ge__(self, other, /): - from xarray.namedarray._array_api import greater_equal, asarray + from xarray.namedarray._array_api import asarray, greater_equal return greater_equal(self, asarray(other)) @@ -493,7 +493,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray): raise NotImplementedError("{k=} is not supported") def __gt__(self, other, /): - from xarray.namedarray._array_api import greater, asarray + from xarray.namedarray._array_api import asarray, greater return greater(self, asarray(other)) @@ -515,7 +515,7 @@ def __iter__(self: NamedArray, /): return (asarray(i) for i in self._data) def __le__(self, other, /): - from xarray.namedarray._array_api import less_equal, asarray + from xarray.namedarray._array_api import asarray, less_equal return less_equal(self, asarray(other)) @@ -525,27 +525,27 @@ def __lshift__(self, other, /): return bitwise_left_shift(self) def __lt__(self, other, /): - from xarray.namedarray._array_api import less, asarray + from xarray.namedarray._array_api import asarray, less return less(self, asarray(other)) def __matmul__(self, other, /): - from xarray.namedarray._array_api import matmul, asarray + from xarray.namedarray._array_api import asarray, matmul return matmul(self, asarray(other)) def __mod__(self, other, /): - from xarray.namedarray._array_api import remainder, asarray + from xarray.namedarray._array_api import asarray, remainder return remainder(self, asarray(other)) def __mul__(self, other, /): - from xarray.namedarray._array_api import multiply, asarray + from xarray.namedarray._array_api import asarray, multiply return multiply(self, asarray(other)) def __ne__(self, other, /): - from xarray.namedarray._array_api import not_equal, asarray + from xarray.namedarray._array_api import asarray, not_equal return not_equal(self, asarray(other)) @@ -565,7 +565,7 @@ def __pos__(self, /): return positive(self) def __pow__(self, other, /): - from xarray.namedarray._array_api import pow, asarray + from xarray.namedarray._array_api import asarray, pow return pow(self, asarray(other)) @@ -587,12 +587,12 @@ def __setitem__( self._array.__setitem__(key, asarray(value)._data) def __sub__(self, other, /): - from xarray.namedarray._array_api import subtract, asarray + from xarray.namedarray._array_api import asarray, subtract return subtract(self, asarray(other)) def __truediv__(self, other, /): - from xarray.namedarray._array_api import divide, asarray + from xarray.namedarray._array_api import asarray, divide return divide(self, asarray(other)) @@ -615,7 +615,7 @@ def __iand__(self, other, /): return self def __rand__(self, other, /): - from xarray.namedarray._array_api import bitwise_and, asarray + from xarray.namedarray._array_api import asarray, bitwise_and return bitwise_and(asarray(other), self) @@ -624,7 +624,7 @@ def __ifloordiv__(self, other, /): return self def __rfloordiv__(self, other, /): - from xarray.namedarray._array_api import floor_divide, asarray + from xarray.namedarray._array_api import asarray, floor_divide return floor_divide(asarray(other), self) @@ -633,7 +633,7 @@ def __ilshift__(self, other, /): return self def __rlshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_left_shift, asarray + from xarray.namedarray._array_api import asarray, bitwise_left_shift return bitwise_left_shift(asarray(other), self) @@ -642,7 +642,7 @@ def __imatmul__(self, other, /): return self def __rmatmul__(self, other, /): - from xarray.namedarray._array_api import matmul, asarray + from xarray.namedarray._array_api import asarray, matmul return matmul(asarray(other), self) @@ -651,7 +651,7 @@ def __imod__(self, other, /): return self def __rmod__(self, other, /): - from xarray.namedarray._array_api import remainder, asarray + from xarray.namedarray._array_api import asarray, remainder return remainder(asarray(other), self) @@ -660,7 +660,7 @@ def __imul__(self, other, /): return self def __rmul__(self, other, /): - from xarray.namedarray._array_api import multiply, asarray + from xarray.namedarray._array_api import asarray, multiply return multiply(asarray(other), self) @@ -669,7 +669,7 @@ def __ior__(self, other, /): return self def __ror__(self, other, /): - from xarray.namedarray._array_api import bitwise_or, asarray + from xarray.namedarray._array_api import asarray, bitwise_or return bitwise_or(asarray(other), self) @@ -678,7 +678,7 @@ def __ipow__(self, other, /): return self def __rpow__(self, other, /): - from xarray.namedarray._array_api import pow, asarray + from xarray.namedarray._array_api import asarray, pow return pow(asarray(other), self) @@ -687,7 +687,7 @@ def __irshift__(self, other, /): return self def __rrshift__(self, other, /): - from xarray.namedarray._array_api import bitwise_right_shift, asarray + from xarray.namedarray._array_api import asarray, bitwise_right_shift return bitwise_right_shift(asarray(other), self) @@ -696,7 +696,7 @@ def __isub__(self, other, /): return self def __rsub__(self, other, /): - from xarray.namedarray._array_api import subtract, asarray + from xarray.namedarray._array_api import asarray, subtract return subtract(asarray(other), self) @@ -705,7 +705,7 @@ def __itruediv__(self, other, /): return self def __rtruediv__(self, other, /): - from xarray.namedarray._array_api import divide, asarray + from xarray.namedarray._array_api import asarray, divide return divide(asarray(other), self) @@ -714,7 +714,7 @@ def __ixor__(self, other, /): return self def __rxor__(self, other, /): - from xarray.namedarray._array_api import bitwise_xor, asarray + from xarray.namedarray._array_api import asarray, bitwise_xor return bitwise_xor(asarray(other), self) From 5d6b1308edb74586630da5fa52fc9255d3a880ff Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:16:18 +0200 Subject: [PATCH 168/367] Update core.py --- xarray/namedarray/core.py | 104 +++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 0f733491e96..25c6070b30c 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -425,21 +425,21 @@ def __add__(self, other: int | float | NamedArray, /) -> NamedArray: return add(self, asarray(other)) - def __and__(self, other, /): + def __and__(self, other: int | bool | NamedArray, /) -> NamedArray: from xarray.namedarray._array_api import asarray, bitwise_and return bitwise_and(self, asarray(other)) - # def __array_namespace__(self, /, *, api_version=None): - # if api_version is not None and api_version not in ( - # "2021.12", - # "2022.12", - # "2023.12", - # ): - # raise ValueError(f"Unrecognized array API version: {api_version!r}") - # import xarray.namedarray._array_api as array_api + def __array_namespace__(self, /, *, api_version: str | None = None): + if api_version is not None and api_version not in ( + "2021.12", + "2022.12", + "2023.12", + ): + raise ValueError(f"Unrecognized array API version: {api_version!r}") + import xarray.namedarray._array_api as array_api - # return array_api + return array_api def __bool__(self, /) -> bool: return self._data.__bool__() @@ -471,17 +471,17 @@ def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: def __float__(self, /) -> float: return self._data.__float__() - def __floordiv__(self, other, /): + def __floordiv__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, floor_divide return floor_divide(self, asarray(other)) - def __ge__(self, other, /): + def __ge__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, greater_equal return greater_equal(self, asarray(other)) - def __getitem__(self, key: _IndexKeyLike | NamedArray): + def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: if isinstance(key, int | slice | tuple): _data = self._data[key] return self._new((), _data) @@ -492,7 +492,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray): else: raise NotImplementedError("{k=} is not supported") - def __gt__(self, other, /): + def __gt__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, greater return greater(self, asarray(other)) @@ -514,37 +514,37 @@ def __iter__(self: NamedArray, /): # TODO: smarter way to retain dims, xarray? return (asarray(i) for i in self._data) - def __le__(self, other, /): + def __le__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, less_equal return less_equal(self, asarray(other)) - def __lshift__(self, other, /): + def __lshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import bitwise_left_shift return bitwise_left_shift(self) - def __lt__(self, other, /): + def __lt__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, less return less(self, asarray(other)) - def __matmul__(self, other, /): + def __matmul__(self, other: NamedArray, /): from xarray.namedarray._array_api import asarray, matmul return matmul(self, asarray(other)) - def __mod__(self, other, /): + def __mod__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, remainder return remainder(self, asarray(other)) - def __mul__(self, other, /): + def __mul__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, multiply return multiply(self, asarray(other)) - def __ne__(self, other, /): + def __ne__(self, other: int | float | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, not_equal return not_equal(self, asarray(other)) @@ -554,7 +554,7 @@ def __neg__(self, /): return negative(self) - def __or__(self, other, /): + def __or__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import bitwise_or return bitwise_or(self) @@ -564,12 +564,12 @@ def __pos__(self, /): return positive(self) - def __pow__(self, other, /): + def __pow__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, pow return pow(self, asarray(other)) - def __rshift__(self, other, /): + def __rshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import bitwise_right_shift return bitwise_right_shift(self) @@ -586,130 +586,130 @@ def __setitem__( key = key._data self._array.__setitem__(key, asarray(value)._data) - def __sub__(self, other, /): + def __sub__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, subtract return subtract(self, asarray(other)) - def __truediv__(self, other, /): + def __truediv__(self, other: float | NamedArray, /): from xarray.namedarray._array_api import asarray, divide return divide(self, asarray(other)) - def __xor__(self, other, /): + def __xor__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import bitwise_xor return bitwise_xor(self) - def __iadd__(self, other, /): + def __iadd__(self, other: int | float | NamedArray, /): self._data.__iadd__(other._data) return self - def __radd__(self, other, /): + def __radd__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import add, asarray return add(asarray(other), self) - def __iand__(self, other, /): + def __iand__(self, other: int | bool | NamedArray, /): self._data.__iand__(other._data) return self - def __rand__(self, other, /): + def __rand__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_and return bitwise_and(asarray(other), self) - def __ifloordiv__(self, other, /): + def __ifloordiv__(self, other: int | float | NamedArray, /): self._data.__ifloordiv__(other._data) return self - def __rfloordiv__(self, other, /): + def __rfloordiv__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, floor_divide return floor_divide(asarray(other), self) - def __ilshift__(self, other, /): + def __ilshift__(self, other: int | NamedArray, /): self._data.__ilshift__(other._data) return self - def __rlshift__(self, other, /): + def __rlshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_left_shift return bitwise_left_shift(asarray(other), self) - def __imatmul__(self, other, /): + def __imatmul__(self, other: NamedArray, /): self._data.__imatmul__(other._data) return self - def __rmatmul__(self, other, /): + def __rmatmul__(self, other: NamedArray, /): from xarray.namedarray._array_api import asarray, matmul return matmul(asarray(other), self) - def __imod__(self, other, /): + def __imod__(self, other: int | float | NamedArray, /): self._data.__imod__(other._data) return self - def __rmod__(self, other, /): + def __rmod__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, remainder return remainder(asarray(other), self) - def __imul__(self, other, /): + def __imul__(self, other: int | float | NamedArray, /): self._data.__imul__(other._data) return self - def __rmul__(self, other, /): + def __rmul__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, multiply return multiply(asarray(other), self) - def __ior__(self, other, /): + def __ior__(self, other: int | bool | NamedArray, /): self._data.__ior__(other._data) return self - def __ror__(self, other, /): + def __ror__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_or return bitwise_or(asarray(other), self) - def __ipow__(self, other, /): + def __ipow__(self, other: int | float | NamedArray, /): self._data.__ipow__(other._data) return self - def __rpow__(self, other, /): + def __rpow__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, pow return pow(asarray(other), self) - def __irshift__(self, other, /): + def __irshift__(self, other: int | NamedArray, /): self._data.__irshift__(other._data) return self - def __rrshift__(self, other, /): + def __rrshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_right_shift return bitwise_right_shift(asarray(other), self) - def __isub__(self, other, /): + def __isub__(self, other: int | float | NamedArray, /): self._data.__isub__(other._data) return self - def __rsub__(self, other, /): + def __rsub__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, subtract return subtract(asarray(other), self) - def __itruediv__(self, other, /): + def __itruediv__(self, other: float | NamedArray, /): self._data.__itruediv__(other._data) return self - def __rtruediv__(self, other, /): + def __rtruediv__(self, other: float | NamedArray, /): from xarray.namedarray._array_api import asarray, divide return divide(asarray(other), self) - def __ixor__(self, other, /): + def __ixor__(self, other: int | bool | NamedArray, /): self._data.__ixor__(other._data) return self From 3833e04a6e4dcc6dc6b6f601e1644a9c89858967 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:23:42 +0200 Subject: [PATCH 169/367] Update core.py --- xarray/namedarray/core.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 25c6070b30c..c5e4602aca2 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -602,7 +602,7 @@ def __xor__(self, other: int | bool | NamedArray, /): return bitwise_xor(self) def __iadd__(self, other: int | float | NamedArray, /): - self._data.__iadd__(other._data) + self._data.__iadd__(asarray(other)._data) return self def __radd__(self, other: int | float | NamedArray, /): @@ -611,7 +611,7 @@ def __radd__(self, other: int | float | NamedArray, /): return add(asarray(other), self) def __iand__(self, other: int | bool | NamedArray, /): - self._data.__iand__(other._data) + self._data.__iand__(asarray(other)._data) return self def __rand__(self, other: int | bool | NamedArray, /): @@ -620,7 +620,7 @@ def __rand__(self, other: int | bool | NamedArray, /): return bitwise_and(asarray(other), self) def __ifloordiv__(self, other: int | float | NamedArray, /): - self._data.__ifloordiv__(other._data) + self._data.__ifloordiv__(asarray(other)._data) return self def __rfloordiv__(self, other: int | float | NamedArray, /): @@ -629,7 +629,7 @@ def __rfloordiv__(self, other: int | float | NamedArray, /): return floor_divide(asarray(other), self) def __ilshift__(self, other: int | NamedArray, /): - self._data.__ilshift__(other._data) + self._data.__ilshift__(asarray(other)._data) return self def __rlshift__(self, other: int | NamedArray, /): @@ -647,7 +647,7 @@ def __rmatmul__(self, other: NamedArray, /): return matmul(asarray(other), self) def __imod__(self, other: int | float | NamedArray, /): - self._data.__imod__(other._data) + self._data.__imod__(asarray(other)._data) return self def __rmod__(self, other: int | float | NamedArray, /): @@ -656,7 +656,7 @@ def __rmod__(self, other: int | float | NamedArray, /): return remainder(asarray(other), self) def __imul__(self, other: int | float | NamedArray, /): - self._data.__imul__(other._data) + self._data.__imul__(asarray(other)._data) return self def __rmul__(self, other: int | float | NamedArray, /): @@ -665,7 +665,7 @@ def __rmul__(self, other: int | float | NamedArray, /): return multiply(asarray(other), self) def __ior__(self, other: int | bool | NamedArray, /): - self._data.__ior__(other._data) + self._data.__ior__(asarray(other)._data) return self def __ror__(self, other: int | bool | NamedArray, /): @@ -674,7 +674,7 @@ def __ror__(self, other: int | bool | NamedArray, /): return bitwise_or(asarray(other), self) def __ipow__(self, other: int | float | NamedArray, /): - self._data.__ipow__(other._data) + self._data.__ipow__(asarray(other)._data) return self def __rpow__(self, other: int | float | NamedArray, /): @@ -683,7 +683,7 @@ def __rpow__(self, other: int | float | NamedArray, /): return pow(asarray(other), self) def __irshift__(self, other: int | NamedArray, /): - self._data.__irshift__(other._data) + self._data.__irshift__(asarray(other)._data) return self def __rrshift__(self, other: int | NamedArray, /): @@ -692,7 +692,7 @@ def __rrshift__(self, other: int | NamedArray, /): return bitwise_right_shift(asarray(other), self) def __isub__(self, other: int | float | NamedArray, /): - self._data.__isub__(other._data) + self._data.__isub__(asarray(other)._data) return self def __rsub__(self, other: int | float | NamedArray, /): @@ -701,7 +701,7 @@ def __rsub__(self, other: int | float | NamedArray, /): return subtract(asarray(other), self) def __itruediv__(self, other: float | NamedArray, /): - self._data.__itruediv__(other._data) + self._data.__itruediv__(asarray(other)._data) return self def __rtruediv__(self, other: float | NamedArray, /): @@ -710,7 +710,7 @@ def __rtruediv__(self, other: float | NamedArray, /): return divide(asarray(other), self) def __ixor__(self, other: int | bool | NamedArray, /): - self._data.__ixor__(other._data) + self._data.__ixor__(asarray(other)._data) return self def __rxor__(self, other, /): From 71d4876637b3f18fda6207fd000f776fcf0e336c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:27:40 +0200 Subject: [PATCH 170/367] Update core.py --- xarray/namedarray/core.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c5e4602aca2..c8ddb880ceb 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -602,6 +602,8 @@ def __xor__(self, other: int | bool | NamedArray, /): return bitwise_xor(self) def __iadd__(self, other: int | float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__iadd__(asarray(other)._data) return self @@ -611,6 +613,8 @@ def __radd__(self, other: int | float | NamedArray, /): return add(asarray(other), self) def __iand__(self, other: int | bool | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__iand__(asarray(other)._data) return self @@ -620,6 +624,8 @@ def __rand__(self, other: int | bool | NamedArray, /): return bitwise_and(asarray(other), self) def __ifloordiv__(self, other: int | float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__ifloordiv__(asarray(other)._data) return self @@ -629,6 +635,8 @@ def __rfloordiv__(self, other: int | float | NamedArray, /): return floor_divide(asarray(other), self) def __ilshift__(self, other: int | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__ilshift__(asarray(other)._data) return self @@ -647,6 +655,8 @@ def __rmatmul__(self, other: NamedArray, /): return matmul(asarray(other), self) def __imod__(self, other: int | float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__imod__(asarray(other)._data) return self @@ -656,6 +666,8 @@ def __rmod__(self, other: int | float | NamedArray, /): return remainder(asarray(other), self) def __imul__(self, other: int | float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__imul__(asarray(other)._data) return self @@ -665,6 +677,8 @@ def __rmul__(self, other: int | float | NamedArray, /): return multiply(asarray(other), self) def __ior__(self, other: int | bool | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__ior__(asarray(other)._data) return self @@ -674,6 +688,8 @@ def __ror__(self, other: int | bool | NamedArray, /): return bitwise_or(asarray(other), self) def __ipow__(self, other: int | float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__ipow__(asarray(other)._data) return self @@ -683,6 +699,8 @@ def __rpow__(self, other: int | float | NamedArray, /): return pow(asarray(other), self) def __irshift__(self, other: int | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__irshift__(asarray(other)._data) return self @@ -692,6 +710,8 @@ def __rrshift__(self, other: int | NamedArray, /): return bitwise_right_shift(asarray(other), self) def __isub__(self, other: int | float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__isub__(asarray(other)._data) return self @@ -701,6 +721,8 @@ def __rsub__(self, other: int | float | NamedArray, /): return subtract(asarray(other), self) def __itruediv__(self, other: float | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__itruediv__(asarray(other)._data) return self @@ -710,6 +732,8 @@ def __rtruediv__(self, other: float | NamedArray, /): return divide(asarray(other), self) def __ixor__(self, other: int | bool | NamedArray, /): + from xarray.namedarray._array_api import asarray + self._data.__ixor__(asarray(other)._data) return self From 32433b47b6d7a934c8e774e9a10f758cc76da941 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:36:45 +0200 Subject: [PATCH 171/367] Update _statistical_functions.py --- xarray/namedarray/_array_api/_statistical_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 09c7f6da0a7..40bc490fb0d 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -162,6 +162,7 @@ def prod( /, *, dims: _Dims | Default = _default, + dtype: _DType | None = None, keepdims: bool = False, axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: @@ -199,6 +200,7 @@ def sum( /, *, dims: _Dims | Default = _default, + dtype: _DType | None = None, keepdims: bool = False, axis: _AxisLike | None = None, ) -> NamedArray[Any, _DType]: From 084e59a3dc4e5cb7dbd504dc35b3f632adbe1f9b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:52:09 +0200 Subject: [PATCH 172/367] more --- xarray/namedarray/_array_api/_creation_functions.py | 2 +- xarray/namedarray/_array_api/_data_type_functions.py | 12 ++++++++---- xarray/namedarray/_array_api/_utils.py | 5 +++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 22a3b14c3d9..3c933cca8d6 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -199,8 +199,8 @@ def full( def full_like( x: NamedArray[_ShapeType, _DType], - fill_value: bool | int | float | complex, /, + fill_value: bool | int | float | complex, *, dtype: _DType | None = None, device: _Device | None = None, diff --git a/xarray/namedarray/_array_api/_data_type_functions.py b/xarray/namedarray/_array_api/_data_type_functions.py index 895b7a50fb1..e3eca76ead7 100644 --- a/xarray/namedarray/_array_api/_data_type_functions.py +++ b/xarray/namedarray/_array_api/_data_type_functions.py @@ -12,15 +12,19 @@ _dtype, _FInfo, _IInfo, + _Device, _ShapeType, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray def astype( - x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True + x: NamedArray[_ShapeType, Any], + dtype: _DType, + /, + *, + copy: bool = True, + device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: """ Copies an array to a specified data type irrespective of Type Promotion Rules rules. diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index c9d4932d3cb..ea68427faea 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -24,8 +24,9 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: if xp is None: - # import array_api_strict as xpd - import array_api_compat.numpy as xpd + import array_api_strict as xpd + + # import array_api_compat.numpy as xpd # import numpy as xpd From d19d4be278c55235db25fb0f1d978153f9f8dec8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 20:52:50 +0000 Subject: [PATCH 173/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_data_type_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_data_type_functions.py b/xarray/namedarray/_array_api/_data_type_functions.py index e3eca76ead7..a213f45f771 100644 --- a/xarray/namedarray/_array_api/_data_type_functions.py +++ b/xarray/namedarray/_array_api/_data_type_functions.py @@ -8,11 +8,11 @@ ) from xarray.namedarray._typing import ( _arrayapi, + _Device, _DType, _dtype, _FInfo, _IInfo, - _Device, _ShapeType, ) from xarray.namedarray.core import NamedArray From 13eec740a46322cf5162f06c5dceeebe607e6606 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 22:55:35 +0200 Subject: [PATCH 174/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index ea68427faea..eba610ec5c1 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -24,11 +24,11 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: if xp is None: - import array_api_strict as xpd + # import array_api_strict as xpd # import array_api_compat.numpy as xpd - # import numpy as xpd + import numpy as xpd return xpd else: From 38ad910d3889204ef319931f32d5a8731846a0ca Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:08:41 +0200 Subject: [PATCH 175/367] Update core.py --- xarray/namedarray/core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c8ddb880ceb..d4bc5e67dca 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -520,9 +520,9 @@ def __le__(self, other: int | float | NamedArray, /): return less_equal(self, asarray(other)) def __lshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import bitwise_left_shift + from xarray.namedarray._array_api import bitwise_left_shift, asarray - return bitwise_left_shift(self) + return bitwise_left_shift(self, asarray(other)) def __lt__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, less @@ -555,9 +555,9 @@ def __neg__(self, /): return negative(self) def __or__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import bitwise_or + from xarray.namedarray._array_api import bitwise_or, asarray - return bitwise_or(self) + return bitwise_or(self, asarray(other)) def __pos__(self, /): from xarray.namedarray._array_api import positive @@ -570,9 +570,9 @@ def __pow__(self, other: int | float | NamedArray, /): return pow(self, asarray(other)) def __rshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import bitwise_right_shift + from xarray.namedarray._array_api import bitwise_right_shift, asarray - return bitwise_right_shift(self) + return bitwise_right_shift(self, asarray(other)) def __setitem__( self, @@ -597,9 +597,9 @@ def __truediv__(self, other: float | NamedArray, /): return divide(self, asarray(other)) def __xor__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import bitwise_xor + from xarray.namedarray._array_api import bitwise_xor, asarray - return bitwise_xor(self) + return bitwise_xor(self, asarray(other)) def __iadd__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray From 28ab5d982f499ff3435e0c94012f722b6eb7d4c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 21:09:20 +0000 Subject: [PATCH 176/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d4bc5e67dca..80a010e8ebf 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -520,7 +520,7 @@ def __le__(self, other: int | float | NamedArray, /): return less_equal(self, asarray(other)) def __lshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import bitwise_left_shift, asarray + from xarray.namedarray._array_api import asarray, bitwise_left_shift return bitwise_left_shift(self, asarray(other)) @@ -555,7 +555,7 @@ def __neg__(self, /): return negative(self) def __or__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import bitwise_or, asarray + from xarray.namedarray._array_api import asarray, bitwise_or return bitwise_or(self, asarray(other)) @@ -570,7 +570,7 @@ def __pow__(self, other: int | float | NamedArray, /): return pow(self, asarray(other)) def __rshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import bitwise_right_shift, asarray + from xarray.namedarray._array_api import asarray, bitwise_right_shift return bitwise_right_shift(self, asarray(other)) @@ -597,7 +597,7 @@ def __truediv__(self, other: float | NamedArray, /): return divide(self, asarray(other)) def __xor__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import bitwise_xor, asarray + from xarray.namedarray._array_api import asarray, bitwise_xor return bitwise_xor(self, asarray(other)) From 9618f2ac8d4fd871b2e2497c5a7338b7aac6f121 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 11:18:13 +0200 Subject: [PATCH 177/367] Update _creation_functions.py --- xarray/namedarray/_array_api/_creation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 3c933cca8d6..2d4be82abd8 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -247,7 +247,7 @@ def meshgrid(*arrays: NamedArray, indexing: str = "xy") -> list[NamedArray]: def ones( shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: - return full(shape, 1, dtype=dtype, device=device) + return full(shape, 1.0, dtype=dtype, device=device) def ones_like( @@ -287,7 +287,7 @@ def triu( def zeros( shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: - return full(shape, 0, dtype=dtype, device=device) + return full(shape, 0.0, dtype=dtype, device=device) def zeros_like( From beb7a350cacb4cec2f1eb34a575250373aeeeea2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:12:36 +0200 Subject: [PATCH 178/367] Add manip functions --- xarray/namedarray/_array_api/__init__.py | 28 ++-- .../_array_api/_manipulation_functions.py | 126 ++++++++++++++++-- xarray/namedarray/_array_api/_utils.py | 9 ++ 3 files changed, 136 insertions(+), 27 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 8167d6adc8f..c5afda1b329 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -259,30 +259,30 @@ ] from xarray.namedarray._array_api._manipulation_functions import ( - # broadcast_arrays, - # broadcast_to, - # concat, + broadcast_arrays, + broadcast_to, + concat, expand_dims, - # flip, - # moveaxis, + flip, + moveaxis, permute_dims, reshape, - # roll, - # squeeze, + roll, + squeeze, stack, ) __all__ += [ - # "broadcast_arrays", - # "broadcast_to", - # "concat", + "broadcast_arrays", + "broadcast_to", + "concat", "expand_dims", - # "flip", - # "moveaxis", + "flip", + "moveaxis", "permute_dims", "reshape", - # "roll", - # "squeeze", + "roll", + "squeeze", "stack", ] diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 662911cbbf6..977a45f41da 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -3,7 +3,12 @@ from typing import Any from xarray.namedarray._array_api._creation_functions import asarray -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._data_type_functions import result_type +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _infer_dims, + _insert_dim, +) from xarray.namedarray._typing import ( Default, _arrayapi, @@ -11,12 +16,41 @@ _Axis, _default, _Dim, + _Dims, _DType, _Shape, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray + + +def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: + x = arrays[0] + xp = _get_data_namespace(x) + _arrays = tuple(a._data for a in arrays) + _datas = xp.broadcast_arrays(_arrays) + out = [] + for _data in _datas: + _dims = _infer_dims(_data) # TODO: Fix dims + out.append(x._new(_dims, _data)) + return out + + +def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.broadcast_to(x._data, shape=shape) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def concat( + arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis | None = 0 +) -> NamedArray: + xp = _get_data_namespace(arrays[0]) + dtype = result_type(*arrays) + arrays = tuple(a._data for a in arrays) + _data = xp.concat(arrays, axis=axis, dtype=dtype) + _dims = _infer_dims(_data) + return NamedArray(_dims, _data) def expand_dims( @@ -57,13 +91,23 @@ def expand_dims( [3., 4.]]]) """ xp = _get_data_namespace(x) - dims = x.dims - if dim is _default: - dim = f"dim_{len(dims)}" - d = list(dims) - d.insert(axis, dim) - out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) - return out + _data = xp.expand_dims(x._data, axis=axis) + _dims = _insert_dim(x.dims, dim, axis) + return x._new(_dims, _data) + + +def flip(x: NamedArray, /, *, axis: _Axes | None = None) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.flip(x._data, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def moveaxis(x: NamedArray, source: _Axes, destination: _Axes, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.moveaxis(x._data, source=source, destination=destination) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: @@ -95,6 +139,19 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT return out +def repeat( + x: NamedArray, + repeats: int | NamedArray, + /, + *, + axis: _Axis | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.repeat(x._data, repeats, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + def reshape(x, /, shape: _Shape, *, copy: bool | None = None): xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape) @@ -105,5 +162,48 @@ def reshape(x, /, shape: _Shape, *, copy: bool | None = None): return out -def stack(arrays, /, *, axis=0): - raise NotImplementedError("TODO:") +def roll( + x: NamedArray, + /, + shift: int | tuple[int, ...], + *, + axis: _Axes | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.roll(x._data, shift=shift, axis=axis) + return x._new(_data) + + +def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.squeeze(x._data, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def stack( + arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis = 0 +) -> NamedArray: + x = arrays[0] + xp = _get_data_namespace(x) + arrays = tuple(a._data for a in arrays) + _data = xp.stack(arrays, axis=axis) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def tile(x: NamedArray, repetitions: tuple[int, ...], /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.tile(x._data, repetitions) + _dims = _infer_dims(_data) # TODO: Fix dims + return x._new(_dims, _data) + + +def unstack(x: NamedArray, /, *, axis: _Axis = 0) -> tuple[NamedArray, ...]: + xp = _get_data_namespace(x) + _datas = xp.unstack(x._data, axis=axis) + out = () + for _data in _datas: + _dims = _infer_dims(_data) # TODO: Fix dims + out += (x._new(_dims, _data),) + return out diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index eba610ec5c1..77afdeb07cd 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -14,6 +14,7 @@ _DimsLike, _DType, _dtype, + _Axis, _Shape, duckarray, ) @@ -162,3 +163,11 @@ def _get_remaining_dims( dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) return dims, data + + +def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: + if dim is _default: + dim = f"dim_{len(dims)}" + d = list(dims) + d.insert(axis, dim) + return tuple(d) From 8caf88d0a65ccf1e0f3db968f0ff879a46ad8ff3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 10:13:15 +0000 Subject: [PATCH 179/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_manipulation_functions.py | 1 - xarray/namedarray/_array_api/_utils.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 977a45f41da..3b59d4e7368 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -16,7 +16,6 @@ _Axis, _default, _Dim, - _Dims, _DType, _Shape, ) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 77afdeb07cd..1a732df108c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -7,6 +7,7 @@ from xarray.namedarray._typing import ( Default, _arrayapi, + _Axis, _AxisLike, _default, _Dim, @@ -14,7 +15,6 @@ _DimsLike, _DType, _dtype, - _Axis, _Shape, duckarray, ) From 4da8a10be40ca8a716b4e0a99e948bb4b4fc7251 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:27:45 +0200 Subject: [PATCH 180/367] Update core.py --- xarray/namedarray/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 80a010e8ebf..94e3722b48d 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -483,8 +483,11 @@ def __ge__(self, other: int | float | NamedArray, /): def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: if isinstance(key, int | slice | tuple): + from xarray.namedarray._array_api._utils import _infer_dims + _data = self._data[key] - return self._new((), _data) + _dims = _infer_dims(_data.shape) # TODO: fix + return self._new(_dims, _data) elif isinstance(key, NamedArray): _key = self._data # TODO: Transpose, unordered dims shouldn't matter. _data = self._data[_key] From 15a21c9b9b5ab5c1ccc6850475aaa6929427c9e0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:36:49 +0200 Subject: [PATCH 181/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 3b59d4e7368..10a68bab71f 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -27,11 +27,8 @@ def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: xp = _get_data_namespace(x) _arrays = tuple(a._data for a in arrays) _datas = xp.broadcast_arrays(_arrays) - out = [] - for _data in _datas: - _dims = _infer_dims(_data) # TODO: Fix dims - out.append(x._new(_dims, _data)) - return out + _dims = _infer_dims(_datas[0].shape) + return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas)] def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray: From af70db911b4ab291d3c6d2f40ad117ce81063eaf Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:52:29 +0200 Subject: [PATCH 182/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 10a68bab71f..a48342b0c07 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -26,7 +26,7 @@ def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: x = arrays[0] xp = _get_data_namespace(x) _arrays = tuple(a._data for a in arrays) - _datas = xp.broadcast_arrays(_arrays) + _datas = xp.broadcast_arrays(*_arrays) _dims = _infer_dims(_datas[0].shape) return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas)] From d3f4c14e069aafd90471c2d619d87393fdbf1f7d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 13:05:36 +0200 Subject: [PATCH 183/367] Update _manipulation_functions.py --- .../_array_api/_manipulation_functions.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index a48342b0c07..5264786ded5 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -34,7 +34,7 @@ def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray: xp = _get_data_namespace(x) _data = xp.broadcast_to(x._data, shape=shape) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) @@ -45,7 +45,7 @@ def concat( dtype = result_type(*arrays) arrays = tuple(a._data for a in arrays) _data = xp.concat(arrays, axis=axis, dtype=dtype) - _dims = _infer_dims(_data) + _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -95,14 +95,14 @@ def expand_dims( def flip(x: NamedArray, /, *, axis: _Axes | None = None) -> NamedArray: xp = _get_data_namespace(x) _data = xp.flip(x._data, axis=axis) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) def moveaxis(x: NamedArray, source: _Axes, destination: _Axes, /) -> NamedArray: xp = _get_data_namespace(x) _data = xp.moveaxis(x._data, source=source, destination=destination) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) @@ -144,7 +144,7 @@ def repeat( ) -> NamedArray: xp = _get_data_namespace(x) _data = xp.repeat(x._data, repeats, axis=axis) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) @@ -173,7 +173,7 @@ def roll( def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: xp = _get_data_namespace(x) _data = xp.squeeze(x._data, axis=axis) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) @@ -184,14 +184,14 @@ def stack( xp = _get_data_namespace(x) arrays = tuple(a._data for a in arrays) _data = xp.stack(arrays, axis=axis) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) def tile(x: NamedArray, repetitions: tuple[int, ...], /) -> NamedArray: xp = _get_data_namespace(x) _data = xp.tile(x._data, repetitions) - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) @@ -200,6 +200,6 @@ def unstack(x: NamedArray, /, *, axis: _Axis = 0) -> tuple[NamedArray, ...]: _datas = xp.unstack(x._data, axis=axis) out = () for _data in _datas: - _dims = _infer_dims(_data) # TODO: Fix dims + _dims = _infer_dims(_data.shape) # TODO: Fix dims out += (x._new(_dims, _data),) return out From ac20d2cdf723c30c9d5107b037f2e24606c4483e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:08:27 +0200 Subject: [PATCH 184/367] add linalg --- xarray/namedarray/_array_api/__init__.py | 6 + .../namedarray/_array_api/_linalg/__init__.py | 53 ++++ .../namedarray/_array_api/_linalg/_linalg.py | 226 ++++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 xarray/namedarray/_array_api/_linalg/__init__.py create mode 100644 xarray/namedarray/_array_api/_linalg/_linalg.py diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index c5afda1b329..04c91df859e 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -244,6 +244,12 @@ "__array_namespace_info__", ] +import xarray.namedarray._array_api._linalg as linalg + + +__all__ = ["linalg"] + + from xarray.namedarray._array_api._linear_algebra_functions import ( matmul, matrix_transpose, diff --git a/xarray/namedarray/_array_api/_linalg/__init__.py b/xarray/namedarray/_array_api/_linalg/__init__.py new file mode 100644 index 00000000000..d86858242e4 --- /dev/null +++ b/xarray/namedarray/_array_api/_linalg/__init__.py @@ -0,0 +1,53 @@ +__all__ = [] + +from xarray.namedarray._array_api._linalg._linalg import ( + cholesky, + cross, + det, + diagonal, + eigh, + eigvalsh, + inv, + matmul, + matrix_norm, + matrix_power, + matrix_rank, + matrix_transpose, + outer, + pinv, + qr, + slogdet, + solve, + svd, + svdvals, + tensordot, + trace, + vecdot, + vector_norm, +) + +__all__ = [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", +] diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py new file mode 100644 index 00000000000..282162bd816 --- /dev/null +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, NamedTuple, Literal + +from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims +from xarray.namedarray.core import NamedArray +from xarray.namedarray._array_api._dtypes import ( + _floating_dtypes, + _numeric_dtypes, + float32, + complex64, + complex128, +) +from xarray.namedarray._array_api._data_type_functions import finfo +from xarray.namedarray._array_api._manipulation_functions import reshape +from xarray.namedarray._array_api._elementwise_functions import conj + + +if TYPE_CHECKING: + from xarray.namedarray._typing import _Axis, _DType, _Axes + + +class EighResult(NamedTuple): + eigenvalues: NamedArray + eigenvectors: NamedArray + + +class QRResult(NamedTuple): + Q: NamedArray + R: NamedArray + + +class SlogdetResult(NamedTuple): + sign: NamedArray + logabsdet: NamedArray + + +class SVDResult(NamedTuple): + U: NamedArray + S: NamedArray + Vh: NamedArray + + +def cholesky(x: NamedArray, /, *, upper: bool = False) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.cholesky(x._data, upper=upper) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +# Note: cross is the numpy top-level namespace, not np.linalg +def cross(x1: NamedArray, x2: NamedArray, /, *, axis: _Axis = -1) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.linalg.cross(x1._data, x2._data, axis=axis) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x1._new(_dims, _data) + + +def det(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.det(x._data) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def diagonal(x: NamedArray, /, *, offset: int = 0) -> NamedArray: + # Note: diagonal is the numpy top-level namespace, not np.linalg + xp = _get_data_namespace(x) + _data = xp.linalg.diagonal(x._data, offset=offset) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def eigh(x: NamedArray, /) -> EighResult: + xp = _get_data_namespace(x) + _datas = xp.linalg.eigh(x._data) + _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims + return EighResult(*(x._new(_dims, _data) for _data in _datas)) + + +def eigvalsh(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.eigvalsh(x._data) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def inv(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.inv(x._data) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def matrix_norm( + x: NamedArray, + /, + *, + keepdims: bool = False, + ord: int | float | Literal["fro", "nuc"] | None = "fro", +) -> NamedArray: # noqa: F821 + xp = _get_data_namespace(x) + _data = xp.linalg.matrix_norm(x._data, keepdims=keepdims, ord=ord) # ckeck xp.mean + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def matrix_power(x: NamedArray, n: int, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.matrix_power(x._data, n=n) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def matrix_rank( + x: NamedArray, /, *, rtol: float | NamedArray | None = None +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.matrix_rank(x._data, rtol=rtol) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def outer(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.linalg.outer(x1._data, x2._data) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x1._new(_dims, _data) + + +def pinv(x: NamedArray, /, *, rtol: float | NamedArray | None = None) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.pinv(x._data, rtol=rtol) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def qr( + x: NamedArray, /, *, mode: Literal["reduced", "complete"] = "reduced" +) -> QRResult: + xp = _get_data_namespace(x) + _datas = xp.linalg.qr(x._data) + _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims + return QRResult(*(x._new(_dims, _data) for _data in _datas)) + + +def slogdet(x: NamedArray, /) -> SlogdetResult: + xp = _get_data_namespace(x) + _datas = xp.linalg.slogdet(x._data) + _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims + return SlogdetResult(*(x._new(_dims, _data) for _data in _datas)) + + +def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x1) + _data = xp.linalg.solve(x1._data, x2._data) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x1._new(_dims, _data) + + +def svd(x: NamedArray, /, *, full_matrices: bool = True) -> SVDResult: + xp = _get_data_namespace(x) + _datas = xp.linalg.svd(x._data, full_matrices=full_matrices) + _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims + return SVDResult(*(x._new(_dims, _data) for _data in _datas)) + + +def svdvals(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.svdvals(x._data) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def trace( + x: NamedArray, /, *, offset: int = 0, dtype: _DType | None = None +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.svdvals(x._data, offset=offset) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def vector_norm( + x: NamedArray, + /, + *, + axis: _Axes | None = None, + keepdims: bool = False, + ord: int | float | None = 2, +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.linalg.svdvals(x._data, axis=axis, keepdims=keepdims, ord=ord) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def matmul(x1: NamedArray, x2: NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api._linear_algebra_functions import matmul + + return matmul(x1, x2) + + +def tensordot( + x1: NamedArray, + x2: NamedArray, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, +) -> NamedArray: + from xarray.namedarray._array_api._linear_algebra_functions import tensordot + + return tensordot(x1, x2, axes=axes) + + +def matrix_transpose(x: NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api._linear_algebra_functions import matrix_transpose + + return matrix_transpose(x) + + +def vecdot(x1: NamedArray, x2: NamedArray, /, *, axis: _Axis = -1) -> NamedArray: + from xarray.namedarray._array_api._linear_algebra_functions import vecdot + + return vecdot(x1, x2, axis=axis) From b799a9fedbc46e9cf761a6e750e2017e2bc9676b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:09:14 +0000 Subject: [PATCH 185/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/__init__.py | 1 - xarray/namedarray/_array_api/_linalg/_linalg.py | 15 ++------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 04c91df859e..99628406abd 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -246,7 +246,6 @@ import xarray.namedarray._array_api._linalg as linalg - __all__ = ["linalg"] diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py index 282162bd816..273de930b9a 100644 --- a/xarray/namedarray/_array_api/_linalg/_linalg.py +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -1,24 +1,13 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, NamedTuple, Literal +from typing import TYPE_CHECKING, Literal, NamedTuple from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims from xarray.namedarray.core import NamedArray -from xarray.namedarray._array_api._dtypes import ( - _floating_dtypes, - _numeric_dtypes, - float32, - complex64, - complex128, -) -from xarray.namedarray._array_api._data_type_functions import finfo -from xarray.namedarray._array_api._manipulation_functions import reshape -from xarray.namedarray._array_api._elementwise_functions import conj - if TYPE_CHECKING: - from xarray.namedarray._typing import _Axis, _DType, _Axes + from xarray.namedarray._typing import _Axes, _Axis, _DType class EighResult(NamedTuple): From c78031acbef4dd1329bd3555651981d78b6e6da0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:32:11 +0200 Subject: [PATCH 186/367] Add set functions --- xarray/namedarray/_array_api/__init__.py | 15 ++++ .../namedarray/_array_api/_set_functions.py | 77 +++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 xarray/namedarray/_array_api/_set_functions.py diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 99628406abd..72ca9984815 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -307,6 +307,21 @@ "where", ] +from ._set_functions import ( + unique_all, + unique_counts, + unique_inverse, + unique_values, +) + +__all__ += [ + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", +] + + from xarray.namedarray._array_api._sorting_functions import argsort, sort __all__ += ["argsort", "sort"] diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py new file mode 100644 index 00000000000..61e3747a277 --- /dev/null +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple, Literal + +from xarray.namedarray._array_api._utils import ( + _dims_to_axis, + _get_data_namespace, + _get_remaining_dims, + _infer_dims, +) +from xarray.namedarray._typing import ( + Default, + _default, + _Dims, +) +from xarray.namedarray.core import NamedArray + + +class UniqueAllResult(NamedTuple): + values: NamedArray + indices: NamedArray + inverse_indices: NamedArray + counts: NamedArray + + +class UniqueCountsResult(NamedTuple): + values: NamedArray + counts: NamedArray + + +class UniqueInverseResult(NamedTuple): + values: NamedArray + inverse_indices: NamedArray + + +def unique_all(x: NamedArray, /) -> UniqueAllResult: + xp = _get_data_namespace(x) + values, indices, inverse_indices, counts = xp.unique_all(x._data) + _dims_values = _infer_dims(values.shape) # TODO: Fix + _dims_indices = _infer_dims(indices.shape) # TODO: Fix dims + _dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims + _dims_counts = _infer_dims(counts.shape) # TODO: Fix dims + return UniqueAllResult( + NamedArray(_dims_values, values), + NamedArray(_dims_indices, indices), + NamedArray(_dims_inverse_indices, inverse_indices), + NamedArray(_dims_counts, counts), + ) + + +def unique_counts(x: NamedArray, /) -> UniqueCountsResult: + xp = _get_data_namespace(x) + values, counts = xp.unique_counts(x._data) + _dims_values = _infer_dims(values.shape) # TODO: Fix dims + _dims_counts = _infer_dims(counts.shape) # TODO: Fix dims + return UniqueCountsResult( + NamedArray(_dims_values, values), + NamedArray(_dims_counts, counts), + ) + + +def unique_inverse(x: NamedArray, /) -> UniqueInverseResult: + xp = _get_data_namespace(x) + values, inverse_indices = xp.unique_inverse(x._data) + _dims_values = _infer_dims(values.shape) # TODO: Fix + _dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims + return UniqueInverseResult( + NamedArray(_dims_values, values), + NamedArray(_dims_inverse_indices, inverse_indices), + ) + + +def unique_values(x: NamedArray, /) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.unique_values(x._data) + _dims = _infer_dims(_data.shape) # TODO: Fix + return x._new(_dims, _data) From b79747254b10eecd2b1649438a49eb4023428406 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:32:49 +0000 Subject: [PATCH 187/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/__init__.py | 2 +- xarray/namedarray/_array_api/_set_functions.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 72ca9984815..13fd8f623f0 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -307,7 +307,7 @@ "where", ] -from ._set_functions import ( +from xarray.namedarray._array_api._set_functions import ( unique_all, unique_counts, unique_inverse, diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index 61e3747a277..64c852b61c3 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -1,18 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple, Literal +from typing import NamedTuple from xarray.namedarray._array_api._utils import ( - _dims_to_axis, _get_data_namespace, - _get_remaining_dims, _infer_dims, ) -from xarray.namedarray._typing import ( - Default, - _default, - _Dims, -) from xarray.namedarray.core import NamedArray From d6911a238b47d4f08dee31bbcf7f1b34e7feee66 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:40:27 +0200 Subject: [PATCH 188/367] Update _creation_functions.py --- xarray/namedarray/_array_api/_creation_functions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 2d4be82abd8..dc9e574040a 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -184,6 +184,19 @@ def eye( return NamedArray(_dims, _data) +def from_dlpack( + x: object, + /, + *, + device: _Device | None = None, + copy: bool | None = None, +) -> NamedArray: + xp = _get_data_namespace(x) + _device = x.device if device is None else device + _data = xp.from_dlpack(x, device=_device, copy=copy) + return x._new(data=_data) + + def full( shape: _Shape, fill_value: bool | int | float | complex, From d9d23bb2283744d0b36afc75aaa3c3b464034d46 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:43:40 +0200 Subject: [PATCH 189/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 13fd8f623f0..ed863867968 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -271,10 +271,13 @@ flip, moveaxis, permute_dims, + repeat, reshape, roll, squeeze, stack, + tile, + unstack, ) __all__ += [ @@ -285,10 +288,13 @@ "flip", "moveaxis", "permute_dims", + "repeat", "reshape", "roll", "squeeze", "stack", + "tile", + "unstack", ] from xarray.namedarray._array_api._searching_functions import ( From 69f3061b16a0618c79e6bde2c9c085e9e92e6355 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:59:09 +0200 Subject: [PATCH 190/367] Update _sorting_functions.py --- .../_array_api/_sorting_functions.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 0bc5f53b34a..e8c20fa6525 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -6,9 +6,7 @@ _default, _Dim, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray def argsort( @@ -22,10 +20,21 @@ def argsort( ) -> NamedArray: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis) - out = x._new( - data=xp.argsort(x._data, axis=_axis, descending=descending, stable=stable) - ) - return out + if not descending: + _data = xp.argsort(x._data, axis=_axis, stable=stable) + else: + # As NumPy has no native descending sort, we imitate it here. Note that + # simply flipping the results of np.argsort(x._array, ...) would not + # respect the relative order like it would in native descending sorts. + _data = xp.flip( + xp.argsort(xp.flip(x._data, axis=axis), stable=stable, axis=axis), + axis=axis, + ) + # Rely on flip()/argsort() to validate axis + normalised_axis = axis if axis >= 0 else x.ndim + axis + max_i = x.shape[normalised_axis] - 1 + _data = max_i - _data + return x._new(data=_data) def sort( @@ -39,7 +48,7 @@ def sort( ) -> NamedArray: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis) - out = x._new( - data=xp.argsort(x._data, axis=_axis, descending=descending, stable=stable) - ) - return out + _data = xp.sort(x._data, axis=_axis, descending=descending, stable=stable) + if descending: + _data = xp.flip(_data, axis=axis) + return x._new(data=_data) From 95c5bc41ba5bccdeb306b1b066f2ffa3b67ca57d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:25:51 +0200 Subject: [PATCH 191/367] Add fft --- xarray/namedarray/_array_api/__init__.py | 4 + xarray/namedarray/_array_api/_fft/__init__.py | 35 ++++ xarray/namedarray/_array_api/_fft/_fft.py | 193 ++++++++++++++++++ 3 files changed, 232 insertions(+) create mode 100644 xarray/namedarray/_array_api/_fft/__init__.py create mode 100644 xarray/namedarray/_array_api/_fft/_fft.py diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index ed863867968..ce375a5c8f8 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -234,6 +234,10 @@ "trunc", ] +import xarray.namedarray._array_api._fft as fft + +__all__ = ["fft"] + from xarray.namedarray._array_api._indexing_functions import take __all__ += ["take"] diff --git a/xarray/namedarray/_array_api/_fft/__init__.py b/xarray/namedarray/_array_api/_fft/__init__.py new file mode 100644 index 00000000000..6b068ee4491 --- /dev/null +++ b/xarray/namedarray/_array_api/_fft/__init__.py @@ -0,0 +1,35 @@ +__all__ = [] + +from xarray.namedarray._array_api._linalg._linalg import ( + fft, + ifft, + fftn, + ifftn, + rfft, + irfft, + rfftn, + irfftn, + hfft, + ihfft, + fftfreq, + rfftfreq, + fftshift, + ifftshift, +) + +__all__ = [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", +] diff --git a/xarray/namedarray/_array_api/_fft/_fft.py b/xarray/namedarray/_array_api/_fft/_fft.py new file mode 100644 index 00000000000..c6053f06dca --- /dev/null +++ b/xarray/namedarray/_array_api/_fft/_fft.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Literal, NamedTuple + +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _infer_dims, + _maybe_default_namespace, +) +from xarray.namedarray.core import NamedArray + +if TYPE_CHECKING: + from xarray.namedarray._typing import _Axes, _Axis, _DType, _Device + + _Norm = Literal["backward", "ortho", "forward"] + +from xarray.namedarray._array_api._dtypes import ( + _floating_dtypes, + _real_floating_dtypes, + _complex_floating_dtypes, + float32, + complex64, +) +from xarray.namedarray._array_api._data_type_functions import astype + + +def fft( + x: NamedArray, + /, + *, + n: int | None = None, + axis: _Axis = -1, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.fft(x._data, n=n, axis=axis, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def ifft( + x: NamedArray, + /, + *, + n: int | None = None, + axis: _Axis = -1, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.ifft(x._data, n=n, axis=axis, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def fftn( + x: NamedArray, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.fftn(x._data, s=s, axes=axes, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def ifftn( + x: NamedArray, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.ifftn(x._data, s=s, axes=axes, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def rfft( + x: NamedArray, + /, + *, + n: int | None = None, + axis: _Axis = -1, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.rfft(x._data, n=n, axis=axis, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def irfft( + x: NamedArray, + /, + *, + n: int | None = None, + axis: _Axis = -1, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.irfft(x._data, n=n, axis=axis, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def rfftn( + x: NamedArray, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.rfftn(x._data, s=s, axes=axes, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def irfftn( + x: NamedArray, + /, + *, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.irfftn(x._data, s=s, axes=axes, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def hfft( + x: NamedArray, + /, + *, + n: int | None = None, + axis: _Axis = -1, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.hfft(x._data, n=n, axis=axis, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def ihfft( + x: NamedArray, + /, + *, + n: int | None = None, + axis: _Axis = -1, + norm: _Norm = "backward", +) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.ihfft(x._data, n=n, axis=axis, norm=norm) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def fftfreq(n: int, /, *, d: float = 1.0, device: _Device | None = None) -> NamedArray: + xp = _maybe_default_namespace() # TODO: Can use device? + _data = xp.fft.fftfreq(n, d=d, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def rfftfreq(n: int, /, *, d: float = 1.0, device: _Device | None = None) -> NamedArray: + xp = _maybe_default_namespace() # TODO: Can use device? + _data = xp.fft.rfftfreq(n, d=d, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def fftshift(x: NamedArray, /, *, axes: _Axes | None = None) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.fftshift(x._data, axes=axes) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) + + +def ifftshift(x: NamedArray, /, *, axes: _Axes | None = None) -> NamedArray: + xp = _get_data_namespace(x) + _data = xp.fft.ifftshift(x._data, axes=axes) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) From 4e7dffb510644888492b10764ea0cae3229b82d9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 13:26:29 +0000 Subject: [PATCH 192/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_fft/__init__.py | 16 ++++++++-------- xarray/namedarray/_array_api/_fft/_fft.py | 13 ++----------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/xarray/namedarray/_array_api/_fft/__init__.py b/xarray/namedarray/_array_api/_fft/__init__.py index 6b068ee4491..b5f729bc517 100644 --- a/xarray/namedarray/_array_api/_fft/__init__.py +++ b/xarray/namedarray/_array_api/_fft/__init__.py @@ -2,19 +2,19 @@ from xarray.namedarray._array_api._linalg._linalg import ( fft, - ifft, + fftfreq, fftn, + fftshift, + hfft, + ifft, ifftn, - rfft, + ifftshift, + ihfft, irfft, - rfftn, irfftn, - hfft, - ihfft, - fftfreq, + rfft, rfftfreq, - fftshift, - ifftshift, + rfftn, ) __all__ = [ diff --git a/xarray/namedarray/_array_api/_fft/_fft.py b/xarray/namedarray/_array_api/_fft/_fft.py index c6053f06dca..46e1fc2c4c5 100644 --- a/xarray/namedarray/_array_api/_fft/_fft.py +++ b/xarray/namedarray/_array_api/_fft/_fft.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, NamedTuple +from typing import TYPE_CHECKING, Literal from xarray.namedarray._array_api._utils import ( _get_data_namespace, @@ -11,19 +11,10 @@ from xarray.namedarray.core import NamedArray if TYPE_CHECKING: - from xarray.namedarray._typing import _Axes, _Axis, _DType, _Device + from xarray.namedarray._typing import _Axes, _Axis, _Device _Norm = Literal["backward", "ortho", "forward"] -from xarray.namedarray._array_api._dtypes import ( - _floating_dtypes, - _real_floating_dtypes, - _complex_floating_dtypes, - float32, - complex64, -) -from xarray.namedarray._array_api._data_type_functions import astype - def fft( x: NamedArray, From 09e1478bf9a3312ea31b4dfd25d2fac50db91d2c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:29:08 +0200 Subject: [PATCH 193/367] Update __init__.py --- xarray/namedarray/_array_api/_fft/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_fft/__init__.py b/xarray/namedarray/_array_api/_fft/__init__.py index b5f729bc517..ac2fa15deff 100644 --- a/xarray/namedarray/_array_api/_fft/__init__.py +++ b/xarray/namedarray/_array_api/_fft/__init__.py @@ -1,6 +1,6 @@ __all__ = [] -from xarray.namedarray._array_api._linalg._linalg import ( +from xarray.namedarray._array_api._fft._fft import ( fft, fftfreq, fftn, From 9aca0cc7134486d9243038bc28142ff46b133cc3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:36:09 +0200 Subject: [PATCH 194/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index ce375a5c8f8..3a169a3a907 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -18,6 +18,7 @@ empty, empty_like, eye, + from_dlpack, full, full_like, linspace, @@ -36,6 +37,7 @@ "empty", "empty_like", "eye", + "from_dlpack", "full", "full_like", "linspace", From a3f47a96518b0198f9b58b6c70b21d3172dd9c24 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:40:58 +0200 Subject: [PATCH 195/367] Update core.py --- xarray/namedarray/core.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 94e3722b48d..2966f6564ba 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -779,7 +779,14 @@ def device(self) -> _Device: @property def mT(self): - raise NotImplementedError("Todo: ") + if isinstance(self._data, _arrayapi): + from xarray.namedarray._array_api._utils import _infer_dims + + _data = self._data.mT + _dims = _infer_dims(_data.shape) + return self._new(_dims, _data) + else: + raise NotImplementedError("self._data missing mT") @property def ndim(self) -> int: From 3cebb62056aa9cbadc799be86644f34efe0f2529 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:43:51 +0200 Subject: [PATCH 196/367] Update _creation_functions.py --- xarray/namedarray/_array_api/_creation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index dc9e574040a..8f96ec73d92 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -281,7 +281,7 @@ def tril( x: NamedArray[_ShapeType, _DType], /, *, k: int = 0 ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _data = xp.tril(x._data, dtype=x.dtype) + _data = xp.tril(x._data, k=k) # TODO: Can probably determine dim names from x, for now just default names: _dims = _infer_dims(_data.shape) return x._new(_dims, _data) @@ -291,7 +291,7 @@ def triu( x: NamedArray[_ShapeType, _DType], /, *, k: int = 0 ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _data = xp.triu(x._data, dtype=x.dtype) + _data = xp.triu(x._data, k=k) # TODO: Can probably determine dim names from x, for now just default names: _dims = _infer_dims(_data.shape) return x._new(_dims, _data) From 042940bb5a0469c5dced61ee02d011bb18246a44 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:47:47 +0200 Subject: [PATCH 197/367] Update _sorting_functions.py --- xarray/namedarray/_array_api/_sorting_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index e8c20fa6525..7360485cc0a 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -48,7 +48,7 @@ def sort( ) -> NamedArray: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis) - _data = xp.sort(x._data, axis=_axis, descending=descending, stable=stable) + _data = xp.sort(x._data, axis=_axis, stable=stable) if descending: _data = xp.flip(_data, axis=axis) return x._new(data=_data) From 26fb6cb2ddd5f6fea75cf87d3a4ad97f5c40f774 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:58:58 +0200 Subject: [PATCH 198/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2966f6564ba..2a26c6f9f44 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -587,7 +587,7 @@ def __setitem__( if isinstance(key, NamedArray): key = key._data - self._array.__setitem__(key, asarray(value)._data) + self._data.__setitem__(key, asarray(value)._data) def __sub__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, subtract From 8e0af56d3ff08bbb3ecc8361d79f141c6365cb3e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:09:43 +0200 Subject: [PATCH 199/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 1a732df108c..5c4c3e02d1c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -25,11 +25,11 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: if xp is None: - # import array_api_strict as xpd + import array_api_strict as xpd # import array_api_compat.numpy as xpd - import numpy as xpd + # import numpy as xpd return xpd else: From 3341e931c9dbfb1d4f4c9070b1cc80258e41932a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:12:22 +0200 Subject: [PATCH 200/367] Revert "Update _utils.py" This reverts commit 8e0af56d3ff08bbb3ecc8361d79f141c6365cb3e. --- xarray/namedarray/_array_api/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 5c4c3e02d1c..1a732df108c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -25,11 +25,11 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: if xp is None: - import array_api_strict as xpd + # import array_api_strict as xpd # import array_api_compat.numpy as xpd - # import numpy as xpd + import numpy as xpd return xpd else: From dc862c53d5411fd265d5a082ff7dbd09caff5d45 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:12:30 +0200 Subject: [PATCH 201/367] Update namedarray_array_api_skips.txt --- xarray/tests/namedarray_array_api_skips.txt | 43 ++++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/xarray/tests/namedarray_array_api_skips.txt b/xarray/tests/namedarray_array_api_skips.txt index 926d644c9fd..77bb4f2651c 100644 --- a/xarray/tests/namedarray_array_api_skips.txt +++ b/xarray/tests/namedarray_array_api_skips.txt @@ -1,8 +1,39 @@ -# Known failures for the array api tests. +# finfo(float32).eps returns float32 but should return float +array_api_tests/test_data_type_functions.py::test_finfo[float32] -array_api_tests/test_array_object.py::test_getitem -array_api_tests/test_array_object.py::test_getitem_masking +# NumPy deviates in some special cases for floordiv +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] -# Test suite attempts in-place mutation: -array_api_tests/test_array_object.py::test_setitem -array_api_tests/test_array_object.py::test_setitem_masking +# https://github.com/numpy/numpy/issues/21213 +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] +array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices + +# The test suite is incorrectly checking sums that have loss of significance +# (https://github.com/data-apis/array-api-tests/issues/168) +array_api_tests/test_statistical_functions.py::test_sum +array_api_tests/test_statistical_functions.py::test_prod + +# The test suite cannot properly get the signature for vecdot +# https://github.com/numpy/numpy/pull/26237 +array_api_tests/test_signatures.py::test_func_signature[vecdot] +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] \ No newline at end of file From 499bee994de168a914fe705198633cbe76340ec1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:13:07 +0000 Subject: [PATCH 202/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/tests/namedarray_array_api_skips.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/namedarray_array_api_skips.txt b/xarray/tests/namedarray_array_api_skips.txt index 77bb4f2651c..6aa54bb4639 100644 --- a/xarray/tests/namedarray_array_api_skips.txt +++ b/xarray/tests/namedarray_array_api_skips.txt @@ -36,4 +36,4 @@ array_api_tests/test_statistical_functions.py::test_prod # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] \ No newline at end of file +array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] From 2d5e81cd4fc8e280fd1ace94001f6d4bcc28d0ed Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:20:36 +0200 Subject: [PATCH 203/367] Update _linalg.py --- xarray/namedarray/_array_api/_linalg/_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py index 273de930b9a..189ad6b81f2 100644 --- a/xarray/namedarray/_array_api/_linalg/_linalg.py +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -166,7 +166,7 @@ def trace( x: NamedArray, /, *, offset: int = 0, dtype: _DType | None = None ) -> NamedArray: xp = _get_data_namespace(x) - _data = xp.linalg.svdvals(x._data, offset=offset) + _data = xp.linalg.trace(x._data, offset=offset) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) @@ -180,7 +180,7 @@ def vector_norm( ord: int | float | None = 2, ) -> NamedArray: xp = _get_data_namespace(x) - _data = xp.linalg.svdvals(x._data, axis=axis, keepdims=keepdims, ord=ord) + _data = xp.linalg.vector_norm(x._data, axis=axis, keepdims=keepdims, ord=ord) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) From 45abdfd1cfdc5c7545e2ac37c6e66db9588320ed Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:40:26 +0200 Subject: [PATCH 204/367] Update _indexing_functions.py --- xarray/namedarray/_array_api/_indexing_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py index 581e8bc8c22..e57d3db0056 100644 --- a/xarray/namedarray/_array_api/_indexing_functions.py +++ b/xarray/namedarray/_array_api/_indexing_functions.py @@ -18,7 +18,7 @@ def take( axis: int | None = None, ) -> NamedArray: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis) + _axis = _dims_to_axis(x, dim, axis)[0] # TODO: Handle attrs? will get x1 now out = x._new(data=xp.take(x._data, indices._data, axis=_axis)) return out From d3efb66bedbec708302a6d84e7a6f048bffe9090 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:40:30 +0200 Subject: [PATCH 205/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 1a732df108c..ec08a6913a7 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -9,6 +9,7 @@ _arrayapi, _Axis, _AxisLike, + _Axes, _default, _Dim, _Dims, @@ -103,7 +104,7 @@ def _assert_either_dim_or_axis( def _dims_to_axis( x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None -) -> _AxisLike | None: +) -> _Axes | None: """ Convert dims to axis indices. @@ -112,7 +113,7 @@ def _dims_to_axis( >>> narr = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) >>> _dims_to_axis(narr, ("y",), None) (1,) - >>> _dims_to_axis(narr, None, 0) + >>> _dims_to_axis(narr, _default, 0) (0,) >>> _dims_to_axis(narr, None, None) """ From 7b8592aefe1b98237c12a8684304724f39352cb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:41:07 +0000 Subject: [PATCH 206/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index ec08a6913a7..fa2ba3b1ba3 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -7,9 +7,9 @@ from xarray.namedarray._typing import ( Default, _arrayapi, + _Axes, _Axis, _AxisLike, - _Axes, _default, _Dim, _Dims, From 02b4f413519063ef6596db05d0e3f61887cba3c3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:52:52 +0200 Subject: [PATCH 207/367] Update core.py --- xarray/namedarray/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2a26c6f9f44..d32df407066 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -482,16 +482,17 @@ def __ge__(self, other: int | float | NamedArray, /): return greater_equal(self, asarray(other)) def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: - if isinstance(key, int | slice | tuple): - from xarray.namedarray._array_api._utils import _infer_dims + from xarray.namedarray._array_api._utils import _infer_dims + if isinstance(key, int | slice | tuple): _data = self._data[key] _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) elif isinstance(key, NamedArray): _key = self._data # TODO: Transpose, unordered dims shouldn't matter. _data = self._data[_key] - return self._new(key._dims, _data) + _dims = _infer_dims(_data.shape) # TODO: fix + return self._new(_dims, _data) else: raise NotImplementedError("{k=} is not supported") From 04700894709db9b0fe2c81f5e1be01b63f69e9e0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:02:10 +0200 Subject: [PATCH 208/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index fa2ba3b1ba3..7f462bedb47 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -150,8 +150,10 @@ def _get_remaining_dims( removed_axes: tuple[int, ...] if axis is None: removed_axes = tuple(v for v in range(x.ndim)) + elif isinstance(axis, tuple): + removed_axes = tuple(a % x.ndim for a in axis) else: - removed_axes = axis % x.ndim if isinstance(axis, tuple) else (axis % x.ndim,) + removed_axes = (axis % x.ndim,) if keepdims: # Insert None (aka newaxis) for removed dims From 5ff5f664b6413486605a790f7ce663ad3f1e7bd2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:09:28 +0200 Subject: [PATCH 209/367] Update _sorting_functions.py --- xarray/namedarray/_array_api/_sorting_functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 7360485cc0a..a47ab7ae527 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -19,7 +19,7 @@ def argsort( axis: int = -1, ) -> NamedArray: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis) + _axis = _dims_to_axis(x, dim, axis)[0] if not descending: _data = xp.argsort(x._data, axis=_axis, stable=stable) else: @@ -27,11 +27,11 @@ def argsort( # simply flipping the results of np.argsort(x._array, ...) would not # respect the relative order like it would in native descending sorts. _data = xp.flip( - xp.argsort(xp.flip(x._data, axis=axis), stable=stable, axis=axis), - axis=axis, + xp.argsort(xp.flip(x._data, axis=_axis), stable=stable, axis=_axis), + axis=_axis, ) # Rely on flip()/argsort() to validate axis - normalised_axis = axis if axis >= 0 else x.ndim + axis + normalised_axis = _axis if _axis >= 0 else x.ndim + _axis max_i = x.shape[normalised_axis] - 1 _data = max_i - _data return x._new(data=_data) @@ -47,8 +47,8 @@ def sort( axis: int = -1, ) -> NamedArray: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis) + _axis = _dims_to_axis(x, dim, axis)[0] _data = xp.sort(x._data, axis=_axis, stable=stable) if descending: - _data = xp.flip(_data, axis=axis) + _data = xp.flip(_data, axis=_axis) return x._new(data=_data) From 3e1123b2f1a8c4dc0a0c2f09b342d079ce0a7667 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:26:55 +0200 Subject: [PATCH 210/367] fix linalg --- .../namedarray/_array_api/_linalg/_linalg.py | 42 +++++++++++++------ .../_array_api/_sorting_functions.py | 5 +-- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py index 189ad6b81f2..9470436a3ae 100644 --- a/xarray/namedarray/_array_api/_linalg/_linalg.py +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -63,9 +63,13 @@ def diagonal(x: NamedArray, /, *, offset: int = 0) -> NamedArray: def eigh(x: NamedArray, /) -> EighResult: xp = _get_data_namespace(x) - _datas = xp.linalg.eigh(x._data) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return EighResult(*(x._new(_dims, _data) for _data in _datas)) + eigvals, eigvecs = xp.linalg.eigh(x._data) + _dims_vals = _infer_dims(eigvals.shape) # TODO: Fix dims + _dims_vecs = _infer_dims(eigvecs.shape) # TODO: Fix dims + return EighResult( + x._new(_dims_vals, eigvals), + x._new(_dims_vecs, eigvecs), + ) def eigvalsh(x: NamedArray, /) -> NamedArray: @@ -129,16 +133,24 @@ def qr( x: NamedArray, /, *, mode: Literal["reduced", "complete"] = "reduced" ) -> QRResult: xp = _get_data_namespace(x) - _datas = xp.linalg.qr(x._data) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return QRResult(*(x._new(_dims, _data) for _data in _datas)) + q, r = xp.linalg.qr(x._data) + _dims_q = _infer_dims(q.shape) # TODO: Fix dims + _dims_r = _infer_dims(r.shape) # TODO: Fix dims + return QRResult( + x._new(_dims_q, q), + x._new(_dims_r, r), + ) def slogdet(x: NamedArray, /) -> SlogdetResult: xp = _get_data_namespace(x) - _datas = xp.linalg.slogdet(x._data) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return SlogdetResult(*(x._new(_dims, _data) for _data in _datas)) + sign, logabsdet = xp.linalg.slogdet(x._data) + _dims_sign = _infer_dims(sign.shape) # TODO: Fix dims + _dims_logabsdet = _infer_dims(logabsdet.shape) # TODO: Fix dims + return SlogdetResult( + x._new(_dims_sign, sign), + x._new(_dims_logabsdet, logabsdet), + ) def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -150,9 +162,15 @@ def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def svd(x: NamedArray, /, *, full_matrices: bool = True) -> SVDResult: xp = _get_data_namespace(x) - _datas = xp.linalg.svd(x._data, full_matrices=full_matrices) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return SVDResult(*(x._new(_dims, _data) for _data in _datas)) + u, s, vh = xp.linalg.svd(x._data, full_matrices=full_matrices) + _dims_u = _infer_dims(u.shape) # TODO: Fix dims + _dims_s = _infer_dims(s.shape) # TODO: Fix dims + _dims_vh = _infer_dims(vh.shape) # TODO: Fix dims + return SVDResult( + x._new(_dims_u, u), + x._new(_dims_s, s), + x._new(_dims_vh, vh), + ) def svdvals(x: NamedArray, /) -> NamedArray: diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index a47ab7ae527..66798daccdb 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -20,12 +20,10 @@ def argsort( ) -> NamedArray: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis)[0] + # TODO: As NumPy currently has no native descending sort, we imitate it here: if not descending: _data = xp.argsort(x._data, axis=_axis, stable=stable) else: - # As NumPy has no native descending sort, we imitate it here. Note that - # simply flipping the results of np.argsort(x._array, ...) would not - # respect the relative order like it would in native descending sorts. _data = xp.flip( xp.argsort(xp.flip(x._data, axis=_axis), stable=stable, axis=_axis), axis=_axis, @@ -49,6 +47,7 @@ def sort( xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis)[0] _data = xp.sort(x._data, axis=_axis, stable=stable) + # TODO: As NumPy currently has no native descending sort, we imitate it here: if descending: _data = xp.flip(_data, axis=_axis) return x._new(data=_data) From 72156803226908a1a22b0456acf3f59616cbd154 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:49:30 +0200 Subject: [PATCH 211/367] Use stable repr --- xarray/namedarray/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d32df407066..585eb15a975 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -1331,10 +1331,12 @@ def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: ) def __repr__(self) -> str: - return formatting.array_repr(self) + # return formatting.array_repr(self) + return f"" def _repr_html_(self) -> str: - return formatting_html.array_repr(self) + # return formatting_html.array_repr(self) + return f"" def _as_sparse( self, From 3c6e8c593ea5f25892b19b4b066918dcdb3b2f36 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:50:08 +0000 Subject: [PATCH 212/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 585eb15a975..2eb9832529a 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -18,7 +18,7 @@ import numpy as np # TODO: get rid of this after migrating this class to array API -from xarray.core import dtypes, formatting, formatting_html +from xarray.core import dtypes from xarray.core.indexing import ( ExplicitlyIndexed, ImplicitToExplicitIndexingAdapter, From 9d0b1425e6bdafcc21c53cadd0e4ee7377fe43f3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:58:08 +0200 Subject: [PATCH 213/367] Update __init__.py --- xarray/namedarray/_array_api/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/__init__.py b/xarray/namedarray/_array_api/__init__.py index 3a169a3a907..957032b09b4 100644 --- a/xarray/namedarray/_array_api/__init__.py +++ b/xarray/namedarray/_array_api/__init__.py @@ -4,9 +4,9 @@ __all__ += ["__array_api_version__"] -from xarray.namedarray.core import NamedArray as Array +# from xarray.namedarray.core import NamedArray as Array -__all__ += ["Array"] +# __all__ += ["Array"] from xarray.namedarray._array_api._constants import e, inf, nan, newaxis, pi From 0b2f21baf10c3a5a0cf9bf0bbd3f210a641e9b00 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:05:36 +0200 Subject: [PATCH 214/367] stricter --- xarray/namedarray/core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2eb9832529a..1cb683083ab 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -263,6 +263,10 @@ def __init__( data: duckarray[Any, _DType_co], attrs: _AttrsLike = None, ): + if not isinstance(data, _arrayfunction_or_api): + raise NotImplementedError( + f"data is not a valid duckarray, got {data=}, {dims=}" + ) self._data = data self._dims = self._parse_dimensions(dims) self._attrs = dict(attrs) if attrs else None From d54210cff78c1bcf5fb03763d135223c1882e5ac Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:06:49 +0200 Subject: [PATCH 215/367] stricter --- xarray/namedarray/core.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 1cb683083ab..7dd663e57b7 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -323,6 +323,10 @@ def _new( attributes you want to store with the array. Will copy the attrs from x by default. """ + if not isinstance(data, _arrayfunction_or_api): + raise NotImplementedError( + f"data is not a valid duckarray, got {data=}, {dims=}" + ) return _new(self, dims, data, attrs) def _replace( @@ -351,6 +355,10 @@ def _replace( attributes you want to store with the array. Will copy the attrs from x by default. """ + if not isinstance(data, _arrayfunction_or_api): + raise NotImplementedError( + f"data is not a valid duckarray, got {data=}, {dims=}" + ) return cast("Self", self._new(dims, data, attrs)) def _copy( From 5622bd7e3aeedb40ead2962229fdcf234cf9ae58 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:44:06 +0200 Subject: [PATCH 216/367] Update _elementwise_functions.py --- .../_array_api/_elementwise_functions.py | 385 +++++++++--------- 1 file changed, 195 insertions(+), 190 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 43cc19e5e61..0d91b13f41d 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -9,112 +9,117 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray + + +def _atleast_0d(x, xp): + """ + Workaround for numpy sometimes returning scalars instead of 0d arrays. + """ + return xp.asarray(x) -def abs(x, /): +def abs(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.abs(x._data)) - return out + _data = _atleast_0d(xp.abs(x._data), xp) + return x._new(data=_data) -def acos(x, /): +def acos(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.acos(x._data)) - return out + _data = _atleast_0d(xp.acos(x._data), xp) + return x._new(data=_data) -def acosh(x, /): +def acosh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.acosh(x._data)) - return out + _data = _atleast_0d(xp.acosh(x._data), xp) + return x._new(data=_data) -def add(x1, x2, /): +def add(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.add(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.add(x1._data, x2._data), xp) + return x1._new(data=_data) -def asin(x, /): +def asin(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.asin(x._data)) - return out + _data = _atleast_0d(xp.asin(x._data), xp) + return x._new(data=_data) -def asinh(x, /): +def asinh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.asinh(x._data)) - return out + _data = _atleast_0d(xp.asinh(x._data), xp) + return x._new(data=_data) -def atan(x, /): +def atan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.atan(x._data)) - return out + _data = _atleast_0d(xp.atan(x._data), xp) + return x._new(data=_data) -def atan2(x1, x2, /): +def atan2(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.atan2(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.atan2(x1._data, x2._data), xp) + return x1._new(data=_data) -def atanh(x, /): +def atanh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.atanh(x._data)) - return out + _data = _atleast_0d(xp.atanh(x._data), xp) + return x._new(data=_data) -def bitwise_and(x1, x2, /): +def bitwise_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_and(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.bitwise_and(x1._data, x2._data), xp) + return x1._new(data=_data) -def bitwise_invert(x, /): +def bitwise_invert(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.bitwise_invert(x._data)) - return out + _data = _atleast_0d(xp.bitwise_invert(x._data), xp) + return x._new(data=_data) -def bitwise_left_shift(x1, x2, /): +def bitwise_left_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_left_shift(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.bitwise_left_shift(x1._data, x2._data), xp) + return x1._new(data=_data) -def bitwise_or(x1, x2, /): +def bitwise_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_or(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.bitwise_or(x1._data, x2._data), xp) + return x1._new(data=_data) -def bitwise_right_shift(x1, x2, /): +def bitwise_right_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_right_shift(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.bitwise_right_shift(x1._data, x2._data), xp) + return x1._new(data=_data) -def bitwise_xor(x1, x2, /): +def bitwise_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.bitwise_xor(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.bitwise_xor(x1._data, x2._data), xp) + return x1._new(data=_data) -def ceil(x, /): +def ceil(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.ceil(x._data)) - return out + _data = _atleast_0d(xp.ceil(x._data), xp) + return x._new(data=_data) def clip( @@ -124,93 +129,93 @@ def clip( max: int | float | NamedArray | None = None, ) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.clip(x._data)) - return out + _data = _atleast_0d(xp.clip(x._data, min=min, max=max), xp) + return x._new(data=_data) -def conj(x, /): +def conj(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.conj(x._data)) - return out + _data = _atleast_0d(xp.conj(x._data), xp) + return x._new(data=_data) def copysign(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.copysign(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.copysign(x1._data, x2._data), xp) + return x1._new(data=_data) -def cos(x, /): +def cos(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.cos(x._data)) - return out + _data = _atleast_0d(xp.cos(x._data), xp) + return x._new(data=_data) -def cosh(x, /): +def cosh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.cosh(x._data)) - return out + _data = _atleast_0d(xp.cosh(x._data), xp) + return x._new(data=_data) -def divide(x1, x2, /): +def divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.divide(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.divide(x1._data, x2._data), xp) + return x1._new(data=_data) -def exp(x, /): +def exp(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.exp(x._data)) - return out + _data = _atleast_0d(xp.exp(x._data), xp) + return x._new(data=_data) -def expm1(x, /): +def expm1(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.expm1(x._data)) - return out + _data = _atleast_0d(xp.expm1(x._data), xp) + return x._new(data=_data) -def equal(x1, x2, /): +def equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.equal(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.equal(x1._data, x2._data), xp) + return x1._new(data=_data) -def floor(x, /): +def floor(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.floor(x._data)) - return out + _data = _atleast_0d(xp.floor(x._data), xp) + return x._new(data=_data) -def floor_divide(x1, x2, /): +def floor_divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.floor_divide(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.floor_divide(x1._data, x2._data), xp) + return x1._new(data=_data) -def greater(x1, x2, /): +def greater(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.greater(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.greater(x1._data, x2._data), xp) + return x1._new(data=_data) -def greater_equal(x1, x2, /): +def greater_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.greater_equal(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.greater_equal(x1._data, x2._data), xp) + return x1._new(data=_data) def hypot(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.hypot(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.hypot(x1._data, x2._data), xp) + return x1._new(data=_data) def imag( @@ -241,145 +246,145 @@ def imag( array([2., 4.]) """ xp = _get_data_namespace(x) - out = x._new(data=xp.imag(x._data)) - return out + _data = _atleast_0d(xp.imag(x._data), xp) + return x._new(data=_data) -def isfinite(x, /): +def isfinite(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.isfinite(x._data)) - return out + _data = _atleast_0d(xp.isfinite(x._data), xp) + return x._new(data=_data) -def isinf(x, /): +def isinf(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.isinf(x._data)) - return out + _data = _atleast_0d(xp.isinf(x._data), xp) + return x._new(data=_data) -def isnan(x, /): +def isnan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.isnan(x._data)) - return out + _data = _atleast_0d(xp.isnan(x._data), xp) + return x._new(data=_data) -def less(x1, x2, /): +def less(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.less(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.less(x1._data, x2._data), xp) + return x1._new(data=_data) -def less_equal(x1, x2, /): +def less_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.less_equal(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.less_equal(x1._data, x2._data), xp) + return x1._new(data=_data) -def log(x, /): +def log(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.log(x._data)) - return out + _data = _atleast_0d(xp.log(x._data), xp) + return x._new(data=_data) -def log1p(x, /): +def log1p(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.log1p(x._data)) - return out + _data = _atleast_0d(xp.log1p(x._data), xp) + return x._new(data=_data) -def log2(x, /): +def log2(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.log2(x._data)) - return out + _data = _atleast_0d(xp.log2(x._data), xp) + return x._new(data=_data) -def log10(x, /): +def log10(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.log10(x._data)) - return out + _data = _atleast_0d(xp.log10(x._data), xp) + return x._new(data=_data) -def logaddexp(x1, x2, /): +def logaddexp(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logaddexp(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.logaddexp(x1._data, x2._data), xp) + return x1._new(data=_data) -def logical_and(x1, x2, /): +def logical_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logical_and(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.logical_and(x1._data, x2._data), xp) + return x1._new(data=_data) -def logical_not(x, /): +def logical_not(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.logical_not(x._data)) - return out + _data = _atleast_0d(xp.logical_not(x._data), xp) + return x._new(data=_data) -def logical_or(x1, x2, /): +def logical_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logical_or(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.logical_or(x1._data, x2._data), xp) + return x1._new(data=_data) -def logical_xor(x1, x2, /): +def logical_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.logical_xor(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.logical_xor(x1._data, x2._data), xp) + return x1._new(data=_data) def maximum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.maximum(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.maximum(x1._data, x2._data), xp) + return x1._new(data=_data) def minimum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.minimum(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.minimum(x1._data, x2._data), xp) + return x1._new(data=_data) -def multiply(x1, x2, /): +def multiply(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.multiply(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.multiply(x1._data, x2._data), xp) + return x1._new(data=_data) -def negative(x, /): +def negative(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.negative(x._data)) - return out + _data = _atleast_0d(xp.negative(x._data), xp) + return x._new(data=_data) -def not_equal(x1, x2, /): +def not_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.not_equal(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.not_equal(x1._data, x2._data), xp) + return x1._new(data=_data) -def positive(x, /): +def positive(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.positive(x._data)) - return out + _data = _atleast_0d(xp.positive(x._data), xp) + return x._new(data=_data) -def pow(x1, x2, /): +def pow(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.pow(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.pow(x1._data, x2._data), xp) + return x1._new(data=_data) def real( @@ -410,79 +415,79 @@ def real( array([1., 2.]) """ xp = _get_data_namespace(x) - out = x._new(data=xp.real(x._data)) - return out + _data = _atleast_0d(xp.real(x._data), xp) + return x._new(data=_data) -def remainder(x1, x2, /): +def remainder(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.remainder(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.remainder(x1._data, x2._data), xp) + return x1._new(data=_data) -def round(x, /): +def round(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.round(x._data)) - return out + _data = _atleast_0d(xp.round(x._data), xp) + return x._new(data=_data) -def sign(x, /): +def sign(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.sign(x._data)) - return out + _data = _atleast_0d(xp.sign(x._data), xp) + return x._new(data=_data) -def signbit(x, /): +def signbit(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.signbit(x._data)) - return out + _data = _atleast_0d(xp.signbit(x._data), xp) + return x._new(data=_data) -def sin(x, /): +def sin(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.sin(x._data)) - return out + _data = _atleast_0d(xp.sin(x._data), xp) + return x._new(data=_data) -def sinh(x, /): +def sinh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.sinh(x._data)) - return out + _data = _atleast_0d(xp.sinh(x._data), xp) + return x._new(data=_data) -def sqrt(x, /): +def sqrt(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.sqrt(x._data)) - return out + _data = _atleast_0d(xp.sqrt(x._data), xp) + return x._new(data=_data) -def square(x, /): +def square(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.square(x._data)) - return out + _data = _atleast_0d(xp.square(x._data), xp) + return x._new(data=_data) -def subtract(x1, x2, /): +def subtract(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - out = x1._new(data=xp.subtract(x1._data, x2._data)) - return out + _data = _atleast_0d(xp.subtract(x1._data, x2._data), xp) + return x1._new(data=_data) -def tan(x, /): +def tan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.tan(x._data)) - return out + _data = _atleast_0d(xp.tan(x._data), xp) + return x._new(data=_data) -def tanh(x, /): +def tanh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - out = x._new(data=xp.tanh(x._data)) - return out + _data = _atleast_0d(xp.tanh(x._data), xp) + return x._new(data=_data) def trunc(x, /): xp = _get_data_namespace(x) - out = x._new(data=xp.trunc(x._data)) - return out + _data = _atleast_0d(xp.trunc(x._data), xp) + return x._new(data=_data) From ee71d042c9cdef35412517eecb325955353c84c4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:49:17 +0200 Subject: [PATCH 217/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 7dd663e57b7..406371f71ed 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -506,7 +506,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) else: - raise NotImplementedError("{k=} is not supported") + raise NotImplementedError(f"{key=} is not supported") def __gt__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, greater From ec5543469d7377b9d78236d0323036dd38c14b36 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:50:51 +0200 Subject: [PATCH 218/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 406371f71ed..09af79cf327 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -501,7 +501,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) elif isinstance(key, NamedArray): - _key = self._data # TODO: Transpose, unordered dims shouldn't matter. + _key = key._data # TODO: Transpose, unordered dims shouldn't matter. _data = self._data[_key] _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) From 93bd9ca708fba00cdee86b0bd72fe53e885c2963 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:56:18 +0200 Subject: [PATCH 219/367] more --- xarray/namedarray/_array_api/_elementwise_functions.py | 8 +------- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- xarray/namedarray/_array_api/_utils.py | 7 +++++++ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 0d91b13f41d..0fc19a340c8 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -8,17 +8,11 @@ _ShapeType, _SupportsImag, _SupportsReal, + _atleast_0d, ) from xarray.namedarray.core import NamedArray -def _atleast_0d(x, xp): - """ - Workaround for numpy sometimes returning scalars instead of 0d arrays. - """ - return xp.asarray(x) - - def abs(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) _data = _atleast_0d(xp.abs(x._data), xp) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 5264786ded5..2bf0ef9dc3e 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -167,7 +167,7 @@ def roll( ) -> NamedArray: xp = _get_data_namespace(x) _data = xp.roll(x._data, shift=shift, axis=axis) - return x._new(_data) + return x._new(data=_data) def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 7f462bedb47..fb01a4ec56e 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -174,3 +174,10 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: d = list(dims) d.insert(axis, dim) return tuple(d) + + +def _atleast_0d(x, xp): + """ + Workaround for numpy sometimes returning scalars instead of 0d arrays. + """ + return xp.asarray(x) From ec0e79ce0c5d8666747792a52fcc59b2754a3686 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:56:56 +0000 Subject: [PATCH 220/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_elementwise_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 0fc19a340c8..431e31f3ae4 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -4,11 +4,11 @@ from xarray.namedarray._array_api._utils import _get_data_namespace from xarray.namedarray._typing import ( + _atleast_0d, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, - _atleast_0d, ) from xarray.namedarray.core import NamedArray From 1da91ea789c7cd86147f704ce20ac5ded0d3f792 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 18:59:02 +0200 Subject: [PATCH 221/367] Update _elementwise_functions.py --- xarray/namedarray/_array_api/_elementwise_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 431e31f3ae4..463a452b771 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -2,9 +2,8 @@ import numpy as np -from xarray.namedarray._array_api._utils import _get_data_namespace +from xarray.namedarray._array_api._utils import _atleast_0d, _get_data_namespace from xarray.namedarray._typing import ( - _atleast_0d, _ScalarType, _ShapeType, _SupportsImag, From bc4e69c6c170b398b5f2543347c243d5b8887613 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 19:16:17 +0200 Subject: [PATCH 222/367] more --- .../_array_api/_statistical_functions.py | 21 ++++++++++--------- .../_array_api/_utility_functions.py | 5 +++-- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 40bc490fb0d..20257b60df1 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -3,6 +3,7 @@ from typing import Any from xarray.namedarray._array_api._utils import ( + _atleast_0d, _dims_to_axis, _get_data_namespace, _get_remaining_dims, @@ -67,7 +68,7 @@ def max( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.max(x._data, axis=_axis, keepdims=False) # We fix keepdims later + _data = _atleast_0d(xp.max(x._data, axis=_axis, keepdims=False), xp) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) return x._new(dims=dims_, data=data_) @@ -133,7 +134,7 @@ def mean( """ xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.mean(x._data, axis=_axis, keepdims=False) # We fix keepdims later + _data = _atleast_0d(xp.mean(x._data, axis=_axis, keepdims=False), xp) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -150,7 +151,7 @@ def min( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.min(x._data, axis=_axis, keepdims=False) # We fix keepdims later + _data = _atleast_0d(xp.min(x._data, axis=_axis, keepdims=False), xp) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -168,7 +169,7 @@ def prod( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.prod(x._data, axis=_axis, keepdims=False) # We fix keepdims later + _data = _atleast_0d(xp.prod(x._data, axis=_axis, keepdims=False), xp) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -186,9 +187,9 @@ def std( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.std( - x._data, axis=_axis, correction=correction, keepdims=False - ) # We fix keepdims later + _data = _atleast_0d( + xp.std(x._data, axis=_axis, correction=correction, keepdims=False), xp + ) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -206,7 +207,7 @@ def sum( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.sum(x._data, axis=_axis, keepdims=False) # We fix keepdims later + _data = _atleast_0d(xp.sum(x._data, axis=_axis, keepdims=False), xp) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -224,9 +225,9 @@ def var( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.var( + _data = _atleast_0d(xp.var( x._data, axis=_axis, correction=correction, keepdims=False - ) # We fix keepdims later + ), xp) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 72f5451a471..350e4216d24 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -3,6 +3,7 @@ from typing import Any from xarray.namedarray._array_api._utils import ( + _atleast_0d, _dims_to_axis, _get_data_namespace, _get_remaining_dims, @@ -29,7 +30,7 @@ def all( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) - d = xp.all(x._data, axis=axis_, keepdims=False) # We fix keepdims later + d = _atleast_0d(xp.all(x._data, axis=axis_, keepdims=False), xp) dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out @@ -45,7 +46,7 @@ def any( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) - d = xp.any(x._data, axis=axis_, keepdims=False) # We fix keepdims later + d = _atleast_0d(xp.any(x._data, axis=axis_, keepdims=False), xp) dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out From dc37d5c0ebbdea7f5b7cfed4edc5305e43307014 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:16:56 +0000 Subject: [PATCH 223/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_statistical_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 20257b60df1..d54c3e2d6f3 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -225,9 +225,9 @@ def var( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d(xp.var( - x._data, axis=_axis, correction=correction, keepdims=False - ), xp) + _data = _atleast_0d( + xp.var(x._data, axis=_axis, correction=correction, keepdims=False), xp + ) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) From b1d7a59fb4c12286bf0ccc4c4059ace68f7ff1a1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 20:35:28 +0200 Subject: [PATCH 224/367] Update _statistical_functions.py --- xarray/namedarray/_array_api/_statistical_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index d54c3e2d6f3..29f6e802ed5 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -32,7 +32,7 @@ def cumulative_sum( axis: int | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis) + _axis = _dims_to_axis(x, dim, axis)[0] try: _data = xp.cumulative_sum( x._data, axis=_axis, dtype=dtype, include_initial=include_initial From 48e5e9ef9385bc15cfeafa57b5564171c23f34a3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 22:29:35 +0200 Subject: [PATCH 225/367] more --- .../_array_api/_manipulation_functions.py | 12 ++- .../_array_api/_statistical_functions.py | 4 +- xarray/namedarray/_array_api/_utils.py | 73 +++++++++++++++++++ 3 files changed, 87 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 2bf0ef9dc3e..33bba7ddf4f 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -6,6 +6,7 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _get_data_namespace, + _get_broadcasted_dims, _infer_dims, _insert_dim, ) @@ -23,11 +24,20 @@ def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: + """ + Broadcasts one or more arrays against one another. + + Examples + -------- + >>> x = xp.asarray([[1, 2, 3]]) + >>> y = xp.asarray([[4], [5]]) + >>> xp.broadcast_arrays(x, y) + """ x = arrays[0] xp = _get_data_namespace(x) + _dims, _ = _get_broadcasted_dims(*arrays) _arrays = tuple(a._data for a in arrays) _datas = xp.broadcast_arrays(*_arrays) - _dims = _infer_dims(_datas[0].shape) return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas)] diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 29f6e802ed5..43228d51bed 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -32,7 +32,9 @@ def cumulative_sum( axis: int | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis)[0] + + a = _dims_to_axis(x, dim, axis) + _axis = a if a is None else a[0] try: _data = xp.cumulative_sum( x._data, axis=_axis, dtype=dtype, include_initial=include_initial diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index fb01a4ec56e..4a82f586785 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -181,3 +181,76 @@ def _atleast_0d(x, xp): Workaround for numpy sometimes returning scalars instead of 0d arrays. """ return xp.asarray(x) + + +# %% +def _raise_if_any_duplicate_dimensions( + dims: _Dims, err_context: str = "This function" +) -> None: + if len(set(dims)) < len(dims): + repeated_dims = {d for d in dims if dims.count(d) > 1} + raise ValueError( + f"{err_context} cannot handle duplicate dimensions, " + f"but dimensions {repeated_dims} appear more than once on this object's dims: {dims}" + ) + + +def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: + """ + Get the expected broadcasted dims. + + Examples + -------- + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) + >>> _get_broadcasted_dims(a, b) + (('x', 'y', 'z'), (5, 3, 4)) + + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) + >>> _get_broadcasted_dims(a, b) + (('x', 'y', 'z'), (5, 3, 4)) + + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> b = NamedArray(("x", "y", "z"), np.zeros((1, 3, 4))) + >>> _get_broadcasted_dims(a, b) + (('x', 'y', 'z'), (5, 3, 4)) + + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> b = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> _get_broadcasted_dims(a, b) + (('x', 'y', 'z'), (5, 3, 4)) + + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> b = NamedArray(("x", "y", "z"), np.zeros((2, 3, 4))) + >>> _get_broadcasted_dims(a, b) + Traceback (most recent call last): + ... + ValueError: operands cannot be broadcast together with mismatched lengths for dimension 'x': (5, 2) + """ + + def broadcastable(e1: int, e2: int) -> bool: + out = e1 > 1 and e2 <= 1 + out |= e2 > 1 and e1 <= 1 + return out + + # validate dimensions + all_dims = {} + for x in arrays: + _dims = x.dims + _raise_if_any_duplicate_dimensions(_dims, err_context="Broadcasting") + + for d, s in zip(_dims, x.shape): + if d not in all_dims: + all_dims[d] = s + elif all_dims[d] != s: + if broadcastable(all_dims[d], s): + max(all_dims[d], s) + else: + raise ValueError( + "operands cannot be broadcast together " + f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}" + ) + + # TODO: Return flag whether broadcasting is needed? + return tuple(all_dims.keys()), tuple(all_dims.values()) From 813396a7a087c9914f351f8e6ebae5b36ea9359f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 25 Aug 2024 20:30:15 +0000 Subject: [PATCH 226/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 33bba7ddf4f..10da4006472 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -5,8 +5,8 @@ from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( - _get_data_namespace, _get_broadcasted_dims, + _get_data_namespace, _infer_dims, _insert_dim, ) From 94eb7b2501445ea4bdaa5aba73eb80bac51b924b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 22:42:11 +0200 Subject: [PATCH 227/367] Update core.py --- xarray/namedarray/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 09af79cf327..07221fc0281 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -494,15 +494,15 @@ def __ge__(self, other: int | float | NamedArray, /): return greater_equal(self, asarray(other)) def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: - from xarray.namedarray._array_api._utils import _infer_dims + from xarray.namedarray._array_api._utils import _atleast_0d, _infer_dims if isinstance(key, int | slice | tuple): - _data = self._data[key] + _data = _atleast_0d(self._data[key], self.__array_namespace__()) _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) elif isinstance(key, NamedArray): _key = key._data # TODO: Transpose, unordered dims shouldn't matter. - _data = self._data[_key] + _data = _atleast_0d(self._data[_key], self.__array_namespace__()) _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) else: From ce8a60abe83503e2b135f07c14194772e5db1c54 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 22:50:51 +0200 Subject: [PATCH 228/367] Update core.py --- xarray/namedarray/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 07221fc0281..263f283577d 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -497,12 +497,12 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: from xarray.namedarray._array_api._utils import _atleast_0d, _infer_dims if isinstance(key, int | slice | tuple): - _data = _atleast_0d(self._data[key], self.__array_namespace__()) + _data = _atleast_0d(self._data[key], self._data._array_namespace__()) _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) elif isinstance(key, NamedArray): _key = key._data # TODO: Transpose, unordered dims shouldn't matter. - _data = _atleast_0d(self._data[_key], self.__array_namespace__()) + _data = _atleast_0d(self._data[_key], self._data.__array_namespace__()) _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) else: From bef1eadbde2acd78b59228bc81c4d7af07b506da Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 23:00:30 +0200 Subject: [PATCH 229/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 263f283577d..edb710f5af9 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -497,7 +497,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: from xarray.namedarray._array_api._utils import _atleast_0d, _infer_dims if isinstance(key, int | slice | tuple): - _data = _atleast_0d(self._data[key], self._data._array_namespace__()) + _data = _atleast_0d(self._data[key], self._data.__array_namespace__()) _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) elif isinstance(key, NamedArray): From c13a8aea592ad30ed122e7e95500fe88e267abbd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 23:13:42 +0200 Subject: [PATCH 230/367] Update _elementwise_functions.py --- xarray/namedarray/_array_api/_elementwise_functions.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 463a452b771..a52c3529971 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -2,7 +2,11 @@ import numpy as np -from xarray.namedarray._array_api._utils import _atleast_0d, _get_data_namespace +from xarray.namedarray._array_api._utils import ( + _atleast_0d, + _get_broadcasted_dims, + _get_data_namespace, +) from xarray.namedarray._typing import ( _ScalarType, _ShapeType, @@ -172,9 +176,9 @@ def expm1(x: NamedArray, /) -> NamedArray: def equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now + _dims, _ = _get_broadcasted_dims(x1, x2) _data = _atleast_0d(xp.equal(x1._data, x2._data), xp) - return x1._new(data=_data) + return NamedArray(_dims, _data) def floor(x: NamedArray, /) -> NamedArray: From e56e7cb3dc235e7f5a48e16c65d085ac93c4a708 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:55:46 +0200 Subject: [PATCH 231/367] Update core.py --- xarray/namedarray/core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index edb710f5af9..94ec2aef1bf 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -496,15 +496,15 @@ def __ge__(self, other: int | float | NamedArray, /): def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: from xarray.namedarray._array_api._utils import _atleast_0d, _infer_dims - if isinstance(key, int | slice | tuple): - _data = _atleast_0d(self._data[key], self._data.__array_namespace__()) - _dims = _infer_dims(_data.shape) # TODO: fix - return self._new(_dims, _data) - elif isinstance(key, NamedArray): + if isinstance(key, NamedArray): _key = key._data # TODO: Transpose, unordered dims shouldn't matter. _data = _atleast_0d(self._data[_key], self._data.__array_namespace__()) _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) + elif isinstance(key, int | slice | tuple) or key is None or key is ...: + _data = _atleast_0d(self._data[key], self._data.__array_namespace__()) + _dims = _infer_dims(_data.shape) # TODO: fix + return self._new(_dims, _data) else: raise NotImplementedError(f"{key=} is not supported") From acc4dc0cfee8dd6550f2a7e8fadb80cf36b9c894 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 27 Aug 2024 23:00:41 +0200 Subject: [PATCH 232/367] Simplify asarray --- .../_array_api/_creation_functions.py | 34 +++++-------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 8f96ec73d92..4e2cd0e23df 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -6,6 +6,7 @@ from xarray.namedarray._array_api._utils import ( _get_data_namespace, + _get_namespace, _get_namespace_dtype, _infer_dims, ) @@ -115,33 +116,16 @@ def asarray( """ data = obj if isinstance(data, NamedArray): - if copy: - return data.copy() - else: + xp = _get_data_namespace(data) + _dtype = data.dtype if dtype is None else dtype + new_data = xp.asarray(data._data, dtype=_dtype, device=device, copy=copy) + if new_data is data._data: return data + else: + NamedArray(data.dims, new_data, data.attrs) - # TODO: dask.array.ma.MaskedArray also exists, better way? - if isinstance(data, np.ma.MaskedArray): - mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] - if mask.any(): - # TODO: requires refactoring/vendoring xarray.core.dtypes and - # xarray.core.duck_array_ops - raise NotImplementedError("MaskedArray is not supported yet") - - _dims = _infer_dims(data.shape, dims) - return NamedArray(_dims, data) - - if isinstance(data, _arrayfunction_or_api): - _dims = _infer_dims(data.shape, dims) - return NamedArray(_dims, data) - - if isinstance(data, tuple): - _data = to_0d_object_array(data) - _dims = _infer_dims(_data.shape, dims) - return NamedArray(_dims, _data) - - # validate whether the data is valid data types. - _data = np.asarray(data, dtype=dtype, device=device, copy=copy) + xp = _get_namespace(data) + _data = xp.asarray(data, dtype=dtype, device=device, copy=copy) _dims = _infer_dims(_data.shape, dims) return NamedArray(_dims, _data) From 2466035082fb5be8fd844eff2a0b676c3c8032a2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 28 Aug 2024 00:29:21 +0200 Subject: [PATCH 233/367] promote scalars correctly --- .../_array_api/_elementwise_functions.py | 2 +- xarray/namedarray/_array_api/_utils.py | 10 +- xarray/namedarray/core.py | 112 +++++++++++------- 3 files changed, 74 insertions(+), 50 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index a52c3529971..c5d11aa8299 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -37,7 +37,7 @@ def acosh(x: NamedArray, /) -> NamedArray: def add(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.add(x1._data, x2._data), xp) + _data = xp.add(x1._data, x2._data) return x1._new(data=_data) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 4a82f586785..c5f11351cba 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -37,13 +37,17 @@ def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: return xp -def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: - if isinstance(x._data, _arrayapi): - return x._data.__array_namespace__() +def _get_namespace(x: Any) -> ModuleType: + if isinstance(x, _arrayapi): + return x.__array_namespace__() return _maybe_default_namespace() +def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: + return _get_namespace(x._data) + + def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: if dtype is None: return _maybe_default_namespace() diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 94ec2aef1bf..91544fc014c 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -427,20 +427,40 @@ def __len__(self) -> _IntOrUnknown: # < Array api > + def _maybe_asarray( + self, x: bool | int | float | complex | NamedArray + ) -> NamedArray: + """ + If x is a scalar, use asarray with the same dtype as self. + If it is namedarray already, respect the dtype and return it. + + Array API always promotes scalars to the same dtype as the other array. + Arrays are promoted according to result_types. + """ + from xarray.namedarray._array_api import asarray + + if isinstance(x, NamedArray): + # x is proper array. Respect the chosen dtype. + return x + # x is a scalar. Use the same dtype as self. + return asarray(x, dtype=self.dtype) + + # Required methods below: + def __abs__(self, /) -> Self: from xarray.namedarray._array_api import abs return abs(self) def __add__(self, other: int | float | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import add, asarray + from xarray.namedarray._array_api import add - return add(self, asarray(other)) + return add(self, self._maybe_asarray(other)) def __and__(self, other: int | bool | NamedArray, /) -> NamedArray: from xarray.namedarray._array_api import asarray, bitwise_and - return bitwise_and(self, asarray(other)) + return bitwise_and(self, self._maybe_asarray(other)) def __array_namespace__(self, /, *, api_version: str | None = None): if api_version is not None and api_version not in ( @@ -478,7 +498,7 @@ def __dlpack_device__(self, /) -> tuple[IntEnum, int]: def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: from xarray.namedarray._array_api import asarray, equal - return equal(self, asarray(other)) + return equal(self, self._maybe_asarray(other)) def __float__(self, /) -> float: return self._data.__float__() @@ -486,12 +506,12 @@ def __float__(self, /) -> float: def __floordiv__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, floor_divide - return floor_divide(self, asarray(other)) + return floor_divide(self, self._maybe_asarray(other)) def __ge__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, greater_equal - return greater_equal(self, asarray(other)) + return greater_equal(self, self._maybe_asarray(other)) def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: from xarray.namedarray._array_api._utils import _atleast_0d, _infer_dims @@ -511,7 +531,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: def __gt__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, greater - return greater(self, asarray(other)) + return greater(self, self._maybe_asarray(other)) def __index__(self, /) -> int: return self._data.__index__() @@ -533,37 +553,37 @@ def __iter__(self: NamedArray, /): def __le__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, less_equal - return less_equal(self, asarray(other)) + return less_equal(self, self._maybe_asarray(other)) def __lshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_left_shift - return bitwise_left_shift(self, asarray(other)) + return bitwise_left_shift(self, self._maybe_asarray(other)) def __lt__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, less - return less(self, asarray(other)) + return less(self, self._maybe_asarray(other)) def __matmul__(self, other: NamedArray, /): from xarray.namedarray._array_api import asarray, matmul - return matmul(self, asarray(other)) + return matmul(self, self._maybe_asarray(other)) def __mod__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, remainder - return remainder(self, asarray(other)) + return remainder(self, self._maybe_asarray(other)) def __mul__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, multiply - return multiply(self, asarray(other)) + return multiply(self, self._maybe_asarray(other)) def __ne__(self, other: int | float | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, not_equal - return not_equal(self, asarray(other)) + return not_equal(self, self._maybe_asarray(other)) def __neg__(self, /): from xarray.namedarray._array_api import negative @@ -573,7 +593,7 @@ def __neg__(self, /): def __or__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_or - return bitwise_or(self, asarray(other)) + return bitwise_or(self, self._maybe_asarray(other)) def __pos__(self, /): from xarray.namedarray._array_api import positive @@ -583,12 +603,12 @@ def __pos__(self, /): def __pow__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, pow - return pow(self, asarray(other)) + return pow(self, self._maybe_asarray(other)) def __rshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_right_shift - return bitwise_right_shift(self, asarray(other)) + return bitwise_right_shift(self, self._maybe_asarray(other)) def __setitem__( self, @@ -600,66 +620,66 @@ def __setitem__( if isinstance(key, NamedArray): key = key._data - self._data.__setitem__(key, asarray(value)._data) + self._data.__setitem__(key, self._maybe_asarray(value)._data) def __sub__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, subtract - return subtract(self, asarray(other)) + return subtract(self, self._maybe_asarray(other)) def __truediv__(self, other: float | NamedArray, /): from xarray.namedarray._array_api import asarray, divide - return divide(self, asarray(other)) + return divide(self, self._maybe_asarray(other)) def __xor__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_xor - return bitwise_xor(self, asarray(other)) + return bitwise_xor(self, self._maybe_asarray(other)) def __iadd__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__iadd__(asarray(other)._data) + self._data.__iadd__(self._maybe_asarray(other)._data) return self def __radd__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import add, asarray - return add(asarray(other), self) + return add(self._maybe_asarray(other), self) def __iand__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__iand__(asarray(other)._data) + self._data.__iand__(self._maybe_asarray(other)._data) return self def __rand__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_and - return bitwise_and(asarray(other), self) + return bitwise_and(self._maybe_asarray(other), self) def __ifloordiv__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__ifloordiv__(asarray(other)._data) + self._data.__ifloordiv__(self._maybe_asarray(other)._data) return self def __rfloordiv__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, floor_divide - return floor_divide(asarray(other), self) + return floor_divide(self._maybe_asarray(other), self) def __ilshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__ilshift__(asarray(other)._data) + self._data.__ilshift__(self._maybe_asarray(other)._data) return self def __rlshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_left_shift - return bitwise_left_shift(asarray(other), self) + return bitwise_left_shift(self._maybe_asarray(other), self) def __imatmul__(self, other: NamedArray, /): self._data.__imatmul__(other._data) @@ -668,95 +688,95 @@ def __imatmul__(self, other: NamedArray, /): def __rmatmul__(self, other: NamedArray, /): from xarray.namedarray._array_api import asarray, matmul - return matmul(asarray(other), self) + return matmul(self._maybe_asarray(other), self) def __imod__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__imod__(asarray(other)._data) + self._data.__imod__(self._maybe_asarray(other)._data) return self def __rmod__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, remainder - return remainder(asarray(other), self) + return remainder(self._maybe_asarray(other), self) def __imul__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__imul__(asarray(other)._data) + self._data.__imul__(self._maybe_asarray(other)._data) return self def __rmul__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, multiply - return multiply(asarray(other), self) + return multiply(self._maybe_asarray(other), self) def __ior__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__ior__(asarray(other)._data) + self._data.__ior__(self._maybe_asarray(other)._data) return self def __ror__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_or - return bitwise_or(asarray(other), self) + return bitwise_or(self._maybe_asarray(other), self) def __ipow__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__ipow__(asarray(other)._data) + self._data.__ipow__(self._maybe_asarray(other)._data) return self def __rpow__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, pow - return pow(asarray(other), self) + return pow(self._maybe_asarray(other), self) def __irshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__irshift__(asarray(other)._data) + self._data.__irshift__(self._maybe_asarray(other)._data) return self def __rrshift__(self, other: int | NamedArray, /): from xarray.namedarray._array_api import asarray, bitwise_right_shift - return bitwise_right_shift(asarray(other), self) + return bitwise_right_shift(self._maybe_asarray(other), self) def __isub__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__isub__(asarray(other)._data) + self._data.__isub__(self._maybe_asarray(other)._data) return self def __rsub__(self, other: int | float | NamedArray, /): from xarray.namedarray._array_api import asarray, subtract - return subtract(asarray(other), self) + return subtract(self._maybe_asarray(other), self) def __itruediv__(self, other: float | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__itruediv__(asarray(other)._data) + self._data.__itruediv__(self._maybe_asarray(other)._data) return self def __rtruediv__(self, other: float | NamedArray, /): from xarray.namedarray._array_api import asarray, divide - return divide(asarray(other), self) + return divide(self._maybe_asarray(other), self) def __ixor__(self, other: int | bool | NamedArray, /): from xarray.namedarray._array_api import asarray - self._data.__ixor__(asarray(other)._data) + self._data.__ixor__(self._maybe_asarray(other)._data) return self def __rxor__(self, other, /): from xarray.namedarray._array_api import asarray, bitwise_xor - return bitwise_xor(asarray(other), self) + return bitwise_xor(self._maybe_asarray(other), self) def to_device(self, device: _Device, /, stream: None = None) -> Self: if isinstance(self._data, _arrayapi): From 1dad0e3217897eded374edf2b6f95abf2a137569 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 28 Aug 2024 00:51:26 +0200 Subject: [PATCH 234/367] fix --- xarray/namedarray/_array_api/_creation_functions.py | 2 +- xarray/namedarray/core.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 4e2cd0e23df..c7d0378f1f7 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -122,7 +122,7 @@ def asarray( if new_data is data._data: return data else: - NamedArray(data.dims, new_data, data.attrs) + return NamedArray(data.dims, new_data, data.attrs) xp = _get_namespace(data) _data = xp.asarray(data, dtype=dtype, device=device, copy=copy) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 91544fc014c..fe4eda7aa65 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -518,11 +518,11 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: if isinstance(key, NamedArray): _key = key._data # TODO: Transpose, unordered dims shouldn't matter. - _data = _atleast_0d(self._data[_key], self._data.__array_namespace__()) + _data = self._data[_key] _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) elif isinstance(key, int | slice | tuple) or key is None or key is ...: - _data = _atleast_0d(self._data[key], self._data.__array_namespace__()) + _data = self._data[key] _dims = _infer_dims(_data.shape) # TODO: fix return self._new(_dims, _data) else: From a7b3a1ca1b56757950157d268709d41e3f0e5a25 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 28 Aug 2024 07:19:30 +0200 Subject: [PATCH 235/367] all shapes less or equal than 1 are broadcastable --- xarray/namedarray/_array_api/_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index c5f11351cba..9678640596f 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -234,8 +234,14 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: """ def broadcastable(e1: int, e2: int) -> bool: - out = e1 > 1 and e2 <= 1 - out |= e2 > 1 and e1 <= 1 + # out = e1 > 1 and e2 <= 1 + # out |= e2 > 1 and e1 <= 1 + + # out = e1 >= 0 and e2 <= 1 + # out |= e2 >= 0 and e1 <= 1 + + out = e1 <= 1 or e2 <= 1 + return out # validate dimensions From 8ad9d1ff41f5100c07ff8b27bf5e221b47301335 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 28 Aug 2024 07:42:06 +0200 Subject: [PATCH 236/367] Update _elementwise_functions.py --- .../_array_api/_elementwise_functions.py | 233 +++++++++--------- 1 file changed, 116 insertions(+), 117 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index c5d11aa8299..1b32193e1ab 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -3,7 +3,6 @@ import numpy as np from xarray.namedarray._array_api._utils import ( - _atleast_0d, _get_broadcasted_dims, _get_data_namespace, ) @@ -18,104 +17,104 @@ def abs(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.abs(x._data), xp) + _data = xp.abs(x._data) return x._new(data=_data) def acos(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.acos(x._data), xp) + _data = xp.acos(x._data) return x._new(data=_data) def acosh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.acosh(x._data), xp) + _data = xp.acosh(x._data) return x._new(data=_data) def add(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now + _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.add(x1._data, x2._data) - return x1._new(data=_data) + return NamedArray(_dims, _data) def asin(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.asin(x._data), xp) + _data = xp.asin(x._data) return x._new(data=_data) def asinh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.asinh(x._data), xp) + _data = xp.asinh(x._data) return x._new(data=_data) def atan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.atan(x._data), xp) + _data = xp.atan(x._data) return x._new(data=_data) def atan2(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.atan2(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.atan2(x1._data, x2._data) + return NamedArray(_dims, _data) def atanh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.atanh(x._data), xp) + _data = xp.atanh(x._data) return x._new(data=_data) def bitwise_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.bitwise_and(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.bitwise_and(x1._data, x2._data) + return NamedArray(_dims, _data) def bitwise_invert(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.bitwise_invert(x._data), xp) + _data = xp.bitwise_invert(x._data) return x._new(data=_data) def bitwise_left_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.bitwise_left_shift(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.bitwise_left_shift(x1._data, x2._data) + return NamedArray(_dims, _data) def bitwise_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.bitwise_or(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.bitwise_or(x1._data, x2._data) + return NamedArray(_dims, _data) def bitwise_right_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.bitwise_right_shift(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.bitwise_right_shift(x1._data, x2._data) + return NamedArray(_dims, _data) def bitwise_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.bitwise_xor(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.bitwise_xor(x1._data, x2._data) + return NamedArray(_dims, _data) def ceil(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.ceil(x._data), xp) + _data = xp.ceil(x._data) return x._new(data=_data) @@ -126,93 +125,93 @@ def clip( max: int | float | NamedArray | None = None, ) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.clip(x._data, min=min, max=max), xp) + _data = xp.clip(x._data, min=min, max=max) return x._new(data=_data) def conj(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.conj(x._data), xp) + _data = xp.conj(x._data) return x._new(data=_data) def copysign(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.copysign(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.copysign(x1._data, x2._data) + return NamedArray(_dims, _data) def cos(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.cos(x._data), xp) + _data = xp.cos(x._data) return x._new(data=_data) def cosh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.cosh(x._data), xp) + _data = xp.cosh(x._data) return x._new(data=_data) def divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.divide(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.divide(x1._data, x2._data) + return NamedArray(_dims, _data) def exp(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.exp(x._data), xp) + _data = xp.exp(x._data) return x._new(data=_data) def expm1(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.expm1(x._data), xp) + _data = xp.expm1(x._data) return x._new(data=_data) def equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) - _data = _atleast_0d(xp.equal(x1._data, x2._data), xp) + _data = xp.equal(x1._data, x2._data) return NamedArray(_dims, _data) def floor(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.floor(x._data), xp) + _data = xp.floor(x._data) return x._new(data=_data) def floor_divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.floor_divide(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.floor_divide(x1._data, x2._data) + return NamedArray(_dims, _data) def greater(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.greater(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.greater(x1._data, x2._data) + return NamedArray(_dims, _data) def greater_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.greater_equal(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.greater_equal(x1._data, x2._data) + return NamedArray(_dims, _data) def hypot(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.hypot(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.hypot(x1._data, x2._data) + return NamedArray(_dims, _data) def imag( @@ -243,145 +242,145 @@ def imag( array([2., 4.]) """ xp = _get_data_namespace(x) - _data = _atleast_0d(xp.imag(x._data), xp) + _data = xp.imag(x._data) return x._new(data=_data) def isfinite(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.isfinite(x._data), xp) + _data = xp.isfinite(x._data) return x._new(data=_data) def isinf(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.isinf(x._data), xp) + _data = xp.isinf(x._data) return x._new(data=_data) def isnan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.isnan(x._data), xp) + _data = xp.isnan(x._data) return x._new(data=_data) def less(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.less(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.less(x1._data, x2._data) + return NamedArray(_dims, _data) def less_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.less_equal(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.less_equal(x1._data, x2._data) + return NamedArray(_dims, _data) def log(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.log(x._data), xp) + _data = xp.log(x._data) return x._new(data=_data) def log1p(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.log1p(x._data), xp) + _data = xp.log1p(x._data) return x._new(data=_data) def log2(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.log2(x._data), xp) + _data = xp.log2(x._data) return x._new(data=_data) def log10(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.log10(x._data), xp) + _data = xp.log10(x._data) return x._new(data=_data) def logaddexp(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.logaddexp(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.logaddexp(x1._data, x2._data) + return NamedArray(_dims, _data) def logical_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.logical_and(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.logical_and(x1._data, x2._data) + return NamedArray(_dims, _data) def logical_not(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.logical_not(x._data), xp) + _data = xp.logical_not(x._data) return x._new(data=_data) def logical_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.logical_or(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.logical_or(x1._data, x2._data) + return NamedArray(_dims, _data) def logical_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.logical_xor(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.logical_xor(x1._data, x2._data) + return NamedArray(_dims, _data) def maximum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.maximum(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.maximum(x1._data, x2._data) + return NamedArray(_dims, _data) def minimum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.minimum(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.minimum(x1._data, x2._data) + return NamedArray(_dims, _data) def multiply(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.multiply(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.multiply(x1._data, x2._data) + return NamedArray(_dims, _data) def negative(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.negative(x._data), xp) + _data = xp.negative(x._data) return x._new(data=_data) def not_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.not_equal(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.not_equal(x1._data, x2._data) + return NamedArray(_dims, _data) def positive(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.positive(x._data), xp) + _data = xp.positive(x._data) return x._new(data=_data) def pow(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.pow(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.pow(x1._data, x2._data) + return NamedArray(_dims, _data) def real( @@ -412,79 +411,79 @@ def real( array([1., 2.]) """ xp = _get_data_namespace(x) - _data = _atleast_0d(xp.real(x._data), xp) + _data = xp.real(x._data) return x._new(data=_data) def remainder(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.remainder(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.remainder(x1._data, x2._data) + return NamedArray(_dims, _data) def round(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.round(x._data), xp) + _data = xp.round(x._data) return x._new(data=_data) def sign(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.sign(x._data), xp) + _data = xp.sign(x._data) return x._new(data=_data) def signbit(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.signbit(x._data), xp) + _data = xp.signbit(x._data) return x._new(data=_data) def sin(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.sin(x._data), xp) + _data = xp.sin(x._data) return x._new(data=_data) def sinh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.sinh(x._data), xp) + _data = xp.sinh(x._data) return x._new(data=_data) def sqrt(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.sqrt(x._data), xp) + _data = xp.sqrt(x._data) return x._new(data=_data) def square(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.square(x._data), xp) + _data = xp.square(x._data) return x._new(data=_data) def subtract(x1: NamedArray, x2: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x1) - # TODO: Handle attrs? will get x1 now - _data = _atleast_0d(xp.subtract(x1._data, x2._data), xp) - return x1._new(data=_data) + _dims, _ = _get_broadcasted_dims(x1, x2) + _data = xp.subtract(x1._data, x2._data) + return NamedArray(_dims, _data) def tan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.tan(x._data), xp) + _data = xp.tan(x._data) return x._new(data=_data) def tanh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) - _data = _atleast_0d(xp.tanh(x._data), xp) + _data = xp.tanh(x._data) return x._new(data=_data) def trunc(x, /): xp = _get_data_namespace(x) - _data = _atleast_0d(xp.trunc(x._data), xp) + _data = xp.trunc(x._data) return x._new(data=_data) From a6091eb8f4f2406cb6d794ac7afa5315df6cc900 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 28 Aug 2024 07:55:50 +0200 Subject: [PATCH 237/367] Update _elementwise_functions.py --- .../_array_api/_elementwise_functions.py | 114 ++++++++++++------ 1 file changed, 76 insertions(+), 38 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 1b32193e1ab..23ba2c69797 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -17,20 +17,23 @@ def abs(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.abs(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def acos(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.acos(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def acosh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.acosh(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def add(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -42,20 +45,23 @@ def add(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def asin(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.asin(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def asinh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.asinh(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def atan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.atan(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def atan2(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -67,8 +73,9 @@ def atan2(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def atanh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.atanh(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def bitwise_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -80,8 +87,9 @@ def bitwise_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def bitwise_invert(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.bitwise_invert(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def bitwise_left_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -114,8 +122,9 @@ def bitwise_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def ceil(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.ceil(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def clip( @@ -125,14 +134,16 @@ def clip( max: int | float | NamedArray | None = None, ) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.clip(x._data, min=min, max=max) - return x._new(data=_data) + return NamedArray(_dims, _data) def conj(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.conj(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def copysign(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -144,14 +155,16 @@ def copysign(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def cos(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.cos(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def cosh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.cosh(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -163,14 +176,16 @@ def divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def exp(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.exp(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def expm1(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.expm1(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -182,8 +197,9 @@ def equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def floor(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.floor(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def floor_divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -242,26 +258,30 @@ def imag( array([2., 4.]) """ xp = _get_data_namespace(x) + _dims = x.dims _data = xp.imag(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def isfinite(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.isfinite(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def isinf(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.isinf(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def isnan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.isnan(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def less(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -280,26 +300,30 @@ def less_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def log(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.log(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def log1p(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.log1p(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def log2(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.log2(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def log10(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.log10(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def logaddexp(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -318,8 +342,9 @@ def logical_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def logical_not(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.logical_not(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def logical_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -359,8 +384,9 @@ def multiply(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def negative(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.negative(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def not_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -372,8 +398,9 @@ def not_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def positive(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.positive(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def pow(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -411,8 +438,9 @@ def real( array([1., 2.]) """ xp = _get_data_namespace(x) + _dims = x.dims _data = xp.real(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def remainder(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -424,44 +452,51 @@ def remainder(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def round(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.round(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def sign(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.sign(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def signbit(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.signbit(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def sin(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.sin(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def sinh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.sinh(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def sqrt(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.sqrt(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def square(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.square(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def subtract(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -473,17 +508,20 @@ def subtract(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def tan(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.tan(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def tanh(x: NamedArray, /) -> NamedArray: xp = _get_data_namespace(x) + _dims = x.dims _data = xp.tanh(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) def trunc(x, /): xp = _get_data_namespace(x) + _dims = x.dims _data = xp.trunc(x._data) - return x._new(data=_data) + return NamedArray(_dims, _data) From 92217518ca7972ba65a66df5b01116e161a42f5e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 16:42:15 +0200 Subject: [PATCH 238/367] simplify like functions --- .../_array_api/_creation_functions.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index c7d0378f1f7..b65cdefc287 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -244,7 +244,10 @@ def meshgrid(*arrays: NamedArray, indexing: str = "xy") -> list[NamedArray]: def ones( shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: - return full(shape, 1.0, dtype=dtype, device=device) + xp = _get_namespace_dtype(dtype) + _data = xp.ones(shape, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) def ones_like( @@ -255,9 +258,7 @@ def ones_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _dtype = x.dtype if dtype is None else dtype - _device = x.device if device is None else device - _data = xp.ones(x.shape, dtype=_dtype, device=_device) + _data = xp.ones_like(x, dtype=dtype, device=device) return x._new(data=_data) @@ -284,7 +285,10 @@ def triu( def zeros( shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: - return full(shape, 0.0, dtype=dtype, device=device) + xp = _get_namespace_dtype(dtype) + _data = xp.zeros(shape, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) def zeros_like( @@ -295,7 +299,5 @@ def zeros_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _dtype = x.dtype if dtype is None else dtype - _device = x.device if device is None else device - _data = xp.zeros(x.shape, dtype=_dtype, device=_device) + _data = xp.zeros_like(x, dtype=dtype, device=device) return x._new(data=_data) From 60aa96989961661955597b485850bcc912146e8f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 16:49:11 +0200 Subject: [PATCH 239/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 84 +++++++++++++++----------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 9678640596f..c6975960024 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,6 +1,8 @@ from __future__ import annotations from collections.abc import Iterable +from itertools import zip_longest +import math from types import ModuleType from typing import TYPE_CHECKING, Any @@ -68,8 +70,21 @@ def _infer_dims( shape: _Shape, dims: _DimsLike | Default = _default, ) -> _DimsLike: + """ + Create default dim names if no dims were supplied. + + Examples + -------- + >>> _infer_dims(()) + () + >>> _infer_dims((1,)) + ('dim_0',) + >>> _infer_dims((3, 1)) + ('dim_1', 'dim_0') + """ if dims is _default: - return tuple(f"dim_{n}" for n in range(len(shape))) + ndim = len(shape) + return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim)) else: return dims @@ -199,6 +214,11 @@ def _raise_if_any_duplicate_dimensions( ) +def _isnone(shape: _Shape) -> tuple[bool, ...]: + # TODO: math.isnan should not be needed for array api, but dask still uses np.nan: + return tuple(v is None and math.isnan(v) for v in shape) + + def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: """ Get the expected broadcasted dims. @@ -209,6 +229,8 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) >>> _get_broadcasted_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) + >>> _get_broadcasted_dims(b, a) + (('x', 'y', 'z'), (5, 3, 4)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) @@ -230,37 +252,31 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: >>> _get_broadcasted_dims(a, b) Traceback (most recent call last): ... - ValueError: operands cannot be broadcast together with mismatched lengths for dimension 'x': (5, 2) + ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4)) """ - - def broadcastable(e1: int, e2: int) -> bool: - # out = e1 > 1 and e2 <= 1 - # out |= e2 > 1 and e1 <= 1 - - # out = e1 >= 0 and e2 <= 1 - # out |= e2 >= 0 and e1 <= 1 - - out = e1 <= 1 or e2 <= 1 - - return out - - # validate dimensions - all_dims = {} - for x in arrays: - _dims = x.dims - _raise_if_any_duplicate_dimensions(_dims, err_context="Broadcasting") - - for d, s in zip(_dims, x.shape): - if d not in all_dims: - all_dims[d] = s - elif all_dims[d] != s: - if broadcastable(all_dims[d], s): - max(all_dims[d], s) - else: - raise ValueError( - "operands cannot be broadcast together " - f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}" - ) - - # TODO: Return flag whether broadcasting is needed? - return tuple(all_dims.keys()), tuple(all_dims.values()) + dims = tuple(a.dims for a in arrays) + shapes = tuple(a.shape for a in arrays) + + if len(shapes) == 1: + return shapes[0] + + out_dims = [] + out_shape = [] + for d, sizes in zip( + zip_longest(*map(reversed, dims), fillvalue=_default), + zip_longest(*map(reversed, shapes), fillvalue=-1), + ): + _d = dict.fromkeys(d) + _d.pop(_default, None) + _d = list(_d) + + dim = None if any(_isnone(sizes)) else max(sizes) + + if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1: + raise ValueError( + f"operands could not be broadcast together with {dims = } and {shapes = }" + ) + + out_dims.append(_d[0]) + out_shape.append(dim) + return tuple(reversed(out_dims)), tuple(reversed(out_shape)) From 451ac17342a8bfe2f2bb95bad9a388f0bd27f40b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Sep 2024 14:50:10 +0000 Subject: [PATCH 240/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../_array_api/_creation_functions.py | 6 -- xarray/namedarray/_array_api/_utils.py | 2 +- xarray/namedarray/core.py | 77 ++++++++----------- 3 files changed, 33 insertions(+), 52 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index b65cdefc287..a5750535ca5 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -2,8 +2,6 @@ from typing import Any, overload -import numpy as np - from xarray.namedarray._array_api._utils import ( _get_data_namespace, _get_namespace, @@ -11,7 +9,6 @@ _infer_dims, ) from xarray.namedarray._typing import ( - _arrayfunction_or_api, _ArrayLike, _default, _Device, @@ -24,9 +21,6 @@ from xarray.namedarray.core import ( NamedArray, ) -from xarray.namedarray.utils import ( - to_0d_object_array, -) def _like_args(x, dtype=None, device: _Device | None = None): diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index c6975960024..b914b746c7c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,8 +1,8 @@ from __future__ import annotations +import math from collections.abc import Iterable from itertools import zip_longest -import math from types import ModuleType from typing import TYPE_CHECKING, Any diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index fe4eda7aa65..6dd91d71854 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -458,7 +458,7 @@ def __add__(self, other: int | float | NamedArray, /) -> NamedArray: return add(self, self._maybe_asarray(other)) def __and__(self, other: int | bool | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import asarray, bitwise_and + from xarray.namedarray._array_api import bitwise_and return bitwise_and(self, self._maybe_asarray(other)) @@ -496,7 +496,7 @@ def __dlpack_device__(self, /) -> tuple[IntEnum, int]: return self._data.__dlpack_device__() def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api import asarray, equal + from xarray.namedarray._array_api import equal return equal(self, self._maybe_asarray(other)) @@ -504,17 +504,17 @@ def __float__(self, /) -> float: return self._data.__float__() def __floordiv__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, floor_divide + from xarray.namedarray._array_api import floor_divide return floor_divide(self, self._maybe_asarray(other)) def __ge__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, greater_equal + from xarray.namedarray._array_api import greater_equal return greater_equal(self, self._maybe_asarray(other)) def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: - from xarray.namedarray._array_api._utils import _atleast_0d, _infer_dims + from xarray.namedarray._array_api._utils import _infer_dims if isinstance(key, NamedArray): _key = key._data # TODO: Transpose, unordered dims shouldn't matter. @@ -529,7 +529,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: raise NotImplementedError(f"{key=} is not supported") def __gt__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, greater + from xarray.namedarray._array_api import greater return greater(self, self._maybe_asarray(other)) @@ -551,37 +551,37 @@ def __iter__(self: NamedArray, /): return (asarray(i) for i in self._data) def __le__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, less_equal + from xarray.namedarray._array_api import less_equal return less_equal(self, self._maybe_asarray(other)) def __lshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_left_shift + from xarray.namedarray._array_api import bitwise_left_shift return bitwise_left_shift(self, self._maybe_asarray(other)) def __lt__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, less + from xarray.namedarray._array_api import less return less(self, self._maybe_asarray(other)) def __matmul__(self, other: NamedArray, /): - from xarray.namedarray._array_api import asarray, matmul + from xarray.namedarray._array_api import matmul return matmul(self, self._maybe_asarray(other)) def __mod__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, remainder + from xarray.namedarray._array_api import remainder return remainder(self, self._maybe_asarray(other)) def __mul__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, multiply + from xarray.namedarray._array_api import multiply return multiply(self, self._maybe_asarray(other)) def __ne__(self, other: int | float | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray, not_equal + from xarray.namedarray._array_api import not_equal return not_equal(self, self._maybe_asarray(other)) @@ -591,7 +591,7 @@ def __neg__(self, /): return negative(self) def __or__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_or + from xarray.namedarray._array_api import bitwise_or return bitwise_or(self, self._maybe_asarray(other)) @@ -601,12 +601,12 @@ def __pos__(self, /): return positive(self) def __pow__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, pow + from xarray.namedarray._array_api import pow return pow(self, self._maybe_asarray(other)) def __rshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_right_shift + from xarray.namedarray._array_api import bitwise_right_shift return bitwise_right_shift(self, self._maybe_asarray(other)) @@ -616,68 +616,63 @@ def __setitem__( value: int | float | bool | NamedArray, /, ) -> None: - from xarray.namedarray._array_api import asarray if isinstance(key, NamedArray): key = key._data self._data.__setitem__(key, self._maybe_asarray(value)._data) def __sub__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, subtract + from xarray.namedarray._array_api import subtract return subtract(self, self._maybe_asarray(other)) def __truediv__(self, other: float | NamedArray, /): - from xarray.namedarray._array_api import asarray, divide + from xarray.namedarray._array_api import divide return divide(self, self._maybe_asarray(other)) def __xor__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_xor + from xarray.namedarray._array_api import bitwise_xor return bitwise_xor(self, self._maybe_asarray(other)) def __iadd__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__iadd__(self._maybe_asarray(other)._data) return self def __radd__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import add, asarray + from xarray.namedarray._array_api import add return add(self._maybe_asarray(other), self) def __iand__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__iand__(self._maybe_asarray(other)._data) return self def __rand__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_and + from xarray.namedarray._array_api import bitwise_and return bitwise_and(self._maybe_asarray(other), self) def __ifloordiv__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__ifloordiv__(self._maybe_asarray(other)._data) return self def __rfloordiv__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, floor_divide + from xarray.namedarray._array_api import floor_divide return floor_divide(self._maybe_asarray(other), self) def __ilshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__ilshift__(self._maybe_asarray(other)._data) return self def __rlshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_left_shift + from xarray.namedarray._array_api import bitwise_left_shift return bitwise_left_shift(self._maybe_asarray(other), self) @@ -686,95 +681,87 @@ def __imatmul__(self, other: NamedArray, /): return self def __rmatmul__(self, other: NamedArray, /): - from xarray.namedarray._array_api import asarray, matmul + from xarray.namedarray._array_api import matmul return matmul(self._maybe_asarray(other), self) def __imod__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__imod__(self._maybe_asarray(other)._data) return self def __rmod__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, remainder + from xarray.namedarray._array_api import remainder return remainder(self._maybe_asarray(other), self) def __imul__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__imul__(self._maybe_asarray(other)._data) return self def __rmul__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, multiply + from xarray.namedarray._array_api import multiply return multiply(self._maybe_asarray(other), self) def __ior__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__ior__(self._maybe_asarray(other)._data) return self def __ror__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_or + from xarray.namedarray._array_api import bitwise_or return bitwise_or(self._maybe_asarray(other), self) def __ipow__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__ipow__(self._maybe_asarray(other)._data) return self def __rpow__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, pow + from xarray.namedarray._array_api import pow return pow(self._maybe_asarray(other), self) def __irshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__irshift__(self._maybe_asarray(other)._data) return self def __rrshift__(self, other: int | NamedArray, /): - from xarray.namedarray._array_api import asarray, bitwise_right_shift + from xarray.namedarray._array_api import bitwise_right_shift return bitwise_right_shift(self._maybe_asarray(other), self) def __isub__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__isub__(self._maybe_asarray(other)._data) return self def __rsub__(self, other: int | float | NamedArray, /): - from xarray.namedarray._array_api import asarray, subtract + from xarray.namedarray._array_api import subtract return subtract(self._maybe_asarray(other), self) def __itruediv__(self, other: float | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__itruediv__(self._maybe_asarray(other)._data) return self def __rtruediv__(self, other: float | NamedArray, /): - from xarray.namedarray._array_api import asarray, divide + from xarray.namedarray._array_api import divide return divide(self._maybe_asarray(other), self) def __ixor__(self, other: int | bool | NamedArray, /): - from xarray.namedarray._array_api import asarray self._data.__ixor__(self._maybe_asarray(other)._data) return self def __rxor__(self, other, /): - from xarray.namedarray._array_api import asarray, bitwise_xor + from xarray.namedarray._array_api import bitwise_xor return bitwise_xor(self._maybe_asarray(other), self) From dc070d7c7105a1ea35ff2e1ffc58d63e381b00d7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 18:29:46 +0200 Subject: [PATCH 241/367] Update _creation_functions.py --- xarray/namedarray/_array_api/_creation_functions.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index a5750535ca5..68580c77bb8 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -141,9 +141,7 @@ def empty_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _dtype = x.dtype if dtype is None else dtype - _device = x.device if device is None else device - _data = xp.empty(x.shape, dtype=_dtype, device=_device) + _data = xp.empty(x._data, dtype=dtype, device=device) return x._new(data=_data) @@ -197,9 +195,7 @@ def full_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _dtype = x.dtype if dtype is None else dtype - _device = x.device if device is None else device - _data = xp.full(x.shape, fill_value, dtype=_dtype, device=_device) + _data = xp.full_like(x._data, fill_value, dtype=dtype, device=device) return x._new(data=_data) @@ -252,7 +248,7 @@ def ones_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _data = xp.ones_like(x, dtype=dtype, device=device) + _data = xp.ones_like(x._data, dtype=dtype, device=device) return x._new(data=_data) @@ -293,5 +289,5 @@ def zeros_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _data = xp.zeros_like(x, dtype=dtype, device=device) + _data = xp.zeros_like(x._data, dtype=dtype, device=device) return x._new(data=_data) From 29fb043bd6d4b0c01671cf84ea3e5ed71843ea05 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 18:33:18 +0200 Subject: [PATCH 242/367] Update _creation_functions.py --- xarray/namedarray/_array_api/_creation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 68580c77bb8..72706a8331c 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -141,7 +141,7 @@ def empty_like( device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _data = xp.empty(x._data, dtype=dtype, device=device) + _data = xp.empty_like(x._data, dtype=dtype, device=device) return x._new(data=_data) From 9e36876caa0eea98358252515df96990f425e46b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 18:41:59 +0200 Subject: [PATCH 243/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index b914b746c7c..8be7849f240 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -225,6 +225,10 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: Examples -------- + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> _get_broadcasted_dims(a) + (('x', 'y', 'z'), (5, 3, 4)) + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) >>> _get_broadcasted_dims(a, b) @@ -257,9 +261,6 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: dims = tuple(a.dims for a in arrays) shapes = tuple(a.shape for a in arrays) - if len(shapes) == 1: - return shapes[0] - out_dims = [] out_shape = [] for d, sizes in zip( @@ -279,4 +280,5 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: out_dims.append(_d[0]) out_shape.append(dim) + return tuple(reversed(out_dims)), tuple(reversed(out_shape)) From 9b681c9c4178f4f8d3031dbd7e732206a3938329 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:23:53 +0200 Subject: [PATCH 244/367] add typing --- xarray/namedarray/_typing.py | 38 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 01de017ca47..9462a3ab5de 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -116,7 +116,7 @@ class _IInfo(Protocol): bits: int max: int min: int - dtype: _dtype + dtype: _dtype[Any] class _FInfo(Protocol): @@ -125,7 +125,7 @@ class _FInfo(Protocol): max: float min: float smallest_normal: float - dtype: _dtype + dtype: _dtype[Any] _Capabilities = TypedDict( @@ -135,28 +135,28 @@ class _FInfo(Protocol): _DefaultDataTypes = TypedDict( "_DefaultDataTypes", { - "real floating": _dtype, - "complex floating": _dtype, - "integral": _dtype, - "indexing": _dtype, + "real floating": _dtype[Any], + "complex floating": _dtype[Any], + "integral": _dtype[Any], + "indexing": _dtype[Any], }, ) class _DataTypes(TypedDict, total=False): - bool: _dtype - float32: _dtype - float64: _dtype - complex64: _dtype - complex128: _dtype - int8: _dtype - int16: _dtype - int32: _dtype - int64: _dtype - uint8: _dtype - uint16: _dtype - uint32: _dtype - uint64: _dtype + bool: _dtype[Any] + float32: _dtype[Any] + float64: _dtype[Any] + complex64: _dtype[Any] + complex128: _dtype[Any] + int8: _dtype[Any] + int16: _dtype[Any] + int32: _dtype[Any] + int64: _dtype[Any] + uint8: _dtype[Any] + uint16: _dtype[Any] + uint32: _dtype[Any] + uint64: _dtype[Any] @runtime_checkable From e9649ef2dc65669a4ce01b32278cfccff63363a0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:24:37 +0200 Subject: [PATCH 245/367] remove atleast_0d --- .../_array_api/_statistical_functions.py | 19 +++++++------------ .../_array_api/_utility_functions.py | 5 ++--- xarray/namedarray/_array_api/_utils.py | 10 +--------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 43228d51bed..9e7f4ff0ff4 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -3,7 +3,6 @@ from typing import Any from xarray.namedarray._array_api._utils import ( - _atleast_0d, _dims_to_axis, _get_data_namespace, _get_remaining_dims, @@ -70,7 +69,7 @@ def max( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d(xp.max(x._data, axis=_axis, keepdims=False), xp) + _data = xp.max(x._data, axis=_axis, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) return x._new(dims=dims_, data=data_) @@ -136,7 +135,7 @@ def mean( """ xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d(xp.mean(x._data, axis=_axis, keepdims=False), xp) + _data = xp.mean(x._data, axis=_axis, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -153,7 +152,7 @@ def min( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d(xp.min(x._data, axis=_axis, keepdims=False), xp) + _data = xp.min(x._data, axis=_axis, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -171,7 +170,7 @@ def prod( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d(xp.prod(x._data, axis=_axis, keepdims=False), xp) + _data = xp.prod(x._data, axis=_axis, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -189,9 +188,7 @@ def std( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d( - xp.std(x._data, axis=_axis, correction=correction, keepdims=False), xp - ) + _data = xp.std(x._data, axis=_axis, correction=correction, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -209,7 +206,7 @@ def sum( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d(xp.sum(x._data, axis=_axis, keepdims=False), xp) + _data = xp.sum(x._data, axis=_axis, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) @@ -227,9 +224,7 @@ def var( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = _atleast_0d( - xp.var(x._data, axis=_axis, correction=correction, keepdims=False), xp - ) + _data = xp.var(x._data, axis=_axis, correction=correction, keepdims=False) # TODO: Why do we need to do the keepdims ourselves? dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) out = x._new(dims=dims_, data=data_) diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 350e4216d24..9cd119fcf2c 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -3,7 +3,6 @@ from typing import Any from xarray.namedarray._array_api._utils import ( - _atleast_0d, _dims_to_axis, _get_data_namespace, _get_remaining_dims, @@ -30,7 +29,7 @@ def all( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) - d = _atleast_0d(xp.all(x._data, axis=axis_, keepdims=False), xp) + d = xp.all(x._data, axis=axis_, keepdims=False) dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out @@ -46,7 +45,7 @@ def any( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) - d = _atleast_0d(xp.any(x._data, axis=axis_, keepdims=False), xp) + d = xp.any(x._data, axis=axis_, keepdims=False) dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) out = x._new(dims=dims_, data=data_) return out diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 8be7849f240..19c603b792f 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -50,7 +50,7 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return _get_namespace(x._data) -def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: +def _get_namespace_dtype(dtype: _dtype[Any] | None = None) -> ModuleType: if dtype is None: return _maybe_default_namespace() @@ -195,14 +195,6 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: return tuple(d) -def _atleast_0d(x, xp): - """ - Workaround for numpy sometimes returning scalars instead of 0d arrays. - """ - return xp.asarray(x) - - -# %% def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: From ac1265e3e0beb8e211e65ef37d2f498309964a3b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:25:53 +0200 Subject: [PATCH 246/367] use set instead, if len>1 something is odd is happening --- xarray/namedarray/_array_api/_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 19c603b792f..8c14087cb7b 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -259,9 +259,7 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: zip_longest(*map(reversed, dims), fillvalue=_default), zip_longest(*map(reversed, shapes), fillvalue=-1), ): - _d = dict.fromkeys(d) - _d.pop(_default, None) - _d = list(_d) + _d = tuple(set(d) - {_default}) dim = None if any(_isnone(sizes)) else max(sizes) From 407c8d668f05ba805e6950837348940e49da4ade Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 2 Sep 2024 21:07:14 +0200 Subject: [PATCH 247/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 8c14087cb7b..191baa21056 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -89,7 +89,7 @@ def _infer_dims( return dims -def _normalize_dimensions(dims: _DimsLike) -> _Dims: +def _normalize_dimensions(dims: _Dim | _Dims) -> _Dims: """ Normalize dimensions. @@ -129,12 +129,12 @@ def _dims_to_axis( Examples -------- - >>> narr = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) - >>> _dims_to_axis(narr, ("y",), None) + >>> x = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) + >>> _dims_to_axis(x, ("y",), None) (1,) - >>> _dims_to_axis(narr, _default, 0) + >>> _dims_to_axis(x, _default, 0) (0,) - >>> _dims_to_axis(narr, None, None) + >>> _dims_to_axis(x, None, None) """ _assert_either_dim_or_axis(dims, axis) From aab7b22e86a8b97e9bc6423d5c39a7c54616af36 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 2 Sep 2024 21:44:24 +0200 Subject: [PATCH 248/367] Update _elementwise_functions.py --- .../_array_api/_elementwise_functions.py | 182 +++++++++++------- 1 file changed, 116 insertions(+), 66 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 23ba2c69797..89c591cbbce 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Any import numpy as np @@ -8,6 +9,7 @@ ) from xarray.namedarray._typing import ( _ScalarType, + _DType, _ShapeType, _SupportsImag, _SupportsReal, @@ -15,112 +17,124 @@ from xarray.namedarray.core import NamedArray -def abs(x: NamedArray, /) -> NamedArray: +def abs(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.abs(x._data) return NamedArray(_dims, _data) -def acos(x: NamedArray, /) -> NamedArray: +def acos(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.acos(x._data) return NamedArray(_dims, _data) -def acosh(x: NamedArray, /) -> NamedArray: +def acosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.acosh(x._data) return NamedArray(_dims, _data) -def add(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def add(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.add(x1._data, x2._data) return NamedArray(_dims, _data) -def asin(x: NamedArray, /) -> NamedArray: +def asin(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.asin(x._data) return NamedArray(_dims, _data) -def asinh(x: NamedArray, /) -> NamedArray: +def asinh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.asinh(x._data) return NamedArray(_dims, _data) -def atan(x: NamedArray, /) -> NamedArray: +def atan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.atan(x._data) return NamedArray(_dims, _data) -def atan2(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def atan2( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.atan2(x1._data, x2._data) return NamedArray(_dims, _data) -def atanh(x: NamedArray, /) -> NamedArray: +def atanh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.atanh(x._data) return NamedArray(_dims, _data) -def bitwise_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def bitwise_and( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.bitwise_and(x1._data, x2._data) return NamedArray(_dims, _data) -def bitwise_invert(x: NamedArray, /) -> NamedArray: +def bitwise_invert(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.bitwise_invert(x._data) return NamedArray(_dims, _data) -def bitwise_left_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def bitwise_left_shift( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.bitwise_left_shift(x1._data, x2._data) return NamedArray(_dims, _data) -def bitwise_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def bitwise_or( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.bitwise_or(x1._data, x2._data) return NamedArray(_dims, _data) -def bitwise_right_shift(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def bitwise_right_shift( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.bitwise_right_shift(x1._data, x2._data) return NamedArray(_dims, _data) -def bitwise_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def bitwise_xor( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.bitwise_xor(x1._data, x2._data) return NamedArray(_dims, _data) -def ceil(x: NamedArray, /) -> NamedArray: +def ceil(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.ceil(x._data) @@ -128,102 +142,116 @@ def ceil(x: NamedArray, /) -> NamedArray: def clip( - x: NamedArray, + x: NamedArray[Any, Any], /, - min: int | float | NamedArray | None = None, - max: int | float | NamedArray | None = None, -) -> NamedArray: + min: int | float | NamedArray[Any, Any] | None = None, + max: int | float | NamedArray[Any, Any] | None = None, +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.clip(x._data, min=min, max=max) return NamedArray(_dims, _data) -def conj(x: NamedArray, /) -> NamedArray: +def conj(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.conj(x._data) return NamedArray(_dims, _data) -def copysign(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def copysign( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.copysign(x1._data, x2._data) return NamedArray(_dims, _data) -def cos(x: NamedArray, /) -> NamedArray: +def cos(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.cos(x._data) return NamedArray(_dims, _data) -def cosh(x: NamedArray, /) -> NamedArray: +def cosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.cosh(x._data) return NamedArray(_dims, _data) -def divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def divide( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.divide(x1._data, x2._data) return NamedArray(_dims, _data) -def exp(x: NamedArray, /) -> NamedArray: +def exp(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.exp(x._data) return NamedArray(_dims, _data) -def expm1(x: NamedArray, /) -> NamedArray: +def expm1(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.expm1(x._data) return NamedArray(_dims, _data) -def equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def equal( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.equal(x1._data, x2._data) return NamedArray(_dims, _data) -def floor(x: NamedArray, /) -> NamedArray: +def floor(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.floor(x._data) return NamedArray(_dims, _data) -def floor_divide(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def floor_divide( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.floor_divide(x1._data, x2._data) return NamedArray(_dims, _data) -def greater(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def greater( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.greater(x1._data, x2._data) return NamedArray(_dims, _data) -def greater_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def greater_equal( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.greater_equal(x1._data, x2._data) return NamedArray(_dims, _data) -def hypot(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def hypot( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.hypot(x1._data, x2._data) @@ -263,147 +291,165 @@ def imag( return NamedArray(_dims, _data) -def isfinite(x: NamedArray, /) -> NamedArray: +def isfinite(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isfinite(x._data) return NamedArray(_dims, _data) -def isinf(x: NamedArray, /) -> NamedArray: +def isinf(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isinf(x._data) return NamedArray(_dims, _data) -def isnan(x: NamedArray, /) -> NamedArray: +def isnan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isnan(x._data) return NamedArray(_dims, _data) -def less(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def less(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.less(x1._data, x2._data) return NamedArray(_dims, _data) -def less_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def less_equal( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.less_equal(x1._data, x2._data) return NamedArray(_dims, _data) -def log(x: NamedArray, /) -> NamedArray: +def log(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log(x._data) return NamedArray(_dims, _data) -def log1p(x: NamedArray, /) -> NamedArray: +def log1p(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log1p(x._data) return NamedArray(_dims, _data) -def log2(x: NamedArray, /) -> NamedArray: +def log2(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log2(x._data) return NamedArray(_dims, _data) -def log10(x: NamedArray, /) -> NamedArray: +def log10(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log10(x._data) return NamedArray(_dims, _data) -def logaddexp(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def logaddexp( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.logaddexp(x1._data, x2._data) return NamedArray(_dims, _data) -def logical_and(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def logical_and( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.logical_and(x1._data, x2._data) return NamedArray(_dims, _data) -def logical_not(x: NamedArray, /) -> NamedArray: +def logical_not(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.logical_not(x._data) return NamedArray(_dims, _data) -def logical_or(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def logical_or( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.logical_or(x1._data, x2._data) return NamedArray(_dims, _data) -def logical_xor(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def logical_xor( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.logical_xor(x1._data, x2._data) return NamedArray(_dims, _data) -def maximum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def maximum( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.maximum(x1._data, x2._data) return NamedArray(_dims, _data) -def minimum(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def minimum( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.minimum(x1._data, x2._data) return NamedArray(_dims, _data) -def multiply(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def multiply( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.multiply(x1._data, x2._data) return NamedArray(_dims, _data) -def negative(x: NamedArray, /) -> NamedArray: +def negative(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.negative(x._data) return NamedArray(_dims, _data) -def not_equal(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def not_equal( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.not_equal(x1._data, x2._data) return NamedArray(_dims, _data) -def positive(x: NamedArray, /) -> NamedArray: +def positive(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.positive(x._data) return NamedArray(_dims, _data) -def pow(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def pow(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.pow(x1._data, x2._data) @@ -443,84 +489,88 @@ def real( return NamedArray(_dims, _data) -def remainder(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def remainder( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.remainder(x1._data, x2._data) return NamedArray(_dims, _data) -def round(x: NamedArray, /) -> NamedArray: +def round(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.round(x._data) return NamedArray(_dims, _data) -def sign(x: NamedArray, /) -> NamedArray: +def sign(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sign(x._data) return NamedArray(_dims, _data) -def signbit(x: NamedArray, /) -> NamedArray: +def signbit(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.signbit(x._data) return NamedArray(_dims, _data) -def sin(x: NamedArray, /) -> NamedArray: +def sin(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sin(x._data) return NamedArray(_dims, _data) -def sinh(x: NamedArray, /) -> NamedArray: +def sinh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sinh(x._data) return NamedArray(_dims, _data) -def sqrt(x: NamedArray, /) -> NamedArray: +def sqrt(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sqrt(x._data) return NamedArray(_dims, _data) -def square(x: NamedArray, /) -> NamedArray: +def square(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.square(x._data) return NamedArray(_dims, _data) -def subtract(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def subtract( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _dims, _ = _get_broadcasted_dims(x1, x2) _data = xp.subtract(x1._data, x2._data) return NamedArray(_dims, _data) -def tan(x: NamedArray, /) -> NamedArray: +def tan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.tan(x._data) return NamedArray(_dims, _data) -def tanh(x: NamedArray, /) -> NamedArray: +def tanh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.tanh(x._data) return NamedArray(_dims, _data) -def trunc(x, /): +def trunc(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.trunc(x._data) From d82a7ccb4f478bc090de256bb8921a49b57192cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:45:09 +0000 Subject: [PATCH 249/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_elementwise_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 89c591cbbce..0c395c22143 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Any import numpy as np @@ -8,8 +9,8 @@ _get_data_namespace, ) from xarray.namedarray._typing import ( - _ScalarType, _DType, + _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, From f73e7a100fbfdafdfc1f6cde47b76ca029d5f788 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 2 Sep 2024 23:08:02 +0200 Subject: [PATCH 250/367] typing --- xarray/namedarray/_array_api/_utils.py | 54 ++++++++++++++++++++------ xarray/namedarray/_typing.py | 2 +- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 191baa21056..86b7db892cd 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeGuard, cast, overload from xarray.namedarray._typing import ( Default, @@ -89,6 +89,11 @@ def _infer_dims( return dims +def _is_single_dim(dims: _Dim | _Dims) -> TypeGuard[_Dim]: + # TODO: https://peps.python.org/pep-0742/ + return isinstance(dims, str) or not isinstance(dims, Iterable) + + def _normalize_dimensions(dims: _Dim | _Dims) -> _Dims: """ Normalize dimensions. @@ -108,10 +113,10 @@ def _normalize_dimensions(dims: _Dim | _Dims) -> _Dims: >>> _normalize_dimensions([("time", "x", "y")]) (('time', 'x', 'y'),) """ - if isinstance(dims, str) or not isinstance(dims, Iterable): + if _is_single_dim(dims): return (dims,) - - return tuple(dims) + else: + return tuple(cast(_Dims, dims)) def _assert_either_dim_or_axis( @@ -121,6 +126,14 @@ def _assert_either_dim_or_axis( raise ValueError("cannot supply both 'axis' and 'dim(s)' arguments") +@overload +def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: None) -> None: ... +@overload +def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: _AxisLike) -> _Axes: ... +@overload +def _dims_to_axis( + x: NamedArray[Any, Any], dims: _Dim | _Dims, axis: _AxisLike +) -> _Axes: ... def _dims_to_axis( x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None ) -> _Axes | None: @@ -129,18 +142,35 @@ def _dims_to_axis( Examples -------- + + Convert to dims to axis values + >>> x = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) >>> _dims_to_axis(x, ("y",), None) (1,) >>> _dims_to_axis(x, _default, 0) (0,) + >>> _dims_to_axis(x, _default, None) + + Using Hashable dims + + >>> x = NamedArray(("x", None), np.array([[1, 2, 3], [5, 6, 7]])) >>> _dims_to_axis(x, None, None) + (1,) + + Defining both dims and axis raises an error + + >>> _dims_to_axis(x, "x", 1) + Traceback (most recent call last): + ... + ValueError: cannot supply both 'axis' and 'dim(s)' arguments """ _assert_either_dim_or_axis(dims, axis) - if dims is not _default: + _dims = _normalize_dimensions(dims) + axis = () - for dim in dims: + for dim in _dims: try: axis = (x.dims.index(dim),) except ValueError: @@ -211,7 +241,7 @@ def _isnone(shape: _Shape) -> tuple[bool, ...]: return tuple(v is None and math.isnan(v) for v in shape) -def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: +def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]: """ Get the expected broadcasted dims. @@ -253,8 +283,8 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: dims = tuple(a.dims for a in arrays) shapes = tuple(a.shape for a in arrays) - out_dims = [] - out_shape = [] + out_dims: tuple[_Dim, ...] = () + out_shape: tuple[_Axis | None, ...] = () for d, sizes in zip( zip_longest(*map(reversed, dims), fillvalue=_default), zip_longest(*map(reversed, shapes), fillvalue=-1), @@ -268,7 +298,7 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: f"operands could not be broadcast together with {dims = } and {shapes = }" ) - out_dims.append(_d[0]) - out_shape.append(dim) + out_dims += (_d[0],) + out_shape += (dim,) - return tuple(reversed(out_dims)), tuple(reversed(out_shape)) + return out_dims[::-1], out_shape[::-1] diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 9462a3ab5de..b3477874947 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -101,7 +101,7 @@ def imag(self) -> _T_co: ... # TODO: np.array_api was bugged and didn't allow (None,), but should! # https://github.com/numpy/numpy/pull/25022 # https://github.com/data-apis/array-api/pull/674 -_IndexKey = Union[int, slice, "ellipsis"] +_IndexKey = Union[int, slice, "ellipsis", None] _IndexKeys = tuple[_IndexKey, ...] # tuple[Union[_IndexKey, None], ...] _IndexKeyLike = Union[_IndexKey, _IndexKeys] From b026b3c713010ee20995d92aec536cf98e2d01b3 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 2 Sep 2024 23:16:00 +0200 Subject: [PATCH 251/367] Update _info.py --- xarray/namedarray/_array_api/_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_info.py b/xarray/namedarray/_array_api/_info.py index b9d4749faec..9876a2fea01 100644 --- a/xarray/namedarray/_array_api/_info.py +++ b/xarray/namedarray/_array_api/_info.py @@ -136,7 +136,7 @@ def dtypes( "complex128": complex128, } if isinstance(kind, tuple): - res = {} + res: _DataTypes = {} for k in kind: res.update(dtypes(kind=k)) return res From 8d3f19261f3f42c6e53888f58ad24912a9019ba9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 2 Sep 2024 23:28:24 +0200 Subject: [PATCH 252/367] methods --- xarray/namedarray/_typing.py | 8 ++++++++ xarray/namedarray/core.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index dfa30369caa..5a4bc554706 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -260,6 +260,14 @@ def __getitem__( def __array_namespace__(self) -> ModuleType: ... + def to_device(self, device: _Device, /, stream: None = None) -> Self: ... + + @property + def device(self) -> _Device: ... + + @property + def mT(self) -> _arrayapi[Any, _DType_co]: ... + # NamedArray can most likely use both __array_function__ and __array_namespace__: _arrayfunction_or_api = (_arrayfunction, _arrayapi) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 0eddc014cb1..aa03ee4e409 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -799,7 +799,7 @@ def device(self) -> _Device: raise NotImplementedError("self._data missing device") @property - def mT(self): + def mT(self) -> NamedArray[Any, _DType_co]: if isinstance(self._data, _arrayapi): from xarray.namedarray._array_api._utils import _infer_dims From 221ba483e807f515b69f62b35e83f0557e60a8d2 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 2 Sep 2024 23:47:08 +0200 Subject: [PATCH 253/367] Update _manipulation_functions.py --- .../_array_api/_manipulation_functions.py | 72 +++++++++++-------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 10da4006472..394a82655cb 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -19,11 +19,12 @@ _Dim, _DType, _Shape, + _ShapeType, ) from xarray.namedarray.core import NamedArray -def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: +def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any]]: """ Broadcasts one or more arrays against one another. @@ -41,7 +42,9 @@ def broadcast_arrays(*arrays: NamedArray) -> list[NamedArray]: return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas)] -def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray: +def broadcast_to( + x: NamedArray[Any, _DType], /, shape: _ShapeType +) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _data = xp.broadcast_to(x._data, shape=shape) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -49,12 +52,15 @@ def broadcast_to(x: NamedArray, /, shape: _Shape) -> NamedArray: def concat( - arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis | None = 0 -) -> NamedArray: + arrays: tuple[NamedArray[Any, Any], ...] | list[NamedArray[Any, Any]], + /, + *, + axis: _Axis | None = 0, +) -> NamedArray[Any, Any]: xp = _get_data_namespace(arrays[0]) dtype = result_type(*arrays) - arrays = tuple(a._data for a in arrays) - _data = xp.concat(arrays, axis=axis, dtype=dtype) + _arrays = tuple(a._data for a in arrays) + _data = xp.concat(_arrays, axis=axis, dtype=dtype) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -102,14 +108,18 @@ def expand_dims( return x._new(_dims, _data) -def flip(x: NamedArray, /, *, axis: _Axes | None = None) -> NamedArray: +def flip( + x: NamedArray[_ShapeType, _DType], /, *, axis: _Axes | None = None +) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _data = xp.flip(x._data, axis=axis) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def moveaxis(x: NamedArray, source: _Axes, destination: _Axes, /) -> NamedArray: +def moveaxis( + x: NamedArray[Any, _DType], source: _Axes, destination: _Axes, / +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.moveaxis(x._data, source=source, destination=destination) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -146,41 +156,40 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT def repeat( - x: NamedArray, - repeats: int | NamedArray, + x: NamedArray[Any, _DType], + repeats: int | NamedArray[Any, Any], /, *, axis: _Axis | None = None, -) -> NamedArray: +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.repeat(x._data, repeats, axis=axis) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def reshape(x, /, shape: _Shape, *, copy: bool | None = None): +def reshape( + x: NamedArray[Any, _DType], /, shape: _ShapeType, *, copy: bool | None = None +) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _data = xp.reshape(x._data, shape) - out = asarray(_data, copy=copy) - # TODO: Have better control where the dims went. - # TODO: If reshaping should we save the dims? - # TODO: What's the xarray equivalent? - return out + _data = xp.reshape(x._data, shape, copy=copy) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) def roll( - x: NamedArray, + x: NamedArray[_ShapeType, _DType], /, shift: int | tuple[int, ...], *, axis: _Axes | None = None, -) -> NamedArray: +) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _data = xp.roll(x._data, shift=shift, axis=axis) return x._new(data=_data) -def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: +def squeeze(x: NamedArray[Any, _DType], /, axis: _Axes) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.squeeze(x._data, axis=axis) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -188,27 +197,34 @@ def squeeze(x: NamedArray, /, axis: _Axes) -> NamedArray: def stack( - arrays: tuple[NamedArray, ...] | list[NamedArray], /, *, axis: _Axis = 0 -) -> NamedArray: + arrays: tuple[NamedArray[Any, Any], ...] | list[NamedArray[Any, Any]], + /, + *, + axis: _Axis = 0, +) -> NamedArray[Any, Any]: x = arrays[0] xp = _get_data_namespace(x) - arrays = tuple(a._data for a in arrays) - _data = xp.stack(arrays, axis=axis) + _arrays = tuple(a._data for a in arrays) + _data = xp.stack(_arrays, axis=axis) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def tile(x: NamedArray, repetitions: tuple[int, ...], /) -> NamedArray: +def tile( + x: NamedArray[Any, _DType], repetitions: tuple[int, ...], / +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.tile(x._data, repetitions) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def unstack(x: NamedArray, /, *, axis: _Axis = 0) -> tuple[NamedArray, ...]: +def unstack( + x: NamedArray[Any, Any], /, *, axis: _Axis = 0 +) -> tuple[NamedArray[Any, Any], ...]: xp = _get_data_namespace(x) _datas = xp.unstack(x._data, axis=axis) - out = () + out: tuple[NamedArray[Any, Any], ...] = () for _data in _datas: _dims = _infer_dims(_data.shape) # TODO: Fix dims out += (x._new(_dims, _data),) From 157c46da2d05db1f38a8c9b17093544495e49969 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 21:47:45 +0000 Subject: [PATCH 254/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 394a82655cb..0799cec58c6 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -2,7 +2,6 @@ from typing import Any -from xarray.namedarray._array_api._creation_functions import asarray from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _get_broadcasted_dims, @@ -18,7 +17,6 @@ _default, _Dim, _DType, - _Shape, _ShapeType, ) from xarray.namedarray.core import NamedArray From 26182d29b701c25fb3b76940b7ca2c850740fcbc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:10:41 +0200 Subject: [PATCH 255/367] Update _fft.py --- xarray/namedarray/_array_api/_fft/_fft.py | 60 +++++++++++++---------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/xarray/namedarray/_array_api/_fft/_fft.py b/xarray/namedarray/_array_api/_fft/_fft.py index 46e1fc2c4c5..ef8c8a2e999 100644 --- a/xarray/namedarray/_array_api/_fft/_fft.py +++ b/xarray/namedarray/_array_api/_fft/_fft.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Any from xarray.namedarray._array_api._utils import ( _get_data_namespace, @@ -11,19 +11,19 @@ from xarray.namedarray.core import NamedArray if TYPE_CHECKING: - from xarray.namedarray._typing import _Axes, _Axis, _Device + from xarray.namedarray._typing import _Axes, _Axis, _Device, _DType _Norm = Literal["backward", "ortho", "forward"] def fft( - x: NamedArray, + x: NamedArray[Any, _DType], /, *, n: int | None = None, axis: _Axis = -1, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.fft.fft(x._data, n=n, axis=axis, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -31,13 +31,13 @@ def fft( def ifft( - x: NamedArray, + x: NamedArray[Any, _DType], /, *, n: int | None = None, axis: _Axis = -1, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.fft.ifft(x._data, n=n, axis=axis, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -45,13 +45,13 @@ def ifft( def fftn( - x: NamedArray, + x: NamedArray[Any, _DType], /, *, s: Sequence[int] | None = None, axes: Sequence[int] | None = None, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.fft.fftn(x._data, s=s, axes=axes, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -59,13 +59,13 @@ def fftn( def ifftn( - x: NamedArray, + x: NamedArray[Any, _DType], /, *, s: Sequence[int] | None = None, axes: Sequence[int] | None = None, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.fft.ifftn(x._data, s=s, axes=axes, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -73,13 +73,13 @@ def ifftn( def rfft( - x: NamedArray, + x: NamedArray[Any, Any], /, *, n: int | None = None, axis: _Axis = -1, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.fft.rfft(x._data, n=n, axis=axis, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -87,13 +87,13 @@ def rfft( def irfft( - x: NamedArray, + x: NamedArray[Any, Any], /, *, n: int | None = None, axis: _Axis = -1, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.fft.irfft(x._data, n=n, axis=axis, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -101,13 +101,13 @@ def irfft( def rfftn( - x: NamedArray, + x: NamedArray[Any, Any], /, *, s: Sequence[int] | None = None, axes: Sequence[int] | None = None, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.fft.rfftn(x._data, s=s, axes=axes, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -115,13 +115,13 @@ def rfftn( def irfftn( - x: NamedArray, + x: NamedArray[Any, Any], /, *, s: Sequence[int] | None = None, axes: Sequence[int] | None = None, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.fft.irfftn(x._data, s=s, axes=axes, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -129,13 +129,13 @@ def irfftn( def hfft( - x: NamedArray, + x: NamedArray[Any, Any], /, *, n: int | None = None, axis: _Axis = -1, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.fft.hfft(x._data, n=n, axis=axis, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -143,41 +143,49 @@ def hfft( def ihfft( - x: NamedArray, + x: NamedArray[Any, Any], /, *, n: int | None = None, axis: _Axis = -1, norm: _Norm = "backward", -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.fft.ihfft(x._data, n=n, axis=axis, norm=norm) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def fftfreq(n: int, /, *, d: float = 1.0, device: _Device | None = None) -> NamedArray: +def fftfreq( + n: int, /, *, d: float = 1.0, device: _Device | None = None +) -> NamedArray[Any, Any]: xp = _maybe_default_namespace() # TODO: Can use device? _data = xp.fft.fftfreq(n, d=d, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) -def rfftfreq(n: int, /, *, d: float = 1.0, device: _Device | None = None) -> NamedArray: +def rfftfreq( + n: int, /, *, d: float = 1.0, device: _Device | None = None +) -> NamedArray[Any, Any]: xp = _maybe_default_namespace() # TODO: Can use device? _data = xp.fft.rfftfreq(n, d=d, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) -def fftshift(x: NamedArray, /, *, axes: _Axes | None = None) -> NamedArray: +def fftshift( + x: NamedArray[Any, _DType], /, *, axes: _Axes | None = None +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.fft.fftshift(x._data, axes=axes) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def ifftshift(x: NamedArray, /, *, axes: _Axes | None = None) -> NamedArray: +def ifftshift( + x: NamedArray[Any, _DType], /, *, axes: _Axes | None = None +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.fft.ifftshift(x._data, axes=axes) _dims = _infer_dims(_data.shape) # TODO: Fix dims From 212dcdd641f5c202977a87bebf064f7c09fa6597 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 22:11:40 +0000 Subject: [PATCH 256/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_fft/_fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_fft/_fft.py b/xarray/namedarray/_array_api/_fft/_fft.py index ef8c8a2e999..9830581080f 100644 --- a/xarray/namedarray/_array_api/_fft/_fft.py +++ b/xarray/namedarray/_array_api/_fft/_fft.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, Any +from typing import TYPE_CHECKING, Any, Literal from xarray.namedarray._array_api._utils import ( _get_data_namespace, From d31fe9502e64aa3df0c55cd7f897a670342acb95 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:35:34 +0200 Subject: [PATCH 257/367] Update _linalg.py --- .../namedarray/_array_api/_linalg/_linalg.py | 140 +++++++++++------- 1 file changed, 83 insertions(+), 57 deletions(-) diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py index 9470436a3ae..0c20ab03dc7 100644 --- a/xarray/namedarray/_array_api/_linalg/_linalg.py +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -1,37 +1,39 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, NamedTuple +from typing import TYPE_CHECKING, Literal, NamedTuple, Any, overload from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims from xarray.namedarray.core import NamedArray if TYPE_CHECKING: - from xarray.namedarray._typing import _Axes, _Axis, _DType + from xarray.namedarray._typing import _Axes, _Axis, _DType, _ShapeType class EighResult(NamedTuple): - eigenvalues: NamedArray - eigenvectors: NamedArray + eigenvalues: NamedArray[Any, Any] + eigenvectors: NamedArray[Any, Any] class QRResult(NamedTuple): - Q: NamedArray - R: NamedArray + Q: NamedArray[Any, Any] + R: NamedArray[Any, Any] class SlogdetResult(NamedTuple): - sign: NamedArray - logabsdet: NamedArray + sign: NamedArray[Any, Any] + logabsdet: NamedArray[Any, Any] class SVDResult(NamedTuple): - U: NamedArray - S: NamedArray - Vh: NamedArray + U: NamedArray[Any, Any] + S: NamedArray[Any, Any] + Vh: NamedArray[Any, Any] -def cholesky(x: NamedArray, /, *, upper: bool = False) -> NamedArray: +def cholesky( + x: NamedArray[_ShapeType, Any], /, *, upper: bool = False +) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _data = xp.linalg.cholesky(x._data, upper=upper) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -39,21 +41,25 @@ def cholesky(x: NamedArray, /, *, upper: bool = False) -> NamedArray: # Note: cross is the numpy top-level namespace, not np.linalg -def cross(x1: NamedArray, x2: NamedArray, /, *, axis: _Axis = -1) -> NamedArray: +def cross( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /, *, axis: _Axis = -1 +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.linalg.cross(x1._data, x2._data, axis=axis) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x1._new(_dims, _data) -def det(x: NamedArray, /) -> NamedArray: +def det(x: NamedArray[Any, _DType], /) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.linalg.det(x._data) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def diagonal(x: NamedArray, /, *, offset: int = 0) -> NamedArray: +def diagonal( + x: NamedArray[Any, _DType], /, *, offset: int = 0 +) -> NamedArray[Any, _DType]: # Note: diagonal is the numpy top-level namespace, not np.linalg xp = _get_data_namespace(x) _data = xp.linalg.diagonal(x._data, offset=offset) @@ -61,7 +67,7 @@ def diagonal(x: NamedArray, /, *, offset: int = 0) -> NamedArray: return x._new(_dims, _data) -def eigh(x: NamedArray, /) -> EighResult: +def eigh(x: NamedArray[Any, Any], /) -> EighResult: xp = _get_data_namespace(x) eigvals, eigvecs = xp.linalg.eigh(x._data) _dims_vals = _infer_dims(eigvals.shape) # TODO: Fix dims @@ -72,14 +78,14 @@ def eigh(x: NamedArray, /) -> EighResult: ) -def eigvalsh(x: NamedArray, /) -> NamedArray: +def eigvalsh(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.linalg.eigvalsh(x._data) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def inv(x: NamedArray, /) -> NamedArray: +def inv(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _data = xp.linalg.inv(x._data) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -87,19 +93,21 @@ def inv(x: NamedArray, /) -> NamedArray: def matrix_norm( - x: NamedArray, + x: NamedArray[Any, Any], /, *, keepdims: bool = False, ord: int | float | Literal["fro", "nuc"] | None = "fro", -) -> NamedArray: # noqa: F821 +) -> NamedArray[Any, Any]: # noqa: F821 xp = _get_data_namespace(x) _data = xp.linalg.matrix_norm(x._data, keepdims=keepdims, ord=ord) # ckeck xp.mean _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def matrix_power(x: NamedArray, n: int, /) -> NamedArray: +def matrix_power( + x: NamedArray[_ShapeType, Any], n: int, / +) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _data = xp.linalg.matrix_power(x._data, n=n) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -107,22 +115,32 @@ def matrix_power(x: NamedArray, n: int, /) -> NamedArray: def matrix_rank( - x: NamedArray, /, *, rtol: float | NamedArray | None = None -) -> NamedArray: + x: NamedArray[Any, Any], /, *, rtol: float | NamedArray[Any, Any] | None = None +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.linalg.matrix_rank(x._data, rtol=rtol) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def outer(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def matrix_transpose(x: NamedArray[Any, _DType], /) -> NamedArray[Any, _DType]: + from xarray.namedarray._array_api._linear_algebra_functions import matrix_transpose + + return matrix_transpose(x) + + +def outer( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.linalg.outer(x1._data, x2._data) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x1._new(_dims, _data) -def pinv(x: NamedArray, /, *, rtol: float | NamedArray | None = None) -> NamedArray: +def pinv( + x: NamedArray[Any, Any], /, *, rtol: float | NamedArray[Any, Any] | None = None +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.linalg.pinv(x._data, rtol=rtol) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -130,7 +148,7 @@ def pinv(x: NamedArray, /, *, rtol: float | NamedArray | None = None) -> NamedAr def qr( - x: NamedArray, /, *, mode: Literal["reduced", "complete"] = "reduced" + x: NamedArray[Any, Any], /, *, mode: Literal["reduced", "complete"] = "reduced" ) -> QRResult: xp = _get_data_namespace(x) q, r = xp.linalg.qr(x._data) @@ -142,7 +160,7 @@ def qr( ) -def slogdet(x: NamedArray, /) -> SlogdetResult: +def slogdet(x: NamedArray[Any, Any], /) -> SlogdetResult: xp = _get_data_namespace(x) sign, logabsdet = xp.linalg.slogdet(x._data) _dims_sign = _infer_dims(sign.shape) # TODO: Fix dims @@ -153,14 +171,16 @@ def slogdet(x: NamedArray, /) -> SlogdetResult: ) -def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def solve( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.linalg.solve(x1._data, x2._data) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x1._new(_dims, _data) -def svd(x: NamedArray, /, *, full_matrices: bool = True) -> SVDResult: +def svd(x: NamedArray[Any, Any], /, *, full_matrices: bool = True) -> SVDResult: xp = _get_data_namespace(x) u, s, vh = xp.linalg.svd(x._data, full_matrices=full_matrices) _dims_u = _infer_dims(u.shape) # TODO: Fix dims @@ -173,61 +193,67 @@ def svd(x: NamedArray, /, *, full_matrices: bool = True) -> SVDResult: ) -def svdvals(x: NamedArray, /) -> NamedArray: +def svdvals(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.linalg.svdvals(x._data) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) +@overload def trace( - x: NamedArray, /, *, offset: int = 0, dtype: _DType | None = None -) -> NamedArray: + x: NamedArray[Any, Any], /, *, offset: int = 0, dtype: _DType +) -> NamedArray[Any, _DType]: ... +@overload +def trace( + x: NamedArray[Any, _DType], /, *, offset: int = 0, dtype: None +) -> NamedArray[Any, _DType]: ... +def trace( + x: NamedArray[Any, _DType | Any], /, *, offset: int = 0, dtype: _DType | None = None +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _data = xp.linalg.trace(x._data, offset=offset) _dims = _infer_dims(_data.shape) # TODO: Fix dims return x._new(_dims, _data) -def vector_norm( - x: NamedArray, - /, - *, - axis: _Axes | None = None, - keepdims: bool = False, - ord: int | float | None = 2, -) -> NamedArray: - xp = _get_data_namespace(x) - _data = xp.linalg.vector_norm(x._data, axis=axis, keepdims=keepdims, ord=ord) - _dims = _infer_dims(_data.shape) # TODO: Fix dims - return x._new(_dims, _data) - - -def matmul(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def matmul( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: from xarray.namedarray._array_api._linear_algebra_functions import matmul return matmul(x1, x2) def tensordot( - x1: NamedArray, - x2: NamedArray, + x1: NamedArray[Any, Any], + x2: NamedArray[Any, Any], /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, -) -> NamedArray: +) -> NamedArray[Any, Any]: from xarray.namedarray._array_api._linear_algebra_functions import tensordot return tensordot(x1, x2, axes=axes) -def matrix_transpose(x: NamedArray, /) -> NamedArray: - from xarray.namedarray._array_api._linear_algebra_functions import matrix_transpose - - return matrix_transpose(x) - - -def vecdot(x1: NamedArray, x2: NamedArray, /, *, axis: _Axis = -1) -> NamedArray: +def vecdot( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /, *, axis: _Axis = -1 +) -> NamedArray[Any, Any]: from xarray.namedarray._array_api._linear_algebra_functions import vecdot return vecdot(x1, x2, axis=axis) + + +def vector_norm( + x: NamedArray[Any, Any], + /, + *, + axis: _Axes | None = None, + keepdims: bool = False, + ord: int | float | None = 2, +) -> NamedArray[Any, Any]: + xp = _get_data_namespace(x) + _data = xp.linalg.vector_norm(x._data, axis=axis, keepdims=keepdims, ord=ord) + _dims = _infer_dims(_data.shape) # TODO: Fix dims + return x._new(_dims, _data) From d05791c4412bded1f0b5419241bdfb5b10551990 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 22:36:10 +0000 Subject: [PATCH 258/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_linalg/_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py index 0c20ab03dc7..1a1c2ca2004 100644 --- a/xarray/namedarray/_array_api/_linalg/_linalg.py +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Literal, NamedTuple, Any, overload +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, overload from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims from xarray.namedarray.core import NamedArray From 5fccd5f7b9602978bf663e94c9e70bb905c3c0c5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:40:28 +0200 Subject: [PATCH 259/367] Update _typing.py --- xarray/namedarray/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 5a4bc554706..58fa0448063 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -260,7 +260,9 @@ def __getitem__( def __array_namespace__(self) -> ModuleType: ... - def to_device(self, device: _Device, /, stream: None = None) -> Self: ... + def to_device( + self, device: _Device, /, stream: None = None + ) -> _arrayapi[_ShapeType_co, _DType_co]: ... @property def device(self) -> _Device: ... From 39fa648bc8a0b7b8cf1659463373be2699202c7f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:56:15 +0200 Subject: [PATCH 260/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 86b7db892cd..7e08400d4c4 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -283,15 +283,19 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] dims = tuple(a.dims for a in arrays) shapes = tuple(a.shape for a in arrays) - out_dims: tuple[_Dim, ...] = () - out_shape: tuple[_Axis | None, ...] = () + out_dims: _Dims = () + out_shape: _Shape = () for d, sizes in zip( zip_longest(*map(reversed, dims), fillvalue=_default), zip_longest(*map(reversed, shapes), fillvalue=-1), ): _d = tuple(set(d) - {_default}) - dim = None if any(_isnone(sizes)) else max(sizes) + if any(_isnone(sizes)): + # dim = None + raise NotImplementedError("TODO: Handle None in shape, {shapes = }") + else: + dim = max(sizes) if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1: raise ValueError( From d37c44d64a0236e43eb823319d0e2538a46f6e92 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 21:23:21 +0200 Subject: [PATCH 261/367] Make default unhashable --- xarray/namedarray/_typing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 58fa0448063..31b78034a50 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -34,6 +34,7 @@ # Singleton type, as per https://github.com/python/typing/pull/240 class Default(Enum): + __hash__ = None token: Final = 0 @@ -95,6 +96,7 @@ def imag(self) -> _T_co: ... _Dim = Hashable _Dims = tuple[_Dim, ...] +_DimsLike2 = Union[_Dim, _Dims] _DimsLike = Union[str, Iterable[_Dim]] # https://data-apis.org/array-api/latest/API_specification/indexing.html From beb3016d7997ad436b2ea599f3f64e01bf4d3009 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 21:23:52 +0200 Subject: [PATCH 262/367] typing --- .../_array_api/_utility_functions.py | 14 ++-- xarray/namedarray/_array_api/_utils.py | 83 +++++++++++-------- 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 9cd119fcf2c..77965510fac 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -11,7 +11,9 @@ Default, _AxisLike, _default, + _Dim, _Dims, + _DimsLike2, _DType, ) from xarray.namedarray.core import ( @@ -20,13 +22,13 @@ def all( - x, + x: NamedArray[Any, Any], /, *, - dims: _Dims | Default = _default, + dims: _DimsLike2 | Default = _default, keepdims: bool = False, axis: _AxisLike | None = None, -) -> NamedArray[Any, _DType]: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) d = xp.all(x._data, axis=axis_, keepdims=False) @@ -36,13 +38,13 @@ def all( def any( - x, + x: NamedArray[Any, Any], /, *, - dims: _Dims | Default = _default, + dims: _DimsLike2 | Default = _default, keepdims: bool = False, axis: _AxisLike | None = None, -) -> NamedArray[Any, _DType]: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) axis_ = _dims_to_axis(x, dims, axis) d = xp.any(x._data, axis=axis_, keepdims=False) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 7e08400d4c4..09f8248afd3 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Any, TypeGuard, cast, overload +from typing import TYPE_CHECKING, Any, TypeGuard, cast, overload, NoReturn from xarray.namedarray._typing import ( Default, @@ -16,6 +16,7 @@ _Dim, _Dims, _DimsLike, + _DimsLike2, _DType, _dtype, _Shape, @@ -66,35 +67,12 @@ def _get_namespace_dtype(dtype: _dtype[Any] | None = None) -> ModuleType: return xp -def _infer_dims( - shape: _Shape, - dims: _DimsLike | Default = _default, -) -> _DimsLike: - """ - Create default dim names if no dims were supplied. - - Examples - -------- - >>> _infer_dims(()) - () - >>> _infer_dims((1,)) - ('dim_0',) - >>> _infer_dims((3, 1)) - ('dim_1', 'dim_0') - """ - if dims is _default: - ndim = len(shape) - return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim)) - else: - return dims - - -def _is_single_dim(dims: _Dim | _Dims) -> TypeGuard[_Dim]: +def _is_single_dim(dims: _DimsLike2) -> TypeGuard[_Dim]: # TODO: https://peps.python.org/pep-0742/ return isinstance(dims, str) or not isinstance(dims, Iterable) -def _normalize_dimensions(dims: _Dim | _Dims) -> _Dims: +def _normalize_dimensions(dims: _DimsLike2) -> _Dims: """ Normalize dimensions. @@ -119,6 +97,36 @@ def _normalize_dimensions(dims: _Dim | _Dims) -> _Dims: return tuple(cast(_Dims, dims)) +def _infer_dims( + shape: _Shape, + dims: _DimsLike2 | Default = _default, +) -> _Dims: + """ + Create default dim names if no dims were supplied. + + Examples + -------- + >>> _infer_dims(()) + () + >>> _infer_dims((1,)) + ('dim_0',) + >>> _infer_dims((3, 1)) + ('dim_1', 'dim_0') + + >>> _infer_dims((1,), "x") + ('x',) + >>> _infer_dims((1,), None) + (None,) + >>> _infer_dims((1,), ("x",)) + ('x',) + """ + if dims is _default: + ndim = len(shape) + return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim)) + else: + return _normalize_dimensions(dims) + + def _assert_either_dim_or_axis( dims: _Dim | _Dims | Default, axis: _AxisLike | None ) -> None: @@ -127,15 +135,17 @@ def _assert_either_dim_or_axis( @overload -def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: None) -> None: ... +def _dims_to_axis( + x: NamedArray[Any, Any], dims: _DimsLike2, axis: _AxisLike +) -> NoReturn: ... @overload -def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: _AxisLike) -> _Axes: ... +def _dims_to_axis(x: NamedArray[Any, Any], dims: _DimsLike2, axis: None) -> None: ... @overload +def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: _AxisLike) -> None: ... +@overload +def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: None) -> None: ... def _dims_to_axis( - x: NamedArray[Any, Any], dims: _Dim | _Dims, axis: _AxisLike -) -> _Axes: ... -def _dims_to_axis( - x: NamedArray[Any, Any], dims: _Dim | _Dims | Default, axis: _AxisLike | None + x: NamedArray[Any, Any], dims: _DimsLike2 | Default, axis: _AxisLike | None ) -> _Axes | None: """ Convert dims to axis indices. @@ -177,10 +187,13 @@ def _dims_to_axis( raise ValueError(f"{dim!r} not found in array dimensions {x.dims!r}") return axis - if isinstance(axis, int): - return (axis,) + if axis is None: + return axis - return axis + if isinstance(axis, tuple): + return axis + else: + return (axis,) def _get_remaining_dims( From e993f3b5a39ef6eeacb46b345ac43609d969f9d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 19:24:38 +0000 Subject: [PATCH 263/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utility_functions.py | 3 --- xarray/namedarray/_array_api/_utils.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 77965510fac..7ed2c631983 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -11,10 +11,7 @@ Default, _AxisLike, _default, - _Dim, - _Dims, _DimsLike2, - _DType, ) from xarray.namedarray.core import ( NamedArray, diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 09f8248afd3..bb12afcd6eb 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Any, TypeGuard, cast, overload, NoReturn +from typing import TYPE_CHECKING, Any, NoReturn, TypeGuard, cast, overload from xarray.namedarray._typing import ( Default, @@ -15,7 +15,6 @@ _default, _Dim, _Dims, - _DimsLike, _DimsLike2, _DType, _dtype, From b86ee936ad8fc5eb34a5d5187681033ab1fae801 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 21:38:03 +0200 Subject: [PATCH 264/367] Update _typing.py --- xarray/namedarray/_typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 31b78034a50..e656c380584 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -34,7 +34,7 @@ # Singleton type, as per https://github.com/python/typing/pull/240 class Default(Enum): - __hash__ = None + __hash__ = None # type: ignore[assignment] # TODO: Better way to set unhashable? token: Final = 0 From 99426cbbf37a32e2cb6846a1b1f6c1ce3549fc57 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Sep 2024 22:20:00 +0200 Subject: [PATCH 265/367] Update _statistical_functions.py --- .../_array_api/_statistical_functions.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 9e7f4ff0ff4..0326d47ed3e 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -16,9 +16,7 @@ _DType, _ShapeType, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray def cumulative_sum( @@ -33,7 +31,23 @@ def cumulative_sum( xp = _get_data_namespace(x) a = _dims_to_axis(x, dim, axis) - _axis = a if a is None else a[0] + _axis_none: int | None + if a is None: + _axis_none = a + else: + _axis_none = a[0] + + # TODO: The standard is not clear about what should happen when x.ndim == 0. + _axis: int + if _axis_none is None: + if x.ndim > 1: + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) + _axis = 0 + else: + _axis = _axis_none + try: _data = xp.cumulative_sum( x._data, axis=_axis, dtype=dtype, include_initial=include_initial @@ -42,20 +56,20 @@ def cumulative_sum( # Use np.cumsum until new name is introduced: # np.cumsum does not support include_initial if include_initial: - if axis < 0: - axis += x.ndim + if _axis < 0: + _axis += x.ndim d = xp.concat( [ xp.zeros( - x.shape[:axis] + (1,) + x.shape[axis + 1 :], dtype=x.dtype + x.shape[:_axis] + (1,) + x.shape[_axis + 1 :], dtype=x.dtype ), x._data, ], - axis=axis, + axis=_axis, ) else: d = x._data - _data = xp.cumsum(d, axis=axis, dtype=dtype) + _data = xp.cumsum(d, axis=_axis, dtype=dtype) return x._new(dims=x.dims, data=_data) From 18d5001680357b6b678e71af382a19a24d9fc3e0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Sep 2024 22:59:46 +0200 Subject: [PATCH 266/367] Create a truly non hashable default by subclassing list --- xarray/namedarray/_typing.py | 43 +++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index e656c380584..2add1faae22 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -2,7 +2,7 @@ import sys from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence -from enum import Enum +from enum import Enum, EnumType from types import EllipsisType, ModuleType from typing import ( TYPE_CHECKING, @@ -16,6 +16,8 @@ Union, overload, runtime_checkable, + NoReturn, + Never, ) import numpy as np @@ -32,13 +34,42 @@ Self: Any = None -# Singleton type, as per https://github.com/python/typing/pull/240 -class Default(Enum): - __hash__ = None # type: ignore[assignment] # TODO: Better way to set unhashable? - token: Final = 0 +class Default(list[Never]): + """ + Non-Hashable default value. + + A replacement value for Optional None since it is Hashable. + Same idea as https://github.com/python/typing/pull/240 + + Examples + -------- + + Runtime checks: + + >>> _default = Default() + >>> isinstance(_default, Hashable) + False + >>> _default == _default + True + >>> _default is _default + True + + Typing usage: + + >>> x: Hashable | Default = _default + >>> if isinstance(x, Default): + >>> y: Default = x + >>> else: + >>> h: Hashable = x + + TODO: if x is _default does not narrow typing, use isinstance check instead. + """ + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}>" -_default = Default.token +_default: Final[Default] = Default() # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array _T = TypeVar("_T") From 4272a2047873763ac0f5c4d86364876b021adc99 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:00:23 +0200 Subject: [PATCH 267/367] Update _utility_functions.py --- xarray/namedarray/_array_api/_utility_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 7ed2c631983..17cd5ad03c7 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -13,9 +13,7 @@ _default, _DimsLike2, ) -from xarray.namedarray.core import ( - NamedArray, -) +from xarray.namedarray.core import NamedArray def all( From a6cf09721e568dbd2b4533fbe488ab3dc1499117 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 19:19:46 +0200 Subject: [PATCH 268/367] typing --- .../_array_api/_searching_functions.py | 44 ++++++++++------- .../_array_api/_sorting_functions.py | 23 ++++++--- xarray/namedarray/_array_api/_utils.py | 49 +++++++++++++------ 3 files changed, 73 insertions(+), 43 deletions(-) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 81679ae06e9..d1ec7578033 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from xarray.namedarray._array_api._utils import ( - _dims_to_axis, + _dim_to_optional_axis, _get_data_namespace, _get_remaining_dims, _infer_dims, @@ -11,7 +11,8 @@ from xarray.namedarray._typing import ( Default, _default, - _Dims, + _Dim, + _arrayapi, ) from xarray.namedarray.core import ( NamedArray, @@ -22,15 +23,15 @@ def argmax( - x: NamedArray, + x: NamedArray[Any, Any], /, *, - dims: _Dims | Default = _default, + dim: _Dim | Default = _default, keepdims: bool = False, axis: int | None = None, -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dims, axis) + _axis = _dim_to_optional_axis(x, dim, axis) _data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) @@ -38,36 +39,36 @@ def argmax( def argmin( - x: NamedArray, + x: NamedArray[Any, Any], /, *, - dims: _Dims | Default = _default, + dim: _Dim | Default = _default, keepdims: bool = False, axis: int | None = None, -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dims, axis) + _axis = _dim_to_optional_axis(x, dim, axis) _data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later # TODO: Why do we need to do the keepdims ourselves? _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) return x._new(dims=_dims, data=data_) -def nonzero(x: NamedArray, /) -> tuple[NamedArray, ...]: +def nonzero(x: NamedArray[Any, Any], /) -> tuple[NamedArray[Any, Any], ...]: xp = _get_data_namespace(x) - _datas = xp.nonzero(x._data) + _datas: tuple[_arrayapi[Any, Any], ...] = xp.nonzero(x._data) # TODO: Verify that dims and axis matches here: - return tuple(x._new(dim, i) for dim, i in zip(x.dims, _datas)) + return tuple(x._new((dim,), data) for dim, data in zip(x.dims, _datas)) def searchsorted( - x1: NamedArray, - x2: NamedArray, + x1: NamedArray[Any, Any], + x2: NamedArray[Any, Any], /, *, side: Literal["left", "right"] = "left", - sorter: NamedArray | None = None, -) -> NamedArray: + sorter: NamedArray[Any, Any] | None = None, +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.searchsorted(x1._data, x2._data, side=side, sorter=sorter) # TODO: Check dims, probably can do it smarter: @@ -75,7 +76,12 @@ def searchsorted( return NamedArray(_dims, _data) -def where(condition: NamedArray, x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def where( + condition: NamedArray[Any, Any], + x1: NamedArray[Any, Any], + x2: NamedArray[Any, Any], + /, +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.where(condition._data, x1._data, x2._data) # TODO: Wrong, _dims should be either of the arguments. How to choose? diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 66798daccdb..17279b0c035 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -1,25 +1,31 @@ from __future__ import annotations - -from xarray.namedarray._array_api._utils import _dims_to_axis, _get_data_namespace +from typing import Any +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _dim_to_axis, +) from xarray.namedarray._typing import ( Default, _default, _Dim, + _DType, + _ShapeType, ) from xarray.namedarray.core import NamedArray def argsort( - x: NamedArray, + x: NamedArray[_ShapeType, Any], /, *, dim: _Dim | Default = _default, descending: bool = False, stable: bool = True, axis: int = -1, -) -> NamedArray: +) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis)[0] + _axis = _dim_to_axis(x, dim, axis) + # TODO: As NumPy currently has no native descending sort, we imitate it here: if not descending: _data = xp.argsort(x._data, axis=_axis, stable=stable) @@ -36,16 +42,17 @@ def argsort( def sort( - x: NamedArray, + x: NamedArray[_ShapeType, _DType], /, *, dim: _Dim | Default = _default, descending: bool = False, stable: bool = True, axis: int = -1, -) -> NamedArray: +) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) - _axis = _dims_to_axis(x, dim, axis)[0] + _axis = _dim_to_axis(x, dim, axis) + _data = xp.sort(x._data, axis=_axis, stable=stable) # TODO: As NumPy currently has no native descending sort, we imitate it here: if descending: diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index bb12afcd6eb..b3819698d98 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -15,6 +15,7 @@ _default, _Dim, _Dims, + _DimsLike, _DimsLike2, _DType, _dtype, @@ -119,7 +120,7 @@ def _infer_dims( >>> _infer_dims((1,), ("x",)) ('x',) """ - if dims is _default: + if isinstance(dims, Default): ndim = len(shape) return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim)) else: @@ -127,22 +128,18 @@ def _infer_dims( def _assert_either_dim_or_axis( - dims: _Dim | _Dims | Default, axis: _AxisLike | None + dims: _DimsLike2 | Default, axis: _AxisLike | None ) -> None: if dims is not _default and axis is not None: raise ValueError("cannot supply both 'axis' and 'dim(s)' arguments") -@overload -def _dims_to_axis( - x: NamedArray[Any, Any], dims: _DimsLike2, axis: _AxisLike -) -> NoReturn: ... -@overload -def _dims_to_axis(x: NamedArray[Any, Any], dims: _DimsLike2, axis: None) -> None: ... -@overload -def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: _AxisLike) -> None: ... -@overload -def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: None) -> None: ... +# @overload +# def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: None) -> None: ... +# @overload +# def _dims_to_axis(x: NamedArray[Any, Any], dims: _DimsLike2, axis: None) -> _Axes: ... +# @overload +# def _dims_to_axis(x: NamedArray[Any, Any], dims: Default, axis: _AxisLike) -> _Axes: ... def _dims_to_axis( x: NamedArray[Any, Any], dims: _DimsLike2 | Default, axis: _AxisLike | None ) -> _Axes | None: @@ -175,7 +172,7 @@ def _dims_to_axis( ValueError: cannot supply both 'axis' and 'dim(s)' arguments """ _assert_either_dim_or_axis(dims, axis) - if dims is not _default: + if not isinstance(dims, Default): _dims = _normalize_dimensions(dims) axis = () @@ -195,6 +192,23 @@ def _dims_to_axis( return (axis,) +def _dim_to_optional_axis( + x: NamedArray[Any, Any], dim: _Dim | Default, axis: int | None +) -> int | None: + a = _dims_to_axis(x, dim, axis) + if a is None: + return a + + return a[0] + + +def _dim_to_axis(x: NamedArray[Any, Any], dim: _Dim | Default, axis: int) -> int: + _dim: _Dim = x.dims[axis] if isinstance(dim, Default) else dim + _axis = _dim_to_optional_axis(x, _dim, None) + assert _axis is not None # Not supposed to happen. + return _axis + + def _get_remaining_dims( x: NamedArray[Any, _DType], data: duckarray[Any, _DType], @@ -230,10 +244,13 @@ def _get_remaining_dims( def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: - if dim is _default: - dim = f"dim_{len(dims)}" + if isinstance(dim, Default): + _dim: _Dim = f"dim_{len(dims)}" + else: + _dim = dim + d = list(dims) - d.insert(axis, dim) + d.insert(axis, _dim) return tuple(d) From b1ad54cae4a86177c78ccbe92c0125efb96d5253 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Sep 2024 17:20:23 +0000 Subject: [PATCH 269/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_searching_functions.py | 2 +- xarray/namedarray/_array_api/_sorting_functions.py | 4 +++- xarray/namedarray/_array_api/_utils.py | 3 +-- xarray/namedarray/_typing.py | 4 +--- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index d1ec7578033..37388542118 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -10,9 +10,9 @@ ) from xarray.namedarray._typing import ( Default, + _arrayapi, _default, _Dim, - _arrayapi, ) from xarray.namedarray.core import ( NamedArray, diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index 17279b0c035..12554597e3c 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -1,8 +1,10 @@ from __future__ import annotations + from typing import Any + from xarray.namedarray._array_api._utils import ( - _get_data_namespace, _dim_to_axis, + _get_data_namespace, ) from xarray.namedarray._typing import ( Default, diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index b3819698d98..ff8cec000ec 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Any, NoReturn, TypeGuard, cast, overload +from typing import TYPE_CHECKING, Any, TypeGuard, cast from xarray.namedarray._typing import ( Default, @@ -15,7 +15,6 @@ _default, _Dim, _Dims, - _DimsLike, _DimsLike2, _DType, _dtype, diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 2add1faae22..83a3d754f10 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -2,13 +2,13 @@ import sys from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence -from enum import Enum, EnumType from types import EllipsisType, ModuleType from typing import ( TYPE_CHECKING, Any, Final, Literal, + Never, Protocol, SupportsIndex, TypedDict, @@ -16,8 +16,6 @@ Union, overload, runtime_checkable, - NoReturn, - Never, ) import numpy as np From 63db063b3c76753191cbc887743437b2b13b08f1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 19:28:24 +0200 Subject: [PATCH 270/367] Update _typing.py --- xarray/namedarray/_typing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 83a3d754f10..b5c558c843f 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -56,9 +56,9 @@ class Default(list[Never]): >>> x: Hashable | Default = _default >>> if isinstance(x, Default): - >>> y: Default = x - >>> else: - >>> h: Hashable = x + ... y: Default = x + ... else: + ... h: Hashable = x TODO: if x is _default does not narrow typing, use isinstance check instead. """ From e810921af0eac4906d0d58c9011f5e15f72ba467 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Sep 2024 17:29:00 +0000 Subject: [PATCH 271/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index b5c558c843f..c814adb1829 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -59,6 +59,7 @@ class Default(list[Never]): ... y: Default = x ... else: ... h: Hashable = x + ... TODO: if x is _default does not narrow typing, use isinstance check instead. """ From d9b516df5eb2e254a8c57f350c1d45c794649582 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:34:33 +0200 Subject: [PATCH 272/367] Update _creation_functions.py --- .../_array_api/_creation_functions.py | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 72706a8331c..19e22ceba33 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -12,21 +12,15 @@ _ArrayLike, _default, _Device, - _DimsLike, + _DimsLike2, _DType, _Shape, _ShapeType, duckarray, + Default, + _Shape1D, ) -from xarray.namedarray.core import ( - NamedArray, -) - - -def _like_args(x, dtype=None, device: _Device | None = None): - if dtype is None: - dtype = x.dtype - return dict(shape=x.shape, dtype=dtype, device=device) +from xarray.namedarray.core import NamedArray def arange( @@ -37,7 +31,7 @@ def arange( *, dtype: _DType | None = None, device: _Device | None = None, -) -> NamedArray[_ShapeType, _DType]: +) -> NamedArray[_Shape1D, _DType]: xp = _get_namespace_dtype(dtype) _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) @@ -52,7 +46,7 @@ def asarray( dtype: _DType, device: _Device | None = ..., copy: bool | None = ..., - dims: _DimsLike = ..., + dims: _DimsLike2 | Default = ..., ) -> NamedArray[_ShapeType, _DType]: ... @overload def asarray( @@ -62,7 +56,7 @@ def asarray( dtype: _DType, device: _Device | None = ..., copy: bool | None = ..., - dims: _DimsLike = ..., + dims: _DimsLike2 | Default = ..., ) -> NamedArray[Any, _DType]: ... @overload def asarray( @@ -72,7 +66,7 @@ def asarray( dtype: None, device: _Device | None = None, copy: bool | None = None, - dims: _DimsLike = ..., + dims: _DimsLike2 | Default = ..., ) -> NamedArray[_ShapeType, _DType]: ... @overload def asarray( @@ -82,7 +76,7 @@ def asarray( dtype: None, device: _Device | None = ..., copy: bool | None = ..., - dims: _DimsLike = ..., + dims: _DimsLike2 | Default = ..., ) -> NamedArray[Any, _DType]: ... def asarray( obj: duckarray[_ShapeType, _DType] | _ArrayLike, @@ -91,7 +85,7 @@ def asarray( dtype: _DType | None = None, device: _Device | None = None, copy: bool | None = None, - dims: _DimsLike = _default, + dims: _DimsLike2 | Default = _default, ) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: """ Create a Named array from an array-like object. @@ -166,11 +160,18 @@ def from_dlpack( *, device: _Device | None = None, copy: bool | None = None, -) -> NamedArray: - xp = _get_data_namespace(x) - _device = x.device if device is None else device - _data = xp.from_dlpack(x, device=_device, copy=copy) - return x._new(data=_data) +) -> NamedArray[Any, Any]: + if isinstance(x, NamedArray): + xp = _get_data_namespace(x) + _device = x.device if device is None else device + _data = xp.from_dlpack(x, device=_device, copy=copy) + _dims = x.dims + else: + xp = _get_namespace(x) + _device = device + _data = xp.from_dlpack(x, device=_device, copy=copy) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) def full( @@ -208,7 +209,7 @@ def linspace( dtype: _DType | None = None, device: _Device | None = None, endpoint: bool = True, -) -> NamedArray[_ShapeType, _DType]: +) -> NamedArray[_Shape1D, _DType]: xp = _get_namespace_dtype(dtype) _data = xp.linspace( start, @@ -222,7 +223,9 @@ def linspace( return NamedArray(_dims, _data) -def meshgrid(*arrays: NamedArray, indexing: str = "xy") -> list[NamedArray]: +def meshgrid( + *arrays: NamedArray[Any, Any], indexing: str = "xy" +) -> list[NamedArray[Any, Any]]: arr = arrays[0] xp = _get_data_namespace(arr) _datas = xp.meshgrid(*[a._data for a in arrays], indexing=indexing) From 686ef2be2abb4632a85090082102d3c8a32c9a1a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:34:36 +0200 Subject: [PATCH 273/367] Update _data_type_functions.py --- .../_array_api/_data_type_functions.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/xarray/namedarray/_array_api/_data_type_functions.py b/xarray/namedarray/_array_api/_data_type_functions.py index a213f45f771..fc728be4a2c 100644 --- a/xarray/namedarray/_array_api/_data_type_functions.py +++ b/xarray/namedarray/_array_api/_data_type_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast from xarray.namedarray._array_api._utils import ( _get_data_namespace, @@ -67,40 +67,42 @@ def astype( return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] -def can_cast(from_: _dtype | NamedArray, to: _dtype, /) -> bool: +def can_cast(from_: _dtype[Any] | NamedArray[Any, Any], to: _dtype[Any], /) -> bool: if isinstance(from_, NamedArray): xp = _get_data_namespace(from_) from_ = from_.dtype - return xp.can_cast(from_, to) + return cast(bool, xp.can_cast(from_, to)) # TODO: Why is cast necessary? else: xp = _get_namespace_dtype(from_) - return xp.can_cast(from_, to) + return cast(bool, xp.can_cast(from_, to)) # TODO: Why is cast necessary? -def finfo(type: _dtype | NamedArray[Any, Any], /) -> _FInfo: +def finfo(type: _dtype[Any] | NamedArray[Any, Any], /) -> _FInfo: if isinstance(type, NamedArray): xp = _get_data_namespace(type) - return xp.finfo(type._data) + return cast(_FInfo, xp.finfo(type._data)) # TODO: Why is cast necessary? else: xp = _get_namespace_dtype(type) - return xp.finfo(type) + return cast(_FInfo, xp.finfo(type)) # TODO: Why is cast necessary? -def iinfo(type: _dtype | NamedArray[Any, Any], /) -> _IInfo: +def iinfo(type: _dtype[Any] | NamedArray[Any, Any], /) -> _IInfo: if isinstance(type, NamedArray): xp = _get_data_namespace(type) - return xp.iinfo(type._data) + return cast(_IInfo, xp.iinfo(type._data)) # TODO: Why is cast necessary? else: xp = _get_namespace_dtype(type) - return xp.iinfo(type) + return cast(_IInfo, xp.iinfo(type)) # TODO: Why is cast necessary? -def isdtype(dtype: _dtype, kind: _dtype | str | tuple[_dtype | str, ...]) -> bool: +def isdtype( + dtype: _dtype[Any], kind: _dtype[Any] | str | tuple[_dtype[Any] | str, ...] +) -> bool: xp = _get_namespace_dtype(dtype) - return xp.isdtype(dtype, kind) + return cast(bool, xp.isdtype(dtype, kind)) # TODO: Why is cast necessary? -def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: +def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype[Any]) -> _dtype[Any]: # TODO: Empty arg? arr_or_dtype = arrays_and_dtypes[0] if isinstance(arr_or_dtype, NamedArray): @@ -108,6 +110,9 @@ def result_type(*arrays_and_dtypes: NamedArray[Any, Any] | _dtype) -> _dtype: else: xp = _get_namespace_dtype(arr_or_dtype) - return xp.result_type( - *(a.dtype if isinstance(a, NamedArray) else a for a in arrays_and_dtypes) - ) + return cast( + _dtype[Any], + xp.result_type( + *(a.dtype if isinstance(a, NamedArray) else a for a in arrays_and_dtypes) + ), + ) # TODO: Why is cast necessary? From 937724df6036f661aef55ca14c678e9b49bcd20f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:34:42 +0200 Subject: [PATCH 274/367] Update _linear_algebra_functions.py --- .../_array_api/_linear_algebra_functions.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index 245f61f179e..f8b730907af 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -1,12 +1,15 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Any from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims from xarray.namedarray.core import NamedArray -def matmul(x1: NamedArray, x2: NamedArray, /) -> NamedArray: +def matmul( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.matmul(x1._data, x2._data) # TODO: Figure out a better way: @@ -15,12 +18,12 @@ def matmul(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def tensordot( - x1: NamedArray, - x2: NamedArray, + x1: NamedArray[Any, Any], + x2: NamedArray[Any, Any], /, *, axes: int | tuple[Sequence[int], Sequence[int]] = 2, -) -> NamedArray: +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.tensordot(x1._data, x2._data, axes=axes) # TODO: Figure out a better way: @@ -28,7 +31,7 @@ def tensordot( return NamedArray(_dims, _data) -def matrix_transpose(x: NamedArray, /) -> NamedArray: +def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.matrix_transpose(x._data) # TODO: Figure out a better way: @@ -36,7 +39,9 @@ def matrix_transpose(x: NamedArray, /) -> NamedArray: return NamedArray(_dims, _data) -def vecdot(x1: NamedArray, x2: NamedArray, /, *, axis: int = -1) -> NamedArray: +def vecdot( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /, *, axis: int = -1 +) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) _data = xp.vecdot(x1._data, x2._data, axis=axis) # TODO: Figure out a better way: From 28dde6a4a2da717082c2f671c28ed46e0bc8f293 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:34:48 +0200 Subject: [PATCH 275/367] Update _set_functions.py --- .../namedarray/_array_api/_set_functions.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index 64c852b61c3..009a5bb96a8 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import NamedTuple +from typing import Any, NamedTuple from xarray.namedarray._array_api._utils import ( _get_data_namespace, @@ -10,23 +10,23 @@ class UniqueAllResult(NamedTuple): - values: NamedArray - indices: NamedArray - inverse_indices: NamedArray - counts: NamedArray + values: NamedArray[Any, Any] + indices: NamedArray[Any, Any] + inverse_indices: NamedArray[Any, Any] + counts: NamedArray[Any, Any] class UniqueCountsResult(NamedTuple): - values: NamedArray - counts: NamedArray + values: NamedArray[Any, Any] + counts: NamedArray[Any, Any] class UniqueInverseResult(NamedTuple): - values: NamedArray - inverse_indices: NamedArray + values: NamedArray[Any, Any] + inverse_indices: NamedArray[Any, Any] -def unique_all(x: NamedArray, /) -> UniqueAllResult: +def unique_all(x: NamedArray[Any, Any], /) -> UniqueAllResult: xp = _get_data_namespace(x) values, indices, inverse_indices, counts = xp.unique_all(x._data) _dims_values = _infer_dims(values.shape) # TODO: Fix @@ -41,7 +41,7 @@ def unique_all(x: NamedArray, /) -> UniqueAllResult: ) -def unique_counts(x: NamedArray, /) -> UniqueCountsResult: +def unique_counts(x: NamedArray[Any, Any], /) -> UniqueCountsResult: xp = _get_data_namespace(x) values, counts = xp.unique_counts(x._data) _dims_values = _infer_dims(values.shape) # TODO: Fix dims @@ -52,7 +52,7 @@ def unique_counts(x: NamedArray, /) -> UniqueCountsResult: ) -def unique_inverse(x: NamedArray, /) -> UniqueInverseResult: +def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult: xp = _get_data_namespace(x) values, inverse_indices = xp.unique_inverse(x._data) _dims_values = _infer_dims(values.shape) # TODO: Fix @@ -63,7 +63,7 @@ def unique_inverse(x: NamedArray, /) -> UniqueInverseResult: ) -def unique_values(x: NamedArray, /) -> NamedArray: +def unique_values(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _data = xp.unique_values(x._data) _dims = _infer_dims(_data.shape) # TODO: Fix From 9354771d84e28c506fbd7b8d56c1b5702e5fb69b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:34:51 +0200 Subject: [PATCH 276/367] Update _typing.py --- xarray/namedarray/_typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index c814adb1829..e2be6f8ee92 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -112,6 +112,7 @@ def imag(self) -> _T_co: ... _ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]] _ShapeType = TypeVar("_ShapeType", bound=Any) _ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True) +_Shape1D = tuple[int] _Axis = int _Axes = tuple[_Axis, ...] From ee12ae56ed66e41f9c5aaf615a4109997213a89b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Sep 2024 18:35:30 +0000 Subject: [PATCH 277/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_creation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 19e22ceba33..ace65f0b43e 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -9,16 +9,16 @@ _infer_dims, ) from xarray.namedarray._typing import ( + Default, _ArrayLike, _default, _Device, _DimsLike2, _DType, _Shape, + _Shape1D, _ShapeType, duckarray, - Default, - _Shape1D, ) from xarray.namedarray.core import NamedArray From 4de05f8f36f9c892c7a557ed773530f5875bef1d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:55:00 +0200 Subject: [PATCH 278/367] Using operators instead should handle inplace methods missing --- xarray/namedarray/core.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index aa03ee4e409..b622d7082a3 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -638,8 +638,7 @@ def __xor__(self, other: int | bool | NamedArray, /): return bitwise_xor(self, self._maybe_asarray(other)) def __iadd__(self, other: int | float | NamedArray, /): - - self._data.__iadd__(self._maybe_asarray(other)._data) + self._data += self._maybe_asarray(other)._data return self def __radd__(self, other: int | float | NamedArray, /): @@ -649,7 +648,7 @@ def __radd__(self, other: int | float | NamedArray, /): def __iand__(self, other: int | bool | NamedArray, /): - self._data.__iand__(self._maybe_asarray(other)._data) + self._data &= self._maybe_asarray(other)._data return self def __rand__(self, other: int | bool | NamedArray, /): @@ -658,8 +657,7 @@ def __rand__(self, other: int | bool | NamedArray, /): return bitwise_and(self._maybe_asarray(other), self) def __ifloordiv__(self, other: int | float | NamedArray, /): - - self._data.__ifloordiv__(self._maybe_asarray(other)._data) + self._data //= self._maybe_asarray(other)._data return self def __rfloordiv__(self, other: int | float | NamedArray, /): @@ -678,7 +676,7 @@ def __rlshift__(self, other: int | NamedArray, /): return bitwise_left_shift(self._maybe_asarray(other), self) def __imatmul__(self, other: NamedArray, /): - self._data.__imatmul__(other._data) + self._data @= other._data return self def __rmatmul__(self, other: NamedArray, /): @@ -687,8 +685,7 @@ def __rmatmul__(self, other: NamedArray, /): return matmul(self._maybe_asarray(other), self) def __imod__(self, other: int | float | NamedArray, /): - - self._data.__imod__(self._maybe_asarray(other)._data) + self._data %= self._maybe_asarray(other)._data return self def __rmod__(self, other: int | float | NamedArray, /): @@ -697,8 +694,7 @@ def __rmod__(self, other: int | float | NamedArray, /): return remainder(self._maybe_asarray(other), self) def __imul__(self, other: int | float | NamedArray, /): - - self._data.__imul__(self._maybe_asarray(other)._data) + self._data *= self._maybe_asarray(other)._data return self def __rmul__(self, other: int | float | NamedArray, /): @@ -707,8 +703,8 @@ def __rmul__(self, other: int | float | NamedArray, /): return multiply(self._maybe_asarray(other), self) def __ior__(self, other: int | bool | NamedArray, /): + self._data |= self._maybe_asarray(other)._data - self._data.__ior__(self._maybe_asarray(other)._data) return self def __ror__(self, other: int | bool | NamedArray, /): @@ -717,8 +713,7 @@ def __ror__(self, other: int | bool | NamedArray, /): return bitwise_or(self._maybe_asarray(other), self) def __ipow__(self, other: int | float | NamedArray, /): - - self._data.__ipow__(self._maybe_asarray(other)._data) + self._data **= self._maybe_asarray(other)._data return self def __rpow__(self, other: int | float | NamedArray, /): @@ -727,8 +722,7 @@ def __rpow__(self, other: int | float | NamedArray, /): return pow(self._maybe_asarray(other), self) def __irshift__(self, other: int | NamedArray, /): - - self._data.__irshift__(self._maybe_asarray(other)._data) + self._data >>= self._maybe_asarray(other)._data return self def __rrshift__(self, other: int | NamedArray, /): @@ -738,7 +732,7 @@ def __rrshift__(self, other: int | NamedArray, /): def __isub__(self, other: int | float | NamedArray, /): - self._data.__isub__(self._maybe_asarray(other)._data) + self._data -= self._maybe_asarray(other)._data return self def __rsub__(self, other: int | float | NamedArray, /): @@ -747,8 +741,7 @@ def __rsub__(self, other: int | float | NamedArray, /): return subtract(self._maybe_asarray(other), self) def __itruediv__(self, other: float | NamedArray, /): - - self._data.__itruediv__(self._maybe_asarray(other)._data) + self._data /= self._maybe_asarray(other) return self def __rtruediv__(self, other: float | NamedArray, /): @@ -758,7 +751,7 @@ def __rtruediv__(self, other: float | NamedArray, /): def __ixor__(self, other: int | bool | NamedArray, /): - self._data.__ixor__(self._maybe_asarray(other)._data) + self._data ^= self._maybe_asarray(other)._data return self def __rxor__(self, other, /): From 0328e449474e044ea01db11462d6c400f0acbe2f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 21:00:12 +0200 Subject: [PATCH 279/367] Update _indexing_functions.py --- xarray/namedarray/_array_api/_indexing_functions.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_indexing_functions.py b/xarray/namedarray/_array_api/_indexing_functions.py index e57d3db0056..1a4dac96ce8 100644 --- a/xarray/namedarray/_array_api/_indexing_functions.py +++ b/xarray/namedarray/_array_api/_indexing_functions.py @@ -1,22 +1,25 @@ from __future__ import annotations +from typing import Any + from xarray.namedarray._array_api._utils import _dims_to_axis, _get_data_namespace from xarray.namedarray._typing import ( Default, _default, _Dim, + _DType, ) from xarray.namedarray.core import NamedArray def take( - x: NamedArray, - indices: NamedArray, + x: NamedArray[Any, _DType], + indices: NamedArray[Any, Any], /, *, dim: _Dim | Default = _default, axis: int | None = None, -) -> NamedArray: +) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis)[0] # TODO: Handle attrs? will get x1 now From 7f9c8bb7a19f5fefa6bf0143c17d09aad4080768 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 21:28:24 +0200 Subject: [PATCH 280/367] Update core.py --- xarray/namedarray/core.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index b622d7082a3..564724b9d53 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -264,10 +264,6 @@ def __init__( data: duckarray[Any, _DType_co], attrs: _AttrsLike = None, ): - if not isinstance(data, _arrayfunction_or_api): - raise NotImplementedError( - f"data is not a valid duckarray, got {data=}, {dims=}" - ) self._data = data self._dims = self._parse_dimensions(dims) self._attrs = dict(attrs) if attrs else None @@ -324,10 +320,6 @@ def _new( attributes you want to store with the array. Will copy the attrs from x by default. """ - if not isinstance(data, _arrayfunction_or_api): - raise NotImplementedError( - f"data is not a valid duckarray, got {data=}, {dims=}" - ) return _new(self, dims, data, attrs) def _replace( @@ -356,10 +348,6 @@ def _replace( attributes you want to store with the array. Will copy the attrs from x by default. """ - if not isinstance(data, _arrayfunction_or_api): - raise NotImplementedError( - f"data is not a valid duckarray, got {data=}, {dims=}" - ) return cast("Self", self._new(dims, data, attrs)) def _copy( From 60142355c8e4b74ce8e801734d56691fdb2d43a0 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 21:42:42 +0200 Subject: [PATCH 281/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index ff8cec000ec..52ab5752d33 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -317,8 +317,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] zip_longest(*map(reversed, dims), fillvalue=_default), zip_longest(*map(reversed, shapes), fillvalue=-1), ): - _d = tuple(set(d) - {_default}) - + _d = tuple(v for v in d if v is not _default) if any(_isnone(sizes)): # dim = None raise NotImplementedError("TODO: Handle None in shape, {shapes = }") From 117548d8b8be57af9ba325e8b2cd7ea5fbd1a6b5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 21:47:13 +0200 Subject: [PATCH 282/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 52ab5752d33..0c1813ed02a 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -22,8 +22,7 @@ duckarray, ) -if TYPE_CHECKING: - from xarray.namedarray.core import NamedArray +from xarray.namedarray.core import NamedArray def _maybe_default_namespace(xp: ModuleType | None = None) -> ModuleType: From 64278e5c66a24d1d62cbfbfe91e715d0cb5fcd23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Sep 2024 19:48:48 +0000 Subject: [PATCH 283/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 0c1813ed02a..5491890d973 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Any, TypeGuard, cast +from typing import Any, TypeGuard, cast from xarray.namedarray._typing import ( Default, @@ -21,7 +21,6 @@ _Shape, duckarray, ) - from xarray.namedarray.core import NamedArray From 783be0d7a49191d8e0ef0078e87b1dc7abf528e6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:13:21 +0200 Subject: [PATCH 284/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 0c1813ed02a..6dd62bd8ccb 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -316,7 +316,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] zip_longest(*map(reversed, dims), fillvalue=_default), zip_longest(*map(reversed, shapes), fillvalue=-1), ): - _d = tuple(v for v in d if v is not _default) + _d = tuple(set(v for v in d if v is not _default)) if any(_isnone(sizes)): # dim = None raise NotImplementedError("TODO: Handle None in shape, {shapes = }") From 5ce92c902bf81be8840d4aa8818c642e0cb37967 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:26:59 +0200 Subject: [PATCH 285/367] Update core.py --- xarray/namedarray/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 564724b9d53..d331e05b3cf 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -654,8 +654,7 @@ def __rfloordiv__(self, other: int | float | NamedArray, /): return floor_divide(self._maybe_asarray(other), self) def __ilshift__(self, other: int | NamedArray, /): - - self._data.__ilshift__(self._maybe_asarray(other)._data) + self._data <<= self._maybe_asarray(other)._data return self def __rlshift__(self, other: int | NamedArray, /): @@ -729,7 +728,7 @@ def __rsub__(self, other: int | float | NamedArray, /): return subtract(self._maybe_asarray(other), self) def __itruediv__(self, other: float | NamedArray, /): - self._data /= self._maybe_asarray(other) + self._data /= self._maybe_asarray(other)._data return self def __rtruediv__(self, other: float | NamedArray, /): From de99d36a0385904b72c86b5c3b2c8d54e3d977fa Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:38:10 +0200 Subject: [PATCH 286/367] skip cumulative sum --- xarray/tests/namedarray_array_api_skips.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/tests/namedarray_array_api_skips.txt b/xarray/tests/namedarray_array_api_skips.txt index 6aa54bb4639..28fb8c836fe 100644 --- a/xarray/tests/namedarray_array_api_skips.txt +++ b/xarray/tests/namedarray_array_api_skips.txt @@ -37,3 +37,8 @@ array_api_tests/test_statistical_functions.py::test_prod # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] + +# numpy scalars missing __complex__: +# AttributeError: 'numpy.uint64' object has no attribute '__complex__' +# https://github.com/numpy/numpy/issues/27305 +array_api_tests/test_statistical_functions.py::test_cumulative_sum From 1835514e86eb401bd4080b1a7b05dd2c8dd91295 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:42:03 +0200 Subject: [PATCH 287/367] Update _typing.py --- xarray/namedarray/_typing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index e2be6f8ee92..424997c08f0 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -8,7 +8,6 @@ Any, Final, Literal, - Never, Protocol, SupportsIndex, TypedDict, @@ -22,9 +21,9 @@ try: if sys.version_info >= (3, 11): - from typing import TypeAlias + from typing import Never, TypeAlias else: - from typing import TypeAlias + from typing_extensions import Never, TypeAlias except ImportError: if TYPE_CHECKING: raise From 68c271ccc3655d772d7849b67bb0b702fe0cd7c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Sep 2024 20:42:37 +0000 Subject: [PATCH 288/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 424997c08f0..93a71a2b267 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -23,7 +23,9 @@ if sys.version_info >= (3, 11): from typing import Never, TypeAlias else: - from typing_extensions import Never, TypeAlias + from typing import TypeAlias + + from typing_extensions import Never except ImportError: if TYPE_CHECKING: raise From 92ee20328b7682dcfbabf870b1772737001ae575 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:45:19 +0200 Subject: [PATCH 289/367] Update _typing.py --- xarray/namedarray/_typing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 424997c08f0..248fa6d23a5 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: raise else: + Never: Any = None Self: Any = None From 019a6103fc14bb3c5d9dd9015559201d1958fe66 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Sep 2024 23:11:57 +0200 Subject: [PATCH 290/367] Update core.py --- xarray/namedarray/core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d331e05b3cf..ac1431bc366 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -143,15 +143,15 @@ def _new( attributes you want to store with the array. Will copy the attrs from x by default. """ - dims_ = copy.copy(x._dims) if dims is _default else dims + dims_ = copy.copy(x._dims) if isinstance(dims, Default) else dims attrs_: Mapping[Any, Any] | None - if attrs is _default: + if isinstance(attrs, _default): attrs_ = None if x._attrs is None else x._attrs.copy() else: attrs_ = attrs - if data is _default: + if isinstance(data, Default): return type(x)(dims_, copy.copy(x._data), attrs_) else: cls_ = cast("type[NamedArray[_ShapeType, _DType]]", type(x)) @@ -1351,12 +1351,12 @@ def _as_sparse( from xarray.namedarray._array_api import astype # TODO: what to do if dask-backended? - if fill_value is _default: + if isinstance(fill_value, Default): dtype, fill_value = dtypes.maybe_promote(self.dtype) else: dtype = dtypes.result_type(self.dtype, fill_value) - if sparse_format is _default: + if isinstance(sparse_format, Default): sparse_format = "coo" try: as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") From 6e7de32a6f58da00b06a1befddc5c9788bbca3a9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 9 Sep 2024 07:24:50 +0200 Subject: [PATCH 291/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index ac1431bc366..5c1d21d884a 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -27,6 +27,7 @@ ) from xarray.namedarray._aggregations import NamedArrayAggregations from xarray.namedarray._typing import ( + Default, ErrorOptionsWithWarn, _arrayapi, _arrayfunction_or_api, @@ -58,7 +59,6 @@ from xarray.core.types import Dims, T_Chunks from xarray.namedarray._typing import ( - Default, _AttrsLike, _Chunks, _Device, From c302e6585268c24006abdd94906e5abead15c3e9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 9 Sep 2024 07:30:05 +0200 Subject: [PATCH 292/367] Update core.py --- xarray/namedarray/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 5c1d21d884a..2f2ced0d09b 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -146,7 +146,7 @@ def _new( dims_ = copy.copy(x._dims) if isinstance(dims, Default) else dims attrs_: Mapping[Any, Any] | None - if isinstance(attrs, _default): + if isinstance(attrs, Default): attrs_ = None if x._attrs is None else x._attrs.copy() else: attrs_ = attrs From ad6d7e494e10d3e16249c511690839ca1f8bb6ae Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 10 Sep 2024 21:26:32 +0200 Subject: [PATCH 293/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index c54605a0627..c0fa1399eb5 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -149,6 +149,8 @@ def _dims_to_axis( Convert to dims to axis values >>> x = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) + >>> _dims_to_axis(x, ("y", "x"), None) + (1, 0) >>> _dims_to_axis(x, ("y",), None) (1,) >>> _dims_to_axis(x, _default, 0) @@ -175,7 +177,7 @@ def _dims_to_axis( axis = () for dim in _dims: try: - axis = (x.dims.index(dim),) + axis += (x.dims.index(dim),) except ValueError: raise ValueError(f"{dim!r} not found in array dimensions {x.dims!r}") return axis From 8fa6d5862ec4f1d7b2fa1025f637d3e50e57e8bd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 10 Sep 2024 23:09:21 +0200 Subject: [PATCH 294/367] Update _manipulation_functions.py --- .../_array_api/_manipulation_functions.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 0799cec58c6..d024487ab2d 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -8,6 +8,7 @@ _get_data_namespace, _infer_dims, _insert_dim, + _dims_to_axis, ) from xarray.namedarray._typing import ( Default, @@ -18,6 +19,7 @@ _Dim, _DType, _ShapeType, + _Dims, ) from xarray.namedarray.core import NamedArray @@ -124,7 +126,13 @@ def moveaxis( return x._new(_dims, _data) -def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: +def permute_dims( + x: NamedArray[Any, _DType], + /, + axes: _Axes | None = None, + *, + dims: _Dims | Default = _default, +) -> NamedArray[Any, _DType]: """ Permutes the dimensions of an array. @@ -141,16 +149,24 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DT An array with permuted dimensions. The returned array must have the same data type as x. + Examples + -------- + >>> x = NamedArray(("x", "y", "z"), np.zeros((3, 4, 5))) + >>> y = permute_dims(x, (2, 1, 0)) + >>> y.dims, y.shape + (('z', 'y', 'x'), (5, 4, 3)) + >>> y = permute_dims(x, dims=("y", "x", "z")) + >>> y.dims, y.shape + (('y', 'x', 'z'), (4, 3, 5)) """ - - dims = x.dims - new_dims = tuple(dims[i] for i in axes) - if isinstance(x._data, _arrayapi): - xp = _get_data_namespace(x) - out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) - else: - out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] - return out + xp = _get_data_namespace(x) + _axis = _dims_to_axis(x, dims, axes) + if _axis is None: + raise TypeError("permute_dims missing argument axes or dims") + old_dims = x.dims + _dims = tuple(old_dims[i] for i in _axis) + _data = xp.permute_dims(x._data, _axis) + return x._new(_dims, _data) def repeat( From b65de7a5d01b347f3715045f298dbd3ab7844389 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:36:40 +0200 Subject: [PATCH 295/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index d024487ab2d..a58cfb30f9a 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -15,6 +15,7 @@ _arrayapi, _Axes, _Axis, + _AxisLike, _default, _Dim, _DType, @@ -164,8 +165,8 @@ def permute_dims( if _axis is None: raise TypeError("permute_dims missing argument axes or dims") old_dims = x.dims - _dims = tuple(old_dims[i] for i in _axis) _data = xp.permute_dims(x._data, _axis) + _dims = tuple(old_dims[i] for i in _axis) return x._new(_dims, _data) From 0d0ea3b89ee6b9032cf4b72d18590c60429608eb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:39:29 +0200 Subject: [PATCH 296/367] normalize negative values to positive --- xarray/namedarray/_array_api/_utils.py | 120 +++++++++++++++++++++++-- 1 file changed, 113 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index c0fa1399eb5..45f83b52ea2 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -124,6 +124,107 @@ def _infer_dims( return _normalize_dimensions(dims) +def _normalize_axis_index(axis: int, ndim: int) -> int: + """ + Parameters + ---------- + axis : int + The un-normalized index of the axis. Can be negative + ndim : int + The number of dimensions of the array that `axis` should be normalized + against + + Returns + ------- + normalized_axis : int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If the axis index is invalid, when `-ndim <= axis < ndim` is false. + + Examples + -------- + >>> _normalize_axis_index(0, ndim=3) + 0 + >>> _normalize_axis_index(1, ndim=3) + 1 + >>> _normalize_axis_index(-1, ndim=3) + 2 + + >>> _normalize_axis_index(3, ndim=3) + Traceback (most recent call last): + ... + AxisError: axis 3 is out of bounds for array of dimension 3 + >>> _normalize_axis_index(-4, ndim=3, msg_prefix='axes_arg') + Traceback (most recent call last): + ... + AxisError: axes_arg: axis -4 is out of bounds for array of dimension 3 + """ + + if -ndim > axis >= ndim: + raise ValueError(f"axis {axis} is out of bounds for array of dimension {ndim}") + + return axis % ndim + + +def _normalize_axis_tuple( + axis: _AxisLike, + ndim: int, + argname: str | None = None, + allow_duplicate: bool = False, +) -> _Axes: + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If any axis provided is out of range + ValueError + If an axis is repeated + """ + if not isinstance(axis, tuple): + _axis = (axis,) + else: + _axis = axis + + # Going via an iterator directly is slower than via list comprehension. + _axis = tuple([_normalize_axis_index(ax, ndim) for ax in _axis]) + if not allow_duplicate and len(set(_axis)) != len(_axis): + if argname: + raise ValueError(f"repeated axis in `{argname}` argument") + else: + raise ValueError("repeated axis") + return _axis + + def _assert_either_dim_or_axis( dims: _DimsLike2 | Default, axis: _AxisLike | None ) -> None: @@ -148,18 +249,26 @@ def _dims_to_axis( Convert to dims to axis values - >>> x = NamedArray(("x", "y"), np.array([[1, 2, 3], [5, 6, 7]])) + >>> x = NamedArray(("x", "y", "z"), np.zeros((1, 2, 3))) >>> _dims_to_axis(x, ("y", "x"), None) (1, 0) >>> _dims_to_axis(x, ("y",), None) (1,) >>> _dims_to_axis(x, _default, 0) (0,) - >>> _dims_to_axis(x, _default, None) + >>> type(_dims_to_axis(x, _default, None)) + NoneType + + Normalizes negative integers + + >>> _dims_to_axis(x, _default, -1) + (2,) + >>> _dims_to_axis(x, _default, (-2, -1)) + (1, 2) Using Hashable dims - >>> x = NamedArray(("x", None), np.array([[1, 2, 3], [5, 6, 7]])) + >>> x = NamedArray(("x", None), np.zeros((1, 2))) >>> _dims_to_axis(x, None, None) (1,) @@ -185,10 +294,7 @@ def _dims_to_axis( if axis is None: return axis - if isinstance(axis, tuple): - return axis - else: - return (axis,) + return _normalize_axis_tuple(axis, x.ndim) def _dim_to_optional_axis( From 52343d9d6e45d0b210c69231dd8c1f31d2d50428 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 17 Sep 2024 21:36:58 +0200 Subject: [PATCH 297/367] more tests and fixes to broadcasted dims --- xarray/namedarray/_array_api/_utils.py | 65 ++++++++++++++++++-------- 1 file changed, 45 insertions(+), 20 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 45f83b52ea2..9b95da2b0dd 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -385,12 +385,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] >>> _get_broadcasted_dims(a) (('x', 'y', 'z'), (5, 3, 4)) - >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) - >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) - >>> _get_broadcasted_dims(a, b) - (('x', 'y', 'z'), (5, 3, 4)) - >>> _get_broadcasted_dims(b, a) - (('x', 'y', 'z'), (5, 3, 4)) + Broadcasting 0- and 1-sized dims >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) @@ -407,6 +402,23 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] >>> _get_broadcasted_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) + Broadcasting different dims + + >>> a = NamedArray(("x",), np.zeros((5,))) + >>> b = NamedArray(("y",), np.zeros((3,))) + >>> _get_broadcasted_dims(a, b) + (('x', 'y'), (5, 3)) + + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) + >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) + >>> _get_broadcasted_dims(a, b) + (('x', 'y', 'z'), (5, 3, 4)) + >>> _get_broadcasted_dims(b, a) + (('x', 'y', 'z'), (5, 3, 4)) + + + # Errors + >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((2, 3, 4))) >>> _get_broadcasted_dims(a, b) @@ -414,21 +426,34 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] ... ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4)) """ - dims = tuple(a.dims for a in arrays) - shapes = tuple(a.shape for a in arrays) - - out_dims: _Dims = () - out_shape: _Shape = () - for d, sizes in zip( - zip_longest(*map(reversed, dims), fillvalue=_default), - zip_longest(*map(reversed, shapes), fillvalue=-1), + arrays_dims = tuple(a.dims for a in arrays) + arrays_shapes = tuple(a.shape for a in arrays) + + sizes: dict[Any, Any] = {} + for dims, shape in zip( + zip_longest(*map(reversed, arrays_dims), fillvalue=_default), + zip_longest(*map(reversed, arrays_shapes), fillvalue=-1), ): - _d = tuple(set(v for v in d if v is not _default)) - if any(_isnone(sizes)): - # dim = None - raise NotImplementedError("TODO: Handle None in shape, {shapes = }") - else: - dim = max(sizes) + for d, s in zip(reversed(dims), reversed(shape)): + if isinstance(d, Default): + continue + + if s is None: + raise NotImplementedError("TODO: Handle None in shape, {shapes = }") + + s_prev = sizes.get(d, -1) + if s_prev not in (-1, 0, 1, s): + raise ValueError( + "operands could not be broadcast together with " + f"dims = {arrays_dims} and shapes = {arrays_shapes}" + ) + + sizes[d] = max(s, s_prev) + + out_dims: _Dims = tuple(reversed(sizes.keys())) + out_shape: _Shape = tuple(reversed(sizes.values())) + return out_dims, out_shape + if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1: raise ValueError( From 5e3da644a314ed88882b7adbeac1feae9139bdd4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 17 Sep 2024 21:56:27 +0200 Subject: [PATCH 298/367] Add dims for basic array creation --- xarray/namedarray/_array_api/_creation_functions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index ace65f0b43e..c7948435f9d 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -31,10 +31,11 @@ def arange( *, dtype: _DType | None = None, device: _Device | None = None, + dims: _DimsLike2 | Default = _default, ) -> NamedArray[_Shape1D, _DType]: xp = _get_namespace_dtype(dtype) _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) - _dims = _infer_dims(_data.shape) + _dims = _infer_dims(_data.shape, dims) return NamedArray(_dims, _data) @@ -209,6 +210,7 @@ def linspace( dtype: _DType | None = None, device: _Device | None = None, endpoint: bool = True, + dims: _DimsLike2 | Default = _default, ) -> NamedArray[_Shape1D, _DType]: xp = _get_namespace_dtype(dtype) _data = xp.linspace( @@ -219,7 +221,7 @@ def linspace( device=device, endpoint=endpoint, ) - _dims = _infer_dims(_data.shape) + _dims = _infer_dims(_data.shape, dims) return NamedArray(_dims, _data) From f45429d5e813d7b5cb42aef9f57805f67ee8516c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 22:57:13 +0200 Subject: [PATCH 299/367] add examples --- .../_array_api/_manipulation_functions.py | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index a58cfb30f9a..cb832e092c7 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -31,9 +31,21 @@ def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any] Examples -------- - >>> x = xp.asarray([[1, 2, 3]]) - >>> y = xp.asarray([[4], [5]]) - >>> xp.broadcast_arrays(x, y) + >>> import numpy as np + >>> x = NamedArray(("x",), np.zeros((3,))) + >>> y = NamedArray(("y", "x"), np.zeros((2, 1))) + >>> x_new, y_new = broadcast_arrays(x, y) + >>> x_new.dims, x_new.shape, y_new.dims, y_new.shape + (('y', 'x'), (2, 3), ('y', 'x'), (2, 3)) + + Errors + + >>> x = NamedArray(("x",), np.zeros((3,))) + >>> y = NamedArray(("x",), np.zeros((2))) + >>> x_new, y_new = broadcast_arrays(x, y) + Traceback (most recent call last): + ... + ValueError: operands could not be broadcast together with dims = (('x',), ('x',)) and shapes = ((3,), (2,)) """ x = arrays[0] xp = _get_data_namespace(x) @@ -46,6 +58,20 @@ def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any] def broadcast_to( x: NamedArray[Any, _DType], /, shape: _ShapeType ) -> NamedArray[_ShapeType, _DType]: + """ + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.arange(0, 3)) + >>> x_new = broadcast_to(x, (1, 1, 3)) + >>> x_new.dims, x_new.shape + (('dim_1', 'dim_0', 'x'), (1, 1, 3)) + + >>> x_new = broadcast_to(x, shape=(1, 1, 3), dims=("y", "x")) + >>> x_new.dims, x_new.shape + (('dim_0', 'y', 'x'), (1, 1, 3)) + """ xp = _get_data_namespace(x) _data = xp.broadcast_to(x._data, shape=shape) _dims = _infer_dims(_data.shape) # TODO: Fix dims @@ -93,15 +119,14 @@ def expand_dims( Examples -------- + >>> import numpy as np >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) - >>> expand_dims(x) - Size: 32B - array([[[1., 2.], - [3., 4.]]]) - >>> expand_dims(x, dim="z") - Size: 32B - array([[[1., 2.], - [3., 4.]]]) + >>> x_new = expand_dims(x) + >>> x_new.dims, x_new.shape + (('dim_2', 'x', 'y'), (1, 2, 2)) + >>> x_new = expand_dims(x, dim="z") + >>> x_new.dims, x_new.shape + (('z', 'x', 'y'), (1, 2, 2)) """ xp = _get_data_namespace(x) _data = xp.expand_dims(x._data, axis=axis) @@ -152,6 +177,7 @@ def permute_dims( Examples -------- + >>> import numpy as np >>> x = NamedArray(("x", "y", "z"), np.zeros((3, 4, 5))) >>> y = permute_dims(x, (2, 1, 0)) >>> y.dims, y.shape From 7efcef28562f8fd7345b0054387dcb6c9560399f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 22:58:46 +0200 Subject: [PATCH 300/367] add dim keywords --- xarray/namedarray/_array_api/_manipulation_functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index cb832e092c7..4181a80e03c 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -56,7 +56,11 @@ def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any] def broadcast_to( - x: NamedArray[Any, _DType], /, shape: _ShapeType + x: NamedArray[Any, _DType], + /, + shape: _ShapeType, + *, + dims: _DimsLike2 | Default = _default, ) -> NamedArray[_ShapeType, _DType]: """ @@ -96,8 +100,8 @@ def expand_dims( x: NamedArray[Any, _DType], /, *, - dim: _Dim | Default = _default, axis: _Axis = 0, + dim: _Dim | Default = _default, ) -> NamedArray[Any, _DType]: """ Expands the shape of an array by inserting a new dimension of size one at the From 04b01b37b0a9b467375610b06c1a1462674e0101 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 22:59:18 +0200 Subject: [PATCH 301/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 4181a80e03c..869be62cb93 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -78,7 +78,7 @@ def broadcast_to( """ xp = _get_data_namespace(x) _data = xp.broadcast_to(x._data, shape=shape) - _dims = _infer_dims(_data.shape) # TODO: Fix dims + _dims = _infer_dims(_data.shape, x.dims if isinstance(dims, Default) else dims) return x._new(_dims, _data) @@ -114,7 +114,7 @@ def expand_dims( dim : Dimension name. New dimension will be stored in the axis position. axis : - (Not recommended) Axis position (zero-based). Default is 0. + Axis position (zero-based). Default is 0. Returns ------- @@ -132,6 +132,9 @@ def expand_dims( >>> x_new.dims, x_new.shape (('z', 'x', 'y'), (1, 2, 2)) """ + # Array Api does not support multiple axes, but maybe in the future: + # https://github.com/data-apis/array-api/issues/760 + # xref: https://github.com/numpy/numpy/blob/3b246c6488cf246d488bbe5726ca58dc26b6ea74/numpy/lib/_shape_base_impl.py#L509C17-L509C24 xp = _get_data_namespace(x) _data = xp.expand_dims(x._data, axis=axis) _dims = _insert_dim(x.dims, dim, axis) From 4a9736438ab77453700aaa2282e866e106173be7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 22:59:40 +0200 Subject: [PATCH 302/367] Add arithmetic broadcasting --- .../_array_api/_manipulation_functions.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 869be62cb93..235fc9e0522 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -21,6 +21,8 @@ _DType, _ShapeType, _Dims, + _DimsLike2, + _Shape, ) from xarray.namedarray.core import NamedArray @@ -277,3 +279,138 @@ def unstack( _dims = _infer_dims(_data.shape) # TODO: Fix dims out += (x._new(_dims, _data),) return out + + +# %% Automatic broadcasting +_OPTIONS = {} +_OPTIONS["arithmetic_broadcast"] = True + + +def _set_dims( + x: NamedArray[Any, Any], dim: _Dims, shape: _Shape | None +) -> NamedArray[Any, Any]: + """ + Return a new array with given set of dimensions. + This method might be used to attach new dimension(s) to array. + + When possible, this operation does not copy this variable's data. + + Parameters + ---------- + dim : + Dimensions to include on the new variable. + shape : + Shape of the dimensions. If None, new dimensions are inserted with length 1. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) + >>> x_new = _set_dims(x, ("y", "x"), None) + >>> x_new.dims, x_new.shape + (('y', 'x'), (1, 3)) + >>> x_new = _set_dims(x, ("x", "y"), None) + >>> x_new.dims, x_new.shape + (('x', 'y'), (3, 1)) + + With shape: + + >>> x_new = _set_dims(x, ("y", "x"), (2, 3)) + >>> x_new.dims, x_new.shape + (('y', 'x'), (2, 3)) + + No operation + + >>> x_new = _set_dims(x, ("x",), None) + >>> x_new.dims, x_new.shape + (('x',), (3,)) + + Error + + >>> x_new = _set_dims(x, (), None) + Traceback (most recent call last): + ... + ValueError: new dimensions () must be a superset of existing dimensions ('x',) + """ + if x.dims == dim: + # No operation. Don't use broadcast_to unless necessary so the result + # remains writeable as long as possible: + return x + + extra_dims = tuple(d for d in dim if d not in x.dims) + if not extra_dims: + raise ValueError( + f"new dimensions {dim!r} must be a superset of " + f"existing dimensions {x.dims!r}" + ) + + if shape is not None: + # Add dimensions, with same size as shape: + dims_map = dict(zip(dim, shape)) + expanded_dims = extra_dims + x.dims + tmp_shape = tuple(dims_map[d] for d in expanded_dims) + return permute_dims(broadcast_to(x, tmp_shape, dims=expanded_dims), dims=dim) + else: + # Add dimensions, with size 1 only: + out = x + for d in extra_dims: + out = expand_dims(out, dim=d) + return permute_dims(out, dims=dim) + + +def _broadcast_arrays(*arrays: NamedArray[Any, Any]): + """ + TODO: Can this become xp.broadcast_arrays? + + Given any number of variables, return variables with matching dimensions + and broadcast data. + + The data on the returned variables may be a view of the data on the + corresponding original arrays but dimensions will be reordered and + inserted so that both broadcast arrays have the same dimensions. The new + dimensions are sorted in order of appearance in the first variable's + dimensions followed by the second variable's dimensions. + """ + dims, shape = _get_broadcasted_dims(*arrays) + return tuple(_set_dims(var, dims, shape) for var in arrays) + + +def _broadcast_arrays_with_minimal_size(*arrays: NamedArray[Any, Any]): + """ + Given any number of variables, return variables with matching dimensions. + + Unlike the result of broadcast_variables(), variables with missing dimensions + will have them added with size 1 instead of the size of the broadcast dimension. + """ + dims, _ = _get_broadcasted_dims(*arrays) + return tuple(_set_dims(var, dims, None) for var in arrays) + + +def _arithmetic_broadcast(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any]): + """ + Fu + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) + >>> y = NamedArray(("y",), np.asarray([4, 5])) + >>> x_new, y_new = _arithmetic_broadcast(x, y) + >>> x_new.dims, x_new.shape, y_new.dims, y_new.shape + (('x', 'y'), (3, 1), ('x', 'y'), (1, 2)) + """ + if not _OPTIONS["arithmetic_broadcast"]: + if x1.dims != x2.dims: + raise ValueError( + "Broadcasting is necessary but automatic broadcasting is disabled via " + "global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + ) + + return _broadcast_arrays_with_minimal_size(x1, x2) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() From f4f32681dc4e38108e71dc65edb9a3f9d66a313e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 23:00:02 +0200 Subject: [PATCH 303/367] Use arithmetic broadcasting --- .../_array_api/_elementwise_functions.py | 136 +++++++++++------- 1 file changed, 82 insertions(+), 54 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 0c395c22143..7adac52e476 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -4,6 +4,7 @@ import numpy as np +from xarray.namedarray._array_api._manipulation_functions import _arithmetic_broadcast from xarray.namedarray._array_api._utils import ( _get_broadcasted_dims, _get_data_namespace, @@ -41,8 +42,9 @@ def acosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: def add(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.add(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.add(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -71,8 +73,9 @@ def atan2( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.atan2(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.atan2(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -87,8 +90,9 @@ def bitwise_and( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.bitwise_and(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.bitwise_and(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -103,8 +107,9 @@ def bitwise_left_shift( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.bitwise_left_shift(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.bitwise_left_shift(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -112,8 +117,9 @@ def bitwise_or( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.bitwise_or(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.bitwise_or(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -121,8 +127,9 @@ def bitwise_right_shift( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.bitwise_right_shift(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.bitwise_right_shift(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -130,8 +137,9 @@ def bitwise_xor( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.bitwise_xor(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.bitwise_xor(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -165,8 +173,9 @@ def copysign( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.copysign(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.copysign(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -188,8 +197,9 @@ def divide( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.divide(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.divide(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -211,8 +221,9 @@ def equal( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.equal(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.equal(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -227,8 +238,9 @@ def floor_divide( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.floor_divide(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.floor_divide(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -236,8 +248,9 @@ def greater( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.greater(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.greater(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -245,8 +258,9 @@ def greater_equal( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.greater_equal(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.greater_equal(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -254,8 +268,9 @@ def hypot( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.hypot(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.hypot(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -315,8 +330,9 @@ def isnan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: def less(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.less(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.less(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -324,8 +340,9 @@ def less_equal( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.less_equal(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.less_equal(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -361,8 +378,9 @@ def logaddexp( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.logaddexp(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.logaddexp(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -370,8 +388,9 @@ def logical_and( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.logical_and(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.logical_and(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -386,8 +405,9 @@ def logical_or( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.logical_or(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.logical_or(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -395,8 +415,9 @@ def logical_xor( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.logical_xor(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.logical_xor(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -404,8 +425,9 @@ def maximum( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.maximum(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.maximum(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -413,8 +435,9 @@ def minimum( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.minimum(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.minimum(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -422,8 +445,9 @@ def multiply( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.multiply(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.multiply(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -438,8 +462,9 @@ def not_equal( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.not_equal(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.not_equal(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -452,8 +477,9 @@ def positive(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: def pow(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.pow(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.pow(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -494,8 +520,9 @@ def remainder( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.remainder(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.remainder(x1_new._data, x2_new._data) return NamedArray(_dims, _data) @@ -552,8 +579,9 @@ def subtract( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x1) - _dims, _ = _get_broadcasted_dims(x1, x2) - _data = xp.subtract(x1._data, x2._data) + x1_new, x2_new = _arithmetic_broadcast(x1, x2) + _dims = x1_new.dims + _data = xp.subtract(x1_new._data, x2_new._data) return NamedArray(_dims, _data) From 656ab62ace5c9f48ed9dc1a2ba5df9fb9cde7ffc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 23:01:17 +0200 Subject: [PATCH 304/367] Add doctests --- xarray/namedarray/_array_api/_utils.py | 54 ++++++++++++++++---------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 9b95da2b0dd..f086b629014 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -116,16 +116,28 @@ def _infer_dims( (None,) >>> _infer_dims((1,), ("x",)) ('x',) + >>> _infer_dims((1, 3), ("x",)) + ('dim_0', 'x') + >>> _infer_dims((1, 1, 3), ("x",)) + ('dim_1', 'dim_0', 'x') """ if isinstance(dims, Default): ndim = len(shape) return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim)) + + _dims = _normalize_dimensions(dims) + diff = len(shape) - len(_dims) + if diff > 0: + # TODO: Leads to ('dim_0', 'x'), should it be ('dim_1', 'x')? + return _infer_dims(shape[:diff], _default) + _dims else: - return _normalize_dimensions(dims) + return _dims def _normalize_axis_index(axis: int, ndim: int) -> int: """ + Normalize axis index to positive values. + Parameters ---------- axis : int @@ -139,10 +151,6 @@ def _normalize_axis_index(axis: int, ndim: int) -> int: normalized_axis : int The normalized axis index, such that `0 <= normalized_axis < ndim` - Raises - ------ - AxisError - If the axis index is invalid, when `-ndim <= axis < ndim` is false. Examples -------- @@ -150,20 +158,28 @@ def _normalize_axis_index(axis: int, ndim: int) -> int: 0 >>> _normalize_axis_index(1, ndim=3) 1 + >>> _normalize_axis_index(2, ndim=3) + 2 >>> _normalize_axis_index(-1, ndim=3) 2 + >>> _normalize_axis_index(-2, ndim=3) + 1 + >>> _normalize_axis_index(-3, ndim=3) + 0 + + Errors >>> _normalize_axis_index(3, ndim=3) Traceback (most recent call last): - ... - AxisError: axis 3 is out of bounds for array of dimension 3 - >>> _normalize_axis_index(-4, ndim=3, msg_prefix='axes_arg') + ... + ValueError: axis 3 is out of bounds for array of dimension 3 + >>> _normalize_axis_index(-4, ndim=3) Traceback (most recent call last): - ... - AxisError: axes_arg: axis -4 is out of bounds for array of dimension 3 + ... + ValueError: axis -4 is out of bounds for array of dimension 3 """ - if -ndim > axis >= ndim: + if -ndim > axis or axis >= ndim: raise ValueError(f"axis {axis} is out of bounds for array of dimension {ndim}") return axis % ndim @@ -248,7 +264,7 @@ def _dims_to_axis( -------- Convert to dims to axis values - + >>> import numpy as np >>> x = NamedArray(("x", "y", "z"), np.zeros((1, 2, 3))) >>> _dims_to_axis(x, ("y", "x"), None) (1, 0) @@ -257,7 +273,7 @@ def _dims_to_axis( >>> _dims_to_axis(x, _default, 0) (0,) >>> type(_dims_to_axis(x, _default, None)) - NoneType + Normalizes negative integers @@ -381,6 +397,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] Examples -------- + >>> import numpy as np >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> _get_broadcasted_dims(a) (('x', 'y', 'z'), (5, 3, 4)) @@ -455,12 +472,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] return out_dims, out_shape - if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1: - raise ValueError( - f"operands could not be broadcast together with {dims = } and {shapes = }" - ) - - out_dims += (_d[0],) - out_shape += (dim,) +if __name__ == "__main__": + import doctest - return out_dims[::-1], out_shape[::-1] + doctest.testmod() From 95f8490dc77baea8b2b48690dbeffe44584b173b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 19 Sep 2024 21:02:00 +0000 Subject: [PATCH 305/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../namedarray/_array_api/_elementwise_functions.py | 1 - .../namedarray/_array_api/_manipulation_functions.py | 12 +++++------- xarray/namedarray/_array_api/_searching_functions.py | 4 +++- xarray/namedarray/_array_api/_utils.py | 3 ++- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 7adac52e476..4585c668bc1 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -6,7 +6,6 @@ from xarray.namedarray._array_api._manipulation_functions import _arithmetic_broadcast from xarray.namedarray._array_api._utils import ( - _get_broadcasted_dims, _get_data_namespace, ) from xarray.namedarray._typing import ( diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 235fc9e0522..7a7b15fc449 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -4,25 +4,23 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( + _dims_to_axis, _get_broadcasted_dims, _get_data_namespace, _infer_dims, _insert_dim, - _dims_to_axis, ) from xarray.namedarray._typing import ( Default, - _arrayapi, _Axes, _Axis, - _AxisLike, _default, _Dim, - _DType, - _ShapeType, _Dims, _DimsLike2, + _DType, _Shape, + _ShapeType, ) from xarray.namedarray.core import NamedArray @@ -54,7 +52,7 @@ def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any] _dims, _ = _get_broadcasted_dims(*arrays) _arrays = tuple(a._data for a in arrays) _datas = xp.broadcast_arrays(*_arrays) - return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas)] + return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas, strict=False)] def broadcast_to( @@ -346,7 +344,7 @@ def _set_dims( if shape is not None: # Add dimensions, with same size as shape: - dims_map = dict(zip(dim, shape)) + dims_map = dict(zip(dim, shape, strict=False)) expanded_dims = extra_dims + x.dims tmp_shape = tuple(dims_map[d] for d in expanded_dims) return permute_dims(broadcast_to(x, tmp_shape, dims=expanded_dims), dims=dim) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 37388542118..38e851bf5f8 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -58,7 +58,9 @@ def nonzero(x: NamedArray[Any, Any], /) -> tuple[NamedArray[Any, Any], ...]: xp = _get_data_namespace(x) _datas: tuple[_arrayapi[Any, Any], ...] = xp.nonzero(x._data) # TODO: Verify that dims and axis matches here: - return tuple(x._new((dim,), data) for dim, data in zip(x.dims, _datas)) + return tuple( + x._new((dim,), data) for dim, data in zip(x.dims, _datas, strict=False) + ) def searchsorted( diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index f086b629014..30549494f98 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -450,8 +450,9 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] for dims, shape in zip( zip_longest(*map(reversed, arrays_dims), fillvalue=_default), zip_longest(*map(reversed, arrays_shapes), fillvalue=-1), + strict=False, ): - for d, s in zip(reversed(dims), reversed(shape)): + for d, s in zip(reversed(dims), reversed(shape), strict=False): if isinstance(d, Default): continue From 9d4b9d2c2e5e579cd0932d800bd194abf7eddaeb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 19 Sep 2024 23:05:41 +0200 Subject: [PATCH 306/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 30549494f98..777d7f84d39 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -303,8 +303,10 @@ def _dims_to_axis( for dim in _dims: try: axis += (x.dims.index(dim),) - except ValueError: - raise ValueError(f"{dim!r} not found in array dimensions {x.dims!r}") + except ValueError as err: + raise ValueError( + f"{dim!r} not found in array dimensions {x.dims!r}" + ) from err return axis if axis is None: From f64f96782df17fe90cdbb501838ca0d6db94908b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 00:11:20 +0200 Subject: [PATCH 307/367] reflexive should broadcast as well --- xarray/namedarray/_array_api/_utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 777d7f84d39..49dc8a24f88 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -410,6 +410,8 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) >>> _get_broadcasted_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) + >>> _get_broadcasted_dims(b, a) + (('x', 'y', 'z'), (5, 3, 4)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((1, 3, 4))) @@ -445,13 +447,17 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] ... ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4)) """ + DEFAULT_SIZE = -1 + BROADCASTABLE_SIZES = (0, 1) + BROADCASTABLE_SIZES_OR_DEFAULT = BROADCASTABLE_SIZES + (DEFAULT_SIZE,) + arrays_dims = tuple(a.dims for a in arrays) arrays_shapes = tuple(a.shape for a in arrays) sizes: dict[Any, Any] = {} for dims, shape in zip( zip_longest(*map(reversed, arrays_dims), fillvalue=_default), - zip_longest(*map(reversed, arrays_shapes), fillvalue=-1), + zip_longest(*map(reversed, arrays_shapes), fillvalue=DEFAULT_SIZE), strict=False, ): for d, s in zip(reversed(dims), reversed(shape), strict=False): @@ -461,8 +467,11 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] if s is None: raise NotImplementedError("TODO: Handle None in shape, {shapes = }") - s_prev = sizes.get(d, -1) - if s_prev not in (-1, 0, 1, s): + s_prev = sizes.get(d, DEFAULT_SIZE) + if not ( + s == s_prev + or any(v in BROADCASTABLE_SIZES_OR_DEFAULT for v in (s, s_prev)) + ): raise ValueError( "operands could not be broadcast together with " f"dims = {arrays_dims} and shapes = {arrays_shapes}" From 5540702c9f8fb309ee34b03d14b7b9c1598e13f6 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 00:20:37 +0200 Subject: [PATCH 308/367] typing --- .../namedarray/_array_api/_manipulation_functions.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 7a7b15fc449..a54066cac1b 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -356,7 +356,7 @@ def _set_dims( return permute_dims(out, dims=dim) -def _broadcast_arrays(*arrays: NamedArray[Any, Any]): +def _broadcast_arrays(*arrays: NamedArray[Any, Any]) -> NamedArray[Any, Any]: """ TODO: Can this become xp.broadcast_arrays? @@ -373,7 +373,9 @@ def _broadcast_arrays(*arrays: NamedArray[Any, Any]): return tuple(_set_dims(var, dims, shape) for var in arrays) -def _broadcast_arrays_with_minimal_size(*arrays: NamedArray[Any, Any]): +def _broadcast_arrays_with_minimal_size( + *arrays: NamedArray[Any, Any] +) -> NamedArray[Any, Any]: """ Given any number of variables, return variables with matching dimensions. @@ -384,7 +386,9 @@ def _broadcast_arrays_with_minimal_size(*arrays: NamedArray[Any, Any]): return tuple(_set_dims(var, dims, None) for var in arrays) -def _arithmetic_broadcast(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any]): +def _arithmetic_broadcast( + x1: NamedArray[Any, Any], x2: NamedArray[Any, Any] +) -> NamedArray[Any, Any]: """ Fu From b2696509d4d60b7997b8c440e9a664b0fcb10122 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 05:57:26 +0200 Subject: [PATCH 309/367] Handle unordered dims --- .../_array_api/_manipulation_functions.py | 22 ++++++++++++++----- xarray/namedarray/core.py | 1 + 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index a54066cac1b..114a3480efe 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -323,8 +323,17 @@ def _set_dims( >>> x_new.dims, x_new.shape (('x',), (3,)) + + Unordered dims + + >>> x = NamedArray(("y", "x"), np.zeros((2, 3))) + >>> x_new = _set_dims(x, ("x", "y"), None) + >>> x_new.dims, x_new.shape + (('x', 'y'), (3, 2)) + Error + >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) >>> x_new = _set_dims(x, (), None) Traceback (most recent call last): ... @@ -335,13 +344,15 @@ def _set_dims( # remains writeable as long as possible: return x - extra_dims = tuple(d for d in dim if d not in x.dims) - if not extra_dims: + missing_dims = set(x.dims) - set(dim) + if missing_dims: raise ValueError( f"new dimensions {dim!r} must be a superset of " f"existing dimensions {x.dims!r}" ) + extra_dims = tuple(d for d in dim if d not in x.dims) + if shape is not None: # Add dimensions, with same size as shape: dims_map = dict(zip(dim, shape, strict=False)) @@ -404,9 +415,10 @@ def _arithmetic_broadcast( if not _OPTIONS["arithmetic_broadcast"]: if x1.dims != x2.dims: raise ValueError( - "Broadcasting is necessary but automatic broadcasting is disabled via " - "global option `'arithmetic_broadcast'`. " - "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting." + "Broadcasting is necessary but automatic broadcasting is disabled " + "via global option `'arithmetic_broadcast'`. " + "Use `xr.set_options(arithmetic_broadcast=True)` to enable " + "automatic broadcasting." ) return _broadcast_arrays_with_minimal_size(x1, x2) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index e3f98a7f205..42983104319 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -432,6 +432,7 @@ def _maybe_asarray( # x is proper array. Respect the chosen dtype. return x # x is a scalar. Use the same dtype as self. + # TODO: Is this a good idea? x[Any, int] + 1.4 => int result then. return asarray(x, dtype=self.dtype) # Required methods below: From 218eba1aab9d17951e12f187395a347e5adcdc3c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 06:05:28 +0200 Subject: [PATCH 310/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 114a3480efe..85b5065258f 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -331,7 +331,7 @@ def _set_dims( >>> x_new.dims, x_new.shape (('x', 'y'), (3, 2)) - Error + Errors >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) >>> x_new = _set_dims(x, (), None) From 8ee6172c7f075882cbaeef76e1ad4ca5707938bb Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 06:43:06 +0200 Subject: [PATCH 311/367] subclasses needs to be passed down --- .../_array_api/_elementwise_functions.py | 130 +++++++++--------- 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/xarray/namedarray/_array_api/_elementwise_functions.py b/xarray/namedarray/_array_api/_elementwise_functions.py index 4585c668bc1..4158a595bab 100644 --- a/xarray/namedarray/_array_api/_elementwise_functions.py +++ b/xarray/namedarray/_array_api/_elementwise_functions.py @@ -22,21 +22,21 @@ def abs(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.abs(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def acos(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.acos(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def acosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.acosh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def add(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: @@ -44,28 +44,28 @@ def add(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.add(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def asin(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.asin(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def asinh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.asinh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def atan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.atan(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def atan2( @@ -75,14 +75,14 @@ def atan2( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.atan2(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def atanh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.atanh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def bitwise_and( @@ -92,14 +92,14 @@ def bitwise_and( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_and(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_invert(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.bitwise_invert(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def bitwise_left_shift( @@ -109,7 +109,7 @@ def bitwise_left_shift( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_left_shift(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_or( @@ -119,7 +119,7 @@ def bitwise_or( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_or(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_right_shift( @@ -129,7 +129,7 @@ def bitwise_right_shift( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_right_shift(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def bitwise_xor( @@ -139,14 +139,14 @@ def bitwise_xor( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.bitwise_xor(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def ceil(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.ceil(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def clip( @@ -158,14 +158,14 @@ def clip( xp = _get_data_namespace(x) _dims = x.dims _data = xp.clip(x._data, min=min, max=max) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def conj(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.conj(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def copysign( @@ -175,21 +175,21 @@ def copysign( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.copysign(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def cos(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.cos(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def cosh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.cosh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def divide( @@ -199,21 +199,21 @@ def divide( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.divide(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def exp(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.exp(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def expm1(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.expm1(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def equal( @@ -223,14 +223,14 @@ def equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def floor(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.floor(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def floor_divide( @@ -240,7 +240,7 @@ def floor_divide( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.floor_divide(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def greater( @@ -250,7 +250,7 @@ def greater( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.greater(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def greater_equal( @@ -260,7 +260,7 @@ def greater_equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.greater_equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def hypot( @@ -270,7 +270,7 @@ def hypot( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.hypot(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def imag( @@ -303,28 +303,28 @@ def imag( xp = _get_data_namespace(x) _dims = x.dims _data = xp.imag(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def isfinite(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isfinite(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def isinf(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isinf(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def isnan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.isnan(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def less(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: @@ -332,7 +332,7 @@ def less(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[An x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.less(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def less_equal( @@ -342,35 +342,35 @@ def less_equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.less_equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def log(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def log1p(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log1p(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def log2(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log2(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def log10(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.log10(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def logaddexp( @@ -380,7 +380,7 @@ def logaddexp( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logaddexp(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def logical_and( @@ -390,14 +390,14 @@ def logical_and( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logical_and(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def logical_not(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.logical_not(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def logical_or( @@ -407,7 +407,7 @@ def logical_or( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logical_or(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def logical_xor( @@ -417,7 +417,7 @@ def logical_xor( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.logical_xor(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def maximum( @@ -427,7 +427,7 @@ def maximum( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.maximum(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def minimum( @@ -437,7 +437,7 @@ def minimum( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.minimum(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def multiply( @@ -447,14 +447,14 @@ def multiply( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.multiply(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def negative(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.negative(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def not_equal( @@ -464,14 +464,14 @@ def not_equal( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.not_equal(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def positive(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.positive(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def pow(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: @@ -479,7 +479,7 @@ def pow(x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /) -> NamedArray[Any x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.pow(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def real( @@ -512,7 +512,7 @@ def real( xp = _get_data_namespace(x) _dims = x.dims _data = xp.real(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def remainder( @@ -522,56 +522,56 @@ def remainder( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.remainder(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def round(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.round(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sign(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sign(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def signbit(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.signbit(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sin(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sin(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sinh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sinh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def sqrt(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.sqrt(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def square(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.square(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def subtract( @@ -581,25 +581,25 @@ def subtract( x1_new, x2_new = _arithmetic_broadcast(x1, x2) _dims = x1_new.dims _data = xp.subtract(x1_new._data, x2_new._data) - return NamedArray(_dims, _data) + return x1._new(_dims, _data) def tan(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.tan(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def tanh(x: NamedArray[_ShapeType, Any], /) -> NamedArray[_ShapeType, Any]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.tanh(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) def trunc(x: NamedArray[_ShapeType, _DType], /) -> NamedArray[_ShapeType, _DType]: xp = _get_data_namespace(x) _dims = x.dims _data = xp.trunc(x._data) - return NamedArray(_dims, _data) + return x._new(_dims, _data) From 90e4819f0a890c135305919f25984c85159db057 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:59:21 +0200 Subject: [PATCH 312/367] prioritize variable arithmetics --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a8c1e004616..9cf1e541bee 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -349,7 +349,7 @@ def _as_array_or_item(data): return data -class Variable(NamedArray, AbstractArray, VariableArithmetic): +class Variable(AbstractArray, VariableArithmetic, NamedArray): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a From c295100f9eb2e846e59ba60f59e76e10a3306b2b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 08:48:15 +0200 Subject: [PATCH 313/367] Update indexing.py --- xarray/core/indexing.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 67912908a2b..e79570d93d0 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -505,7 +505,9 @@ class ExplicitlyIndexed: __slots__ = () - def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + def __array__( + self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + ) -> np.ndarray: # Leave casting to an array up to the underlying array type. return np.asarray(self.get_duck_array(), dtype=dtype) @@ -520,10 +522,10 @@ def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) return self[key] - def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: - # This is necessary because we apply the indexing key in self.get_duck_array() - # Note this is the base class for all lazy indexing classes - return np.asarray(self.get_duck_array(), dtype=dtype) + # def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # # This is necessary because we apply the indexing key in self.get_duck_array() + # # Note this is the base class for all lazy indexing classes + # return np.asarray(self.get_duck_array(), dtype=dtype) def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( @@ -570,7 +572,9 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): self.array = as_indexable(array) self.indexer_cls = indexer_cls - def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + def __array__( + self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + ) -> np.ndarray: return np.asarray(self.get_duck_array(), dtype=dtype) def get_duck_array(self): @@ -830,8 +834,8 @@ def __init__(self, array): def _ensure_cached(self): self.array = as_indexable(self.array.get_duck_array()) - def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: - return np.asarray(self.get_duck_array(), dtype=dtype) + # def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # return np.asarray(self.get_duck_array(), dtype=dtype) def get_duck_array(self): self._ensure_cached() @@ -1674,7 +1678,9 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None): def dtype(self) -> np.dtype: return self._dtype - def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + def __array__( + self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + ) -> np.ndarray: if dtype is None: dtype = self.dtype array = self.array @@ -1831,7 +1837,9 @@ def __init__( super().__init__(array, dtype) self.level = level - def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + def __array__( + self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + ) -> np.ndarray: if dtype is None: dtype = self.dtype if self.level is not None: From d02f4f2c8f55fcf49f861177331db43694839615 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 08:59:26 +0200 Subject: [PATCH 314/367] add copy to __array__ --- xarray/core/indexing.py | 21 ++++++++------------- xarray/namedarray/_typing.py | 6 +++--- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index e79570d93d0..55b179a6b7e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -506,10 +506,10 @@ class ExplicitlyIndexed: __slots__ = () def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None ) -> np.ndarray: # Leave casting to an array up to the underlying array type. - return np.asarray(self.get_duck_array(), dtype=dtype) + return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy) def get_duck_array(self): return self.array @@ -522,11 +522,6 @@ def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) return self[key] - # def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: - # # This is necessary because we apply the indexing key in self.get_duck_array() - # # Note this is the base class for all lazy indexing classes - # return np.asarray(self.get_duck_array(), dtype=dtype) - def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( f"{self.__class__.__name__}._oindex_get method should be overridden" @@ -573,9 +568,9 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): self.indexer_cls = indexer_cls def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None ) -> np.ndarray: - return np.asarray(self.get_duck_array(), dtype=dtype) + return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy) def get_duck_array(self): return self.array.get_duck_array() @@ -1679,7 +1674,7 @@ def dtype(self) -> np.dtype: return self._dtype def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None ) -> np.ndarray: if dtype is None: dtype = self.dtype @@ -1688,7 +1683,7 @@ def __array__( with suppress(AttributeError): # this might not be public API array = array.astype("object") - return np.asarray(array.values, dtype=dtype) + return np.asarray(array.values, dtype=dtype, copy=copy) def get_duck_array(self) -> np.ndarray: return np.asarray(self) @@ -1838,7 +1833,7 @@ def __init__( self.level = level def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: None | bool = None + self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None ) -> np.ndarray: if dtype is None: dtype = self.dtype @@ -1847,7 +1842,7 @@ def __array__( self.array.get_level_values(self.level).values, dtype=dtype ) else: - return super().__array__(dtype) + return super().__array__(dtype, copy=copy) def _convert_scalar(self, item): if isinstance(item, tuple) and self.level is not None: diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 9473db580dd..cb5ebac3613 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -239,15 +239,15 @@ def __getitem__( @overload def __array__( - self, dtype: None = ..., /, *, copy: None | bool = ... + self, dtype: None = ..., /, *, copy: bool | None = ... ) -> np.ndarray[Any, _DType_co]: ... @overload def __array__( - self, dtype: _DType, /, *, copy: None | bool = ... + self, dtype: _DType, /, *, copy: bool | None = ... ) -> np.ndarray[Any, _DType]: ... def __array__( - self, dtype: _DType | None = ..., /, *, copy: None | bool = ... + self, dtype: _DType | None = ..., /, *, copy: bool | None = ... ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ... # TODO: Should return the same subclass but with a new dtype generic. From 825805bb6c9d295416f85dc7f56d2be307336d53 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 15:35:44 +0200 Subject: [PATCH 315/367] calculate dims from tuple indexing --- xarray/namedarray/_array_api/_utils.py | 58 +++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 49dc8a24f88..ec168dae629 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -20,6 +20,7 @@ _dtype, _Shape, duckarray, + _IndexKeys, ) from xarray.namedarray.core import NamedArray @@ -366,9 +367,31 @@ def _get_remaining_dims( return dims, data +def _new_unique_dim_name(dims: _Dims, i=None) -> _Dim: + """ + Get a new unique dimension name. + + Examples + -------- + >>> _new_unique_dim_name(()) + 'dim_0' + >>> _new_unique_dim_name(("dim_0",)) + 'dim_1' + >>> _new_unique_dim_name(("dim_1", "dim_0")) + 'dim_2' + >>> _new_unique_dim_name(("dim_0", "dim_2")) + 'dim_3' + >>> _new_unique_dim_name(("dim_3", "dim_2")) + 'dim_4' + """ + i = len(dims) if i is None else i + _dim: _Dim = f"dim_{i}" + return _new_unique_dim_name(dims, i=i + 1) if _dim in dims else _dim + + def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: if isinstance(dim, Default): - _dim: _Dim = f"dim_{len(dims)}" + _dim: _Dim = _new_unique_dim_name(dims) else: _dim = dim @@ -377,6 +400,39 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: return tuple(d) +def dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: + """ + Get the expected dims when using tuples in __getitem__. + + Examples + -------- + >>> dims_from_tuple_indexing(("x", "y"), ()) + ('x', 'y') + >>> dims_from_tuple_indexing(("x", "y"), (0,)) + ('y',) + >>> dims_from_tuple_indexing(("x", "y"), (0, 0)) + () + >>> dims_from_tuple_indexing(("x", "y"), (0, ...)) + ('y',) + >>> dims_from_tuple_indexing(("x", "y"), (0, slice(0))) + ('y',) + >>> dims_from_tuple_indexing(("x", "y"), (None,)) + ('dim_2', 'x', 'y') + >>> dims_from_tuple_indexing(("x", "y"), (0, None, None, 0)) + ('dim_1', 'dim_2') + """ + _dims = list(dims) + j = 0 + for i, v in enumerate(key): + if v is None: + _dims.insert(j, _new_unique_dim_name(tuple(_dims))) + elif isinstance(v, int): + _dims.pop(j) + j -= 1 + j += 1 + return tuple(_dims) + + def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: From e29287da159dfea0495a7df803436bd140ae760a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 15:36:10 +0200 Subject: [PATCH 316/367] Handle getitem better --- xarray/namedarray/core.py | 93 +++++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 13 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 42983104319..f8017116d39 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -504,16 +504,82 @@ def __ge__(self, other: int | float | NamedArray, /): return greater_equal(self, self._maybe_asarray(other)) def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: - from xarray.namedarray._array_api._utils import _infer_dims + """ + sdfd + + Some rules + * Integers removes the dim. + * Slices and ellipsis maintains same dim. + * None adds a dim. + * tuple follows above but on that specific axis. + + Examples + -------- + + 1D + + >>> x = NamedArray(("x",), np.array([0, 1, 2])) + >>> mask = NamedArray(("x",), np.array([1, 0, 0], dtype=bool)) + >>> xm = x[mask] + >>> xm.dims, xm.shape + (('x',), (1,)) + + >>> x = NamedArray(("x",), np.array([0, 1, 2])) + >>> mask = NamedArray(("x",), np.array([0, 0, 0], dtype=bool)) + >>> xm = x[mask] + >>> xm.dims, xm.shape + (('x',), (0,)) + + Setup a ND array: + + >>> x = NamedArray(("x", "y"), np.arange(3*4).reshape((3, 4))) + >>> xm = x[0] + >>> xm.dims, xm.shape + (('y',), (4,)) + >>> xm = x[slice(0)] + >>> xm.dims, xm.shape + (('x', 'y'), (0, 4)) + >>> xm = x[None] + >>> xm.dims, xm.shape + (('dim_2', 'x', 'y'), (1, 3, 4)) + + >>> mask = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool)) + >>> xm = x[mask] + >>> xm.dims, xm.shape + ((('x', 'y'),), (12,)) + """ + from xarray.namedarray._array_api._manipulation_functions import ( + _broadcast_arrays, + expand_dims, + ) + from xarray.namedarray._array_api._utils import dims_from_tuple_indexing if isinstance(key, NamedArray): - _key = key._data # TODO: Transpose, unordered dims shouldn't matter. - _data = self._data[_key] - _dims = _infer_dims(_data.shape) # TODO: fix + self_new, key_new = _broadcast_arrays(self, key) + _data = self_new._data[key_new._data] + if self_new.ndim > 1: + # ND-arrays are raveled and then masked: + # _dims = f"{'_'.join(self_new.dims)}_ravel" + _dims = (self_new.dims,) + else: + _dims = self_new.dims return self._new(_dims, _data) + # elif isinstance(key, int): + # return self._new(self.dims[1:], self._data[key]) + # elif isinstance(key, slice) or key is ...: + # return self._new(self.dims, self._data[key]) + # elif key is None: + # return expand_dims(self) + # elif isinstance(key, tuple): + # _dims = dims_from_tuple_indexing(self.dims, key) + # return self._new(_dims, self._data[key]) + elif isinstance(key, int | slice | tuple) or key is None or key is ...: + # TODO: __getitem__ not always available, use expand_dims _data = self._data[key] - _dims = _infer_dims(_data.shape) # TODO: fix + _dims = dims_from_tuple_indexing( + self.dims, key if isinstance(key, tuple) else (key,) + ) return self._new(_dims, _data) else: raise NotImplementedError(f"{key=} is not supported") @@ -1456,7 +1522,7 @@ def broadcast_to( Examples -------- >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) - >>> array = xr.NamedArray(("x", "y"), data) + >>> array = NamedArray(("x", "y"), data) >>> array.sizes {'x': 2, 'y': 2} @@ -1511,16 +1577,11 @@ def expand_dims( Examples -------- - >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) - >>> array = xr.NamedArray(("x", "y"), data) - - - # expand dimensions by specifying a new dimension name - >>> expanded = array.expand_dims(dim="z") + >>> x = NamedArray(("x", "y"), data) + >>> expanded = x.expand_dims(dim="z") >>> expanded.dims ('z', 'x', 'y') - """ from xarray.namedarray._array_api import expand_dims @@ -1539,3 +1600,9 @@ def _raise_if_any_duplicate_dimensions( raise ValueError( f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}" ) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() From 28df3071aa3fc0c873e4338a6690eb4f347bf6bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Sep 2024 13:37:01 +0000 Subject: [PATCH 317/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 2 +- xarray/namedarray/core.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index ec168dae629..5032e73a4ae 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -18,9 +18,9 @@ _DimsLike2, _DType, _dtype, + _IndexKeys, _Shape, duckarray, - _IndexKeys, ) from xarray.namedarray.core import NamedArray diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index f8017116d39..ca2d1612173 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -532,7 +532,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: Setup a ND array: - >>> x = NamedArray(("x", "y"), np.arange(3*4).reshape((3, 4))) + >>> x = NamedArray(("x", "y"), np.arange(3 * 4).reshape((3, 4))) >>> xm = x[0] >>> xm.dims, xm.shape (('y',), (4,)) @@ -550,7 +550,6 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: """ from xarray.namedarray._array_api._manipulation_functions import ( _broadcast_arrays, - expand_dims, ) from xarray.namedarray._array_api._utils import dims_from_tuple_indexing From eb73c2d4d849a607afeacdf60540b12508e73485 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:10:36 +0200 Subject: [PATCH 318/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 5032e73a4ae..2ffd2b24ebb 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -433,6 +433,10 @@ def dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: return tuple(_dims) +def _flattened_dims(dims: _Dims, ndim: int) -> _Dims: + return (dims,) if ndim > 1 else dims + + def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: From a450e8aa1deb432cb424ee56c563a72245c20ff9 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:11:06 +0200 Subject: [PATCH 319/367] Update core.py --- xarray/namedarray/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index ca2d1612173..0021b9b96e1 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -71,6 +71,7 @@ _ScalarType, _Shape, _ShapeType, + _IndexKeys, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -551,17 +552,15 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: from xarray.namedarray._array_api._manipulation_functions import ( _broadcast_arrays, ) - from xarray.namedarray._array_api._utils import dims_from_tuple_indexing + from xarray.namedarray._array_api._utils import ( + dims_from_tuple_indexing, + _flattened_dims, + ) if isinstance(key, NamedArray): self_new, key_new = _broadcast_arrays(self, key) _data = self_new._data[key_new._data] - if self_new.ndim > 1: - # ND-arrays are raveled and then masked: - # _dims = f"{'_'.join(self_new.dims)}_ravel" - _dims = (self_new.dims,) - else: - _dims = self_new.dims + _dims = _flattened_dims(self_new.dims, self_new.ndim) return self._new(_dims, _data) # elif isinstance(key, int): # return self._new(self.dims[1:], self._data[key]) From 9c3f6c13a8c4a56e57bbb770ce4c99b8770ef486 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Sep 2024 17:16:02 +0000 Subject: [PATCH 320/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 0021b9b96e1..2db3f609c72 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -71,7 +71,6 @@ _ScalarType, _Shape, _ShapeType, - _IndexKeys, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -553,8 +552,8 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: _broadcast_arrays, ) from xarray.namedarray._array_api._utils import ( - dims_from_tuple_indexing, _flattened_dims, + dims_from_tuple_indexing, ) if isinstance(key, NamedArray): From 73a7d2fe83ff444985705a8e8cfbaf9bbd398b37 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:17:30 +0200 Subject: [PATCH 321/367] improve dims handling --- .../_array_api/_manipulation_functions.py | 63 +++++++++++++-- .../namedarray/_array_api/_set_functions.py | 78 +++++++++++++++---- 2 files changed, 119 insertions(+), 22 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 85b5065258f..b6e904f3033 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -1,14 +1,18 @@ from __future__ import annotations +import math from typing import Any from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _dims_to_axis, + _dim_to_optional_axis, _get_broadcasted_dims, _get_data_namespace, _infer_dims, _insert_dim, + _new_unique_dim_name, + _flattened_dims, ) from xarray.namedarray._typing import ( Default, @@ -88,12 +92,29 @@ def concat( *, axis: _Axis | None = 0, ) -> NamedArray[Any, Any]: - xp = _get_data_namespace(arrays[0]) + """ + Joins a sequence of arrays along an existing axis. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.zeros((3,))) + >>> x1 = concat((x, 1+x)) + >>> x1.dims, x1.shape + (('x',), (6,)) + + >>> x = NamedArray(("x", "y"), np.zeros((3, 4))) + >>> x1 = concat((x, 1+x)) + >>> x1.dims, x1.shape + (('x', 'y'), (6, 4)) + """ + x = arrays[0] + xp = _get_data_namespace(x) + _axis = axis # TODO: add support for dim? dtype = result_type(*arrays) _arrays = tuple(a._data for a in arrays) - _data = xp.concat(_arrays, axis=axis, dtype=dtype) - _dims = _infer_dims(_data.shape) - return NamedArray(_dims, _data) + _data = xp.concat(_arrays, axis=_axis, dtype=dtype) + return NamedArray(x.dims, _data) def expand_dims( @@ -219,9 +240,41 @@ def repeat( def reshape( x: NamedArray[Any, _DType], /, shape: _ShapeType, *, copy: bool | None = None ) -> NamedArray[_ShapeType, _DType]: + """ + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.zeros((3,))) + >>> x1 = reshape(x, (-1,)) + >>> x1.dims, x1.shape + (('x',), (3,)) + + To N-dimensions + + >>> x1 = reshape(x, (1, -1, 1)) + >>> x1.dims, x1.shape + (('dim_0', 'x', 'dim_2'), (1, 3, 1)) + + >>> x = NamedArray(("x", "y"), np.zeros((3, 4))) + >>> x1 = reshape(x, (-1,)) + >>> x1.dims, x1.shape + ((('x', 'y'),), (12,)) + + """ xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape, copy=copy) - _dims = _infer_dims(_data.shape) # TODO: Fix dims + + if math.prod(shape) == -1: + # Flattening operations merges all dimensions to 1: + dims_raveled = _flattened_dims(x.dims, x.ndim) + dim = dims_raveled[0] + d = [] + for v in shape: + d.append(dim if v == -1 else _new_unique_dim_name(tuple(d))) + _dims = tuple(d) + else: + _dims = _infer_dims(_data.shape, x.dims) return x._new(_dims, _data) diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index 009a5bb96a8..21ca04c4612 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -5,6 +5,7 @@ from xarray.namedarray._array_api._utils import ( _get_data_namespace, _infer_dims, + _flattened_dims, ) from xarray.namedarray.core import NamedArray @@ -29,42 +30,85 @@ class UniqueInverseResult(NamedTuple): def unique_all(x: NamedArray[Any, Any], /) -> UniqueAllResult: xp = _get_data_namespace(x) values, indices, inverse_indices, counts = xp.unique_all(x._data) - _dims_values = _infer_dims(values.shape) # TODO: Fix - _dims_indices = _infer_dims(indices.shape) # TODO: Fix dims - _dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims - _dims_counts = _infer_dims(counts.shape) # TODO: Fix dims + _dims = _flattened_dims(x.dims, x.ndim) return UniqueAllResult( - NamedArray(_dims_values, values), - NamedArray(_dims_indices, indices), - NamedArray(_dims_inverse_indices, inverse_indices), - NamedArray(_dims_counts, counts), + NamedArray(_dims, values), + NamedArray(_dims, indices), + NamedArray(_dims, inverse_indices), + NamedArray(_dims, counts), ) def unique_counts(x: NamedArray[Any, Any], /) -> UniqueCountsResult: + """ + Returns the unique elements of an input array x and the corresponding + counts for each unique element in x. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) + >>> x_unique = unique_counts(x) + >>> x_unique.values + >>> x_unique.counts + + >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) + >>> x_unique = unique_counts(x) + >>> x_unique.values + >>> x_unique.counts + """ xp = _get_data_namespace(x) values, counts = xp.unique_counts(x._data) - _dims_values = _infer_dims(values.shape) # TODO: Fix dims - _dims_counts = _infer_dims(counts.shape) # TODO: Fix dims + _dims = _flattened_dims(x.dims, x.ndim) return UniqueCountsResult( - NamedArray(_dims_values, values), - NamedArray(_dims_counts, counts), + NamedArray(_dims, values), + NamedArray(_dims, counts), ) def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult: + """ + Returns the unique elements of an input array x and the indices + from the set of unique elements that reconstruct x. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) + >>> x_unique = unique_inverse(x) + >>> x_unique.values + >>> x_unique.counts + >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) + >>> x_unique = unique_inverse(x) + >>> x_unique.dims, x_unique.shape + (('x',), (3,)) + """ xp = _get_data_namespace(x) values, inverse_indices = xp.unique_inverse(x._data) - _dims_values = _infer_dims(values.shape) # TODO: Fix - _dims_inverse_indices = _infer_dims(inverse_indices.shape) # TODO: Fix dims + _dims = _flattened_dims(x.dims, x.ndim) return UniqueInverseResult( - NamedArray(_dims_values, values), - NamedArray(_dims_inverse_indices, inverse_indices), + NamedArray(_dims, values), + NamedArray(_dims, inverse_indices), ) def unique_values(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: + """ + Returns the unique elements of an input array x. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) + >>> x_unique = unique_values(x) + >>> x_unique.dims, x_unique.shape + (('x',), (3,)) + >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) + >>> x_unique = unique_values(x) + >>> x_unique.dims, x_unique.shape + (('x',), (3,)) + """ xp = _get_data_namespace(x) _data = xp.unique_values(x._data) - _dims = _infer_dims(_data.shape) # TODO: Fix + _dims = _flattened_dims(x.dims, x.ndim) return x._new(_dims, _data) From 45c6db008411c7e40516e540f42e036469355e81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Sep 2024 17:18:37 +0000 Subject: [PATCH 322/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_manipulation_functions.py | 7 +++---- xarray/namedarray/_array_api/_set_functions.py | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index b6e904f3033..38964742780 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -6,13 +6,12 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _dims_to_axis, - _dim_to_optional_axis, + _flattened_dims, _get_broadcasted_dims, _get_data_namespace, _infer_dims, _insert_dim, _new_unique_dim_name, - _flattened_dims, ) from xarray.namedarray._typing import ( Default, @@ -99,12 +98,12 @@ def concat( -------- >>> import numpy as np >>> x = NamedArray(("x",), np.zeros((3,))) - >>> x1 = concat((x, 1+x)) + >>> x1 = concat((x, 1 + x)) >>> x1.dims, x1.shape (('x',), (6,)) >>> x = NamedArray(("x", "y"), np.zeros((3, 4))) - >>> x1 = concat((x, 1+x)) + >>> x1 = concat((x, 1 + x)) >>> x1.dims, x1.shape (('x', 'y'), (6, 4)) """ diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index 21ca04c4612..dc8e544196f 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -3,9 +3,8 @@ from typing import Any, NamedTuple from xarray.namedarray._array_api._utils import ( - _get_data_namespace, - _infer_dims, _flattened_dims, + _get_data_namespace, ) from xarray.namedarray.core import NamedArray From e90508e8fb9e3f3950708f110679774318a336cd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:20:37 +0200 Subject: [PATCH 323/367] typo --- xarray/namedarray/_array_api/_utils.py | 2 +- xarray/namedarray/core.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 2ffd2b24ebb..fa5f6e27c89 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -400,7 +400,7 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: return tuple(d) -def dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: +def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: """ Get the expected dims when using tuples in __getitem__. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2db3f609c72..d454abecb2c 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -552,8 +552,8 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: _broadcast_arrays, ) from xarray.namedarray._array_api._utils import ( + _dims_from_tuple_indexing, _flattened_dims, - dims_from_tuple_indexing, ) if isinstance(key, NamedArray): @@ -568,13 +568,13 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: # elif key is None: # return expand_dims(self) # elif isinstance(key, tuple): - # _dims = dims_from_tuple_indexing(self.dims, key) + # _dims = _dims_from_tuple_indexing(self.dims, key) # return self._new(_dims, self._data[key]) elif isinstance(key, int | slice | tuple) or key is None or key is ...: # TODO: __getitem__ not always available, use expand_dims _data = self._data[key] - _dims = dims_from_tuple_indexing( + _dims = _dims_from_tuple_indexing( self.dims, key if isinstance(key, tuple) else (key,) ) return self._new(_dims, _data) From de0aa20da37344418f89e6cef4a3e515c9ca5b35 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:23:24 +0200 Subject: [PATCH 324/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index fa5f6e27c89..30a125d5f6c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -423,10 +423,10 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: """ _dims = list(dims) j = 0 - for i, v in enumerate(key): - if v is None: + for k in key: + if k is None: _dims.insert(j, _new_unique_dim_name(tuple(_dims))) - elif isinstance(v, int): + elif isinstance(k, int): _dims.pop(j) j -= 1 j += 1 From 17346f7b63e498769711abcd1e16e7c75287fe7f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:55:59 +0200 Subject: [PATCH 325/367] Update indexing.py --- xarray/core/indexing.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 55b179a6b7e..d3a6d340d90 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -505,9 +505,7 @@ class ExplicitlyIndexed: __slots__ = () - def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None - ) -> np.ndarray: + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: # Leave casting to an array up to the underlying array type. return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy) @@ -522,6 +520,11 @@ def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) return self[key] + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + # This is necessary because we apply the indexing key in self.get_duck_array() + # Note this is the base class for all lazy indexing classes + return np.asarray(self.get_duck_array(), dtype=dtype) + def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( f"{self.__class__.__name__}._oindex_get method should be overridden" @@ -567,10 +570,8 @@ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer): self.array = as_indexable(array) self.indexer_cls = indexer_cls - def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None - ) -> np.ndarray: - return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy) + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + return np.asarray(self.get_duck_array(), dtype=dtype) def get_duck_array(self): return self.array.get_duck_array() @@ -829,8 +830,8 @@ def __init__(self, array): def _ensure_cached(self): self.array = as_indexable(self.array.get_duck_array()) - # def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: - # return np.asarray(self.get_duck_array(), dtype=dtype) + def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: + return np.asarray(self.get_duck_array(), dtype=dtype) def get_duck_array(self): self._ensure_cached() @@ -1673,9 +1674,7 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None): def dtype(self) -> np.dtype: return self._dtype - def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None - ) -> np.ndarray: + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: if dtype is None: dtype = self.dtype array = self.array @@ -1832,9 +1831,7 @@ def __init__( super().__init__(array, dtype) self.level = level - def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None - ) -> np.ndarray: + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: if dtype is None: dtype = self.dtype if self.level is not None: From 1173cc094f99a1c1a66de585bca714f51c720081 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 19:59:04 +0200 Subject: [PATCH 326/367] Update indexing.py --- xarray/core/indexing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index d3a6d340d90..67912908a2b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -507,7 +507,7 @@ class ExplicitlyIndexed: def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: # Leave casting to an array up to the underlying array type. - return np.asarray(self.get_duck_array(), dtype=dtype, copy=copy) + return np.asarray(self.get_duck_array(), dtype=dtype) def get_duck_array(self): return self.array @@ -1682,7 +1682,7 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: with suppress(AttributeError): # this might not be public API array = array.astype("object") - return np.asarray(array.values, dtype=dtype, copy=copy) + return np.asarray(array.values, dtype=dtype) def get_duck_array(self) -> np.ndarray: return np.asarray(self) @@ -1839,7 +1839,7 @@ def __array__(self, dtype: DTypeLike = None) -> np.ndarray: self.array.get_level_values(self.level).values, dtype=dtype ) else: - return super().__array__(dtype, copy=copy) + return super().__array__(dtype) def _convert_scalar(self, item): if isinstance(item, tuple) and self.level is not None: From 8a5a041134fc40801f776360bfe149f2c1556d5d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 21 Sep 2024 20:35:27 +0200 Subject: [PATCH 327/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 38964742780..3e55c9bf608 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -264,7 +264,7 @@ def reshape( xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape, copy=copy) - if math.prod(shape) == -1: + if math.prod(shape) == -1 and False: # Flattening operations merges all dimensions to 1: dims_raveled = _flattened_dims(x.dims, x.ndim) dim = dims_raveled[0] From fae50aa0b67d806ed3b193e904429a988b4b1cdd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 22 Sep 2024 00:38:38 +0200 Subject: [PATCH 328/367] add examples --- .../_array_api/_linear_algebra_functions.py | 27 +++++++++++++++++++ xarray/namedarray/_array_api/_utils.py | 16 ++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index f8b730907af..f5e2cc0f389 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -10,6 +10,33 @@ def matmul( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], / ) -> NamedArray[Any, Any]: + """ + Matrix product of two arrays. + + Examples + -------- + For 2-D arrays it is the matrix product: + + >>> import numpy as np + >>> a = NamedArray(("y", "x"), np.array([[1, 0], [0, 1]])) + >>> b = NamedArray(("y", "x"), np.array([[4, 1], [2, 2]])) + >>> matmul(a, b) + + + For 2-D mixed with 1-D, the result is the usual. + + >>> a = NamedArray(("y", "x"), np.array([[1, 0], [0, 1]])) + >>> b = NamedArray(("x",), np.array([1, 2])) + >>> matmul(a, b) + + Broadcasting is conventional for stacks of arrays + + >>> a = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 2, 4))) + >>> b = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 4, 2))) + >>> axb = matmul(a,b) + >>> axb.dims, axb.shape + """ xp = _get_data_namespace(x1) _data = xp.matmul(x1._data, x2._data) # TODO: Figure out a better way: diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 30a125d5f6c..c0a6611427d 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -111,6 +111,8 @@ def _infer_dims( >>> _infer_dims((3, 1)) ('dim_1', 'dim_0') + >>> _infer_dims((), ()) + () >>> _infer_dims((1,), "x") ('x',) >>> _infer_dims((1,), None) @@ -406,19 +408,19 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: Examples -------- - >>> dims_from_tuple_indexing(("x", "y"), ()) + >>> _dims_from_tuple_indexing(("x", "y"), ()) ('x', 'y') - >>> dims_from_tuple_indexing(("x", "y"), (0,)) + >>> _dims_from_tuple_indexing(("x", "y"), (0,)) ('y',) - >>> dims_from_tuple_indexing(("x", "y"), (0, 0)) + >>> _dims_from_tuple_indexing(("x", "y"), (0, 0)) () - >>> dims_from_tuple_indexing(("x", "y"), (0, ...)) + >>> _dims_from_tuple_indexing(("x", "y"), (0, ...)) ('y',) - >>> dims_from_tuple_indexing(("x", "y"), (0, slice(0))) + >>> _dims_from_tuple_indexing(("x", "y"), (0, slice(0))) ('y',) - >>> dims_from_tuple_indexing(("x", "y"), (None,)) + >>> _dims_from_tuple_indexing(("x", "y"), (None,)) ('dim_2', 'x', 'y') - >>> dims_from_tuple_indexing(("x", "y"), (0, None, None, 0)) + >>> _dims_from_tuple_indexing(("x", "y"), (0, None, None, 0)) ('dim_1', 'dim_2') """ _dims = list(dims) From 52f7edf26530b7cb0022763a6895d5d17b0cba74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Sep 2024 22:52:21 +0000 Subject: [PATCH 329/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_linear_algebra_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index f5e2cc0f389..ae40344c5c2 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -34,7 +34,7 @@ def matmul( >>> a = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 2, 4))) >>> b = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 4, 2))) - >>> axb = matmul(a,b) + >>> axb = matmul(a, b) >>> axb.dims, axb.shape """ xp = _get_data_namespace(x1) From 805a0c59fb739366467d8d3dbd7f984c48b29a16 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 22 Sep 2024 16:36:13 +0200 Subject: [PATCH 330/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 3e55c9bf608..ab1b2e31000 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -273,7 +273,9 @@ def reshape( d.append(dim if v == -1 else _new_unique_dim_name(tuple(d))) _dims = tuple(d) else: - _dims = _infer_dims(_data.shape, x.dims) + # _dims = _infer_dims(_data.shape, x.dims) + _dims = _infer_dims(_data.shape) + return x._new(_dims, _data) From 4621e26e4aff749bea219c7172bca6b0849e5e9b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:45:08 +0200 Subject: [PATCH 331/367] Clarify the paths --- xarray/namedarray/_array_api/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index c0a6611427d..6a058547bba 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -427,10 +427,15 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: j = 0 for k in key: if k is None: + # None adds 1 dimension: _dims.insert(j, _new_unique_dim_name(tuple(_dims))) elif isinstance(k, int): + # Integer removes 1 dimension: _dims.pop(j) j -= 1 + else: + # Slices and Ellipsis maintains same dimensions: + pass j += 1 return tuple(_dims) From f6e90d052d39ca21b182980d5d9329219d46d63f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:50:20 +0200 Subject: [PATCH 332/367] Shorten name --- xarray/namedarray/_array_api/_utils.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 6a058547bba..5da70faaf4e 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -440,8 +440,22 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: return tuple(_dims) -def _flattened_dims(dims: _Dims, ndim: int) -> _Dims: - return (dims,) if ndim > 1 else dims +def _flatten_dims(dims: _Dims) -> _Dims: + """ + Flatten multidimensional dims to 1-dimensional. + + Examples + -------- + >>> _flatten_dims(()) + () + >>> _flatten_dims(("x",)) + ('x',) + >>> _flatten_dims(("x", "y")) + (('x', 'y'),) + """ + return (dims,) if len(dims) > 1 else dims + + def _raise_if_any_duplicate_dimensions( From 31ef484432f97cedd7b95105b4ab467ef6e785b4 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:50:54 +0200 Subject: [PATCH 333/367] Add atleast1d for dims, unique_values seems to need it --- xarray/namedarray/_array_api/_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 5da70faaf4e..cec92037c1c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -456,6 +456,20 @@ def _flatten_dims(dims: _Dims) -> _Dims: return (dims,) if len(dims) > 1 else dims +def _atleast1d_dims(dims: _Dims) -> _Dims: + """ + Set dims atleast 1-dimensional. + + Examples + -------- + >>> _atleast1d_dims(()) + ('dim_0',) + >>> _atleast1d_dims(("x",)) + ('x',) + >>> _atleast1d_dims(("x", "y")) + ('x', 'y') + """ + return (_new_unique_dim_name(dims),) if len(dims) < 1 else dims def _raise_if_any_duplicate_dimensions( From c049f983b6f3244245a59d3bf360fd7957f2a0ac Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:51:24 +0200 Subject: [PATCH 334/367] unique_values is atleast1d --- .../namedarray/_array_api/_set_functions.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index dc8e544196f..fe3903ae466 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -3,7 +3,8 @@ from typing import Any, NamedTuple from xarray.namedarray._array_api._utils import ( - _flattened_dims, + _atleast1d_dims, + _flatten_dims, _get_data_namespace, ) from xarray.namedarray.core import NamedArray @@ -29,7 +30,7 @@ class UniqueInverseResult(NamedTuple): def unique_all(x: NamedArray[Any, Any], /) -> UniqueAllResult: xp = _get_data_namespace(x) values, indices, inverse_indices, counts = xp.unique_all(x._data) - _dims = _flattened_dims(x.dims, x.ndim) + _dims = _atleast1d_dims(_flatten_dims(x.dims)) return UniqueAllResult( NamedArray(_dims, values), NamedArray(_dims, indices), @@ -58,7 +59,7 @@ def unique_counts(x: NamedArray[Any, Any], /) -> UniqueCountsResult: """ xp = _get_data_namespace(x) values, counts = xp.unique_counts(x._data) - _dims = _flattened_dims(x.dims, x.ndim) + _dims = _flatten_dims(_atleast1d_dims(x.dims)) return UniqueCountsResult( NamedArray(_dims, values), NamedArray(_dims, counts), @@ -76,7 +77,7 @@ def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult: >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) >>> x_unique = unique_inverse(x) >>> x_unique.values - >>> x_unique.counts + >>> x_unique.inverse_indices >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) >>> x_unique = unique_inverse(x) >>> x_unique.dims, x_unique.shape @@ -84,7 +85,7 @@ def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult: """ xp = _get_data_namespace(x) values, inverse_indices = xp.unique_inverse(x._data) - _dims = _flattened_dims(x.dims, x.ndim) + _dims = _flatten_dims(_atleast1d_dims(x.dims)) return UniqueInverseResult( NamedArray(_dims, values), NamedArray(_dims, inverse_indices), @@ -106,8 +107,15 @@ def unique_values(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: >>> x_unique = unique_values(x) >>> x_unique.dims, x_unique.shape (('x',), (3,)) + + # Scalars becomes 1-dimensional + + >>> x = NamedArray((), np.array(0, dtype=int)) + x_unique = unique_values(x) + >>> x_unique.dims, x_unique.shape + (('dim_0',), (1,)) """ xp = _get_data_namespace(x) _data = xp.unique_values(x._data) - _dims = _flattened_dims(x.dims, x.ndim) + _dims = _flatten_dims(_atleast1d_dims(x.dims)) return x._new(_dims, _data) From 749235bca502a14fe227a5659ad307a1d61e7bbd Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:51:37 +0200 Subject: [PATCH 335/367] Update _manipulation_functions.py --- xarray/namedarray/_array_api/_manipulation_functions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index ab1b2e31000..7b310ed6b08 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -6,7 +6,7 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _dims_to_axis, - _flattened_dims, + _flatten_dims, _get_broadcasted_dims, _get_data_namespace, _infer_dims, @@ -259,14 +259,13 @@ def reshape( >>> x1 = reshape(x, (-1,)) >>> x1.dims, x1.shape ((('x', 'y'),), (12,)) - """ xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape, copy=copy) if math.prod(shape) == -1 and False: # Flattening operations merges all dimensions to 1: - dims_raveled = _flattened_dims(x.dims, x.ndim) + dims_raveled = _flatten_dims(x.dims) dim = dims_raveled[0] d = [] for v in shape: From c77feb1c93a32571d21e7cd2202dee1ed78a444b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:51:54 +0200 Subject: [PATCH 336/367] Update core.py --- xarray/namedarray/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index d454abecb2c..72c22c838b9 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -553,13 +553,13 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: ) from xarray.namedarray._array_api._utils import ( _dims_from_tuple_indexing, - _flattened_dims, + _flatten_dims, ) if isinstance(key, NamedArray): self_new, key_new = _broadcast_arrays(self, key) _data = self_new._data[key_new._data] - _dims = _flattened_dims(self_new.dims, self_new.ndim) + _dims = _flatten_dims(self_new.dims) return self._new(_dims, _data) # elif isinstance(key, int): # return self._new(self.dims[1:], self._data[key]) From 00115d1a2e8816d835ac0b54e634c34100c3a68c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:01:27 +0200 Subject: [PATCH 337/367] Update _set_functions.py --- xarray/namedarray/_array_api/_set_functions.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index fe3903ae466..721fb861045 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -30,7 +30,7 @@ class UniqueInverseResult(NamedTuple): def unique_all(x: NamedArray[Any, Any], /) -> UniqueAllResult: xp = _get_data_namespace(x) values, indices, inverse_indices, counts = xp.unique_all(x._data) - _dims = _atleast1d_dims(_flatten_dims(x.dims)) + _dims = _flatten_dims(_atleast1d_dims(x.dims)) return UniqueAllResult( NamedArray(_dims, values), NamedArray(_dims, indices), @@ -85,10 +85,9 @@ def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult: """ xp = _get_data_namespace(x) values, inverse_indices = xp.unique_inverse(x._data) - _dims = _flatten_dims(_atleast1d_dims(x.dims)) return UniqueInverseResult( - NamedArray(_dims, values), - NamedArray(_dims, inverse_indices), + NamedArray(_flatten_dims(_atleast1d_dims(x.dims)), values), + NamedArray(_flatten_dims(x.dims), inverse_indices), ) From 6cf9fae421a99767939cdd00314523c60cb0d05a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:07:01 +0200 Subject: [PATCH 338/367] Update _set_functions.py --- xarray/namedarray/_array_api/_set_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index 721fb861045..46d6c50bb0a 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -34,7 +34,7 @@ def unique_all(x: NamedArray[Any, Any], /) -> UniqueAllResult: return UniqueAllResult( NamedArray(_dims, values), NamedArray(_dims, indices), - NamedArray(_dims, inverse_indices), + NamedArray(_flatten_dims(x.dims), inverse_indices), NamedArray(_dims, counts), ) From 18400d53fc6333b8180987b42be100a86b0f6027 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:29:44 +0200 Subject: [PATCH 339/367] indexing are always 1d? --- xarray/namedarray/core.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 72c22c838b9..5d2c40bdb45 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -505,9 +505,9 @@ def __ge__(self, other: int | float | NamedArray, /): def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: """ - sdfd + Returns self[key]. - Some rules + Some rules: * Integers removes the dim. * Slices and ellipsis maintains same dim. * None adds a dim. @@ -519,14 +519,14 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: 1D >>> x = NamedArray(("x",), np.array([0, 1, 2])) - >>> mask = NamedArray(("x",), np.array([1, 0, 0], dtype=bool)) - >>> xm = x[mask] + >>> key = NamedArray(("x",), np.array([1, 0, 0], dtype=bool)) + >>> xm = x[key] >>> xm.dims, xm.shape (('x',), (1,)) >>> x = NamedArray(("x",), np.array([0, 1, 2])) - >>> mask = NamedArray(("x",), np.array([0, 0, 0], dtype=bool)) - >>> xm = x[mask] + >>> key = NamedArray(("x",), np.array([0, 0, 0], dtype=bool)) + >>> xm = x[key] >>> xm.dims, xm.shape (('x',), (0,)) @@ -543,15 +543,24 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: >>> xm.dims, xm.shape (('dim_2', 'x', 'y'), (1, 3, 4)) - >>> mask = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool)) - >>> xm = x[mask] + >>> key = NamedArray(("x", "y"), np.ones((3, 4), dtype=bool)) + >>> xm = x[key] >>> xm.dims, xm.shape ((('x', 'y'),), (12,)) + + 0D + + >>> x = NamedArray((), np.array(False, dtype=np.bool)) + >>> key = NamedArray((), np.array(False, dtype=np.bool)) + >>> xm = x[key] + >>> xm.dims, xm.shape + (('dim_0',), (0,)) """ from xarray.namedarray._array_api._manipulation_functions import ( _broadcast_arrays, ) from xarray.namedarray._array_api._utils import ( + _atleast1d_dims, _dims_from_tuple_indexing, _flatten_dims, ) @@ -559,7 +568,7 @@ def __getitem__(self, key: _IndexKeyLike | NamedArray) -> NamedArray: if isinstance(key, NamedArray): self_new, key_new = _broadcast_arrays(self, key) _data = self_new._data[key_new._data] - _dims = _flatten_dims(self_new.dims) + _dims = _flatten_dims(_atleast1d_dims(self_new.dims)) return self._new(_dims, _data) # elif isinstance(key, int): # return self._new(self.dims[1:], self._data[key]) From 842b2896be805636ae7391f24d1f21a6bd81d194 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:55:10 +0200 Subject: [PATCH 340/367] 0D concatenation --- .../_array_api/_manipulation_functions.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 7b310ed6b08..034932699d8 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -5,6 +5,7 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( + _atleast1d_dims, _dims_to_axis, _flatten_dims, _get_broadcasted_dims, @@ -98,14 +99,22 @@ def concat( -------- >>> import numpy as np >>> x = NamedArray(("x",), np.zeros((3,))) - >>> x1 = concat((x, 1 + x)) - >>> x1.dims, x1.shape + >>> xc = concat((x, 1 + x)) + >>> xc.dims, xc.shape (('x',), (6,)) >>> x = NamedArray(("x", "y"), np.zeros((3, 4))) - >>> x1 = concat((x, 1 + x)) - >>> x1.dims, x1.shape + >>> xc = concat((x, 1 + x)) + >>> xc.dims, xc.shape (('x', 'y'), (6, 4)) + + 0D + + >>> x1 = NamedArray((), np.array(0)) + >>> x2 = NamedArray((), np.array(0)) + >>> xc = concat((x1, x2), axis=None) + >>> xc.dims, xc.shape + (('dim_0',), (2,)) """ x = arrays[0] xp = _get_data_namespace(x) @@ -113,7 +122,8 @@ def concat( dtype = result_type(*arrays) _arrays = tuple(a._data for a in arrays) _data = xp.concat(_arrays, axis=_axis, dtype=dtype) - return NamedArray(x.dims, _data) + _dims = _atleast1d_dims(x.dims) if axis is None else x.dims + return NamedArray(_dims, _data) def expand_dims( From 733a45a9170588a17c19b79f0bfdb89d174ae51b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 21:56:41 +0200 Subject: [PATCH 341/367] lets try reshape again --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 034932699d8..0cff3a1da15 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -273,7 +273,7 @@ def reshape( xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape, copy=copy) - if math.prod(shape) == -1 and False: + if math.prod(shape) == -1: # Flattening operations merges all dimensions to 1: dims_raveled = _flatten_dims(x.dims) dim = dims_raveled[0] From b7ceeb44d6df2d7251a4faddf104960aea2d3c82 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:13:32 +0200 Subject: [PATCH 342/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index cec92037c1c..52b9da1884a 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -422,6 +422,8 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: ('dim_2', 'x', 'y') >>> _dims_from_tuple_indexing(("x", "y"), (0, None, None, 0)) ('dim_1', 'dim_2') + >>> _dims_from_tuple_indexing(("x",), (..., 0)) + () """ _dims = list(dims) j = 0 @@ -429,14 +431,14 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: if k is None: # None adds 1 dimension: _dims.insert(j, _new_unique_dim_name(tuple(_dims))) + j += 1 elif isinstance(k, int): # Integer removes 1 dimension: _dims.pop(j) - j -= 1 else: # Slices and Ellipsis maintains same dimensions: pass - j += 1 + return tuple(_dims) From 43f757e25aeb159fbc0211ba4dd44d085590fbdc Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:55:30 +0200 Subject: [PATCH 343/367] Ellipsis are often converted to slices --- xarray/namedarray/_typing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index cb5ebac3613..fb654e0fc54 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -136,7 +136,9 @@ def imag(self) -> _T_co: ... # TODO: np.array_api was bugged and didn't allow (None,), but should! # https://github.com/numpy/numpy/pull/25022 # https://github.com/data-apis/array-api/pull/674 -_IndexKey = Union[int, slice, EllipsisType, None] +_IndexKeyNoEllipsis = Union[int, slice, None] +_IndexKey = Union[_IndexKeyNoEllipsis, EllipsisType] +_IndexKeysNoEllipsis = tuple[_IndexKeyNoEllipsis, ...] _IndexKeys = tuple[_IndexKey, ...] # tuple[Union[_IndexKey, None], ...] _IndexKeyLike = Union[_IndexKey, _IndexKeys] From e8b2b4ead6ee73c87069ec2ecc79339b8a1e8195 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:56:58 +0200 Subject: [PATCH 344/367] Convert ellipsis to slices --- xarray/namedarray/_array_api/_utils.py | 99 ++++++++++++++++++++------ 1 file changed, 78 insertions(+), 21 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 52b9da1884a..a453d5e582f 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,10 +1,10 @@ from __future__ import annotations import math -from collections.abc import Iterable +from collections.abc import Iterable, Iterator from itertools import zip_longest from types import ModuleType -from typing import Any, TypeGuard, cast +from typing import Any, TypeGuard, cast, Callable from xarray.namedarray._typing import ( Default, @@ -19,7 +19,9 @@ _DType, _dtype, _IndexKeys, + _IndexKeysNoEllipsis, _Shape, + _T, duckarray, ) from xarray.namedarray.core import NamedArray @@ -229,10 +231,10 @@ def _normalize_axis_tuple( ValueError If an axis is repeated """ - if not isinstance(axis, tuple): - _axis = (axis,) - else: + if isinstance(axis, tuple): _axis = axis + else: + _axis = (axis,) # Going via an iterator directly is slower than via list comprehension. _axis = tuple([_normalize_axis_index(ax, ndim) for ax in _axis]) @@ -351,10 +353,8 @@ def _get_remaining_dims( removed_axes: tuple[int, ...] if axis is None: removed_axes = tuple(v for v in range(x.ndim)) - elif isinstance(axis, tuple): - removed_axes = tuple(a % x.ndim for a in axis) else: - removed_axes = (axis % x.ndim,) + removed_axes = _normalize_axis_tuple(axis, x.ndim) if keepdims: # Insert None (aka newaxis) for removed dims @@ -369,7 +369,7 @@ def _get_remaining_dims( return dims, data -def _new_unique_dim_name(dims: _Dims, i=None) -> _Dim: +def _new_unique_dim_name(dims: _Dims, i: int | None = None) -> _Dim: """ Get a new unique dimension name. @@ -402,22 +402,78 @@ def _insert_dim(dims: _Dims, dim: _Dim | Default, axis: _Axis) -> _Dims: return tuple(d) +def _filter_next_false( + predicate: Callable[..., bool], iterable: Iterable[_T] +) -> Iterator[_T]: + """ + Make an iterator that filters elements from the iterable returning only those + for which the predicate returns a false value for the second time. + + Variant on itertools.filterfalse but doesn't filter until the 2 second False. + + Examples + -------- + >>> tuple(_filter_next_false(lambda x: x is not None, (1, None, 3, None, 4))) + (1, None, 3, 4) + """ + predicate_has_been_false = False + for x in iterable: + if not predicate(x): + if predicate_has_been_false: + continue + predicate_has_been_false = True + yield x + + +def _replace_ellipsis(key: _IndexKeys, ndim: int) -> _IndexKeysNoEllipsis: + """ + Replace ... with slices, :, : ,: + + >>> _replace_ellipsis((3, Ellipsis, 2), 4) + (3, slice(None, None, None), slice(None, None, None), 2) + + >>> _replace_ellipsis((Ellipsis, None), 2) + (slice(None, None, None), slice(None, None, None), None) + >>> _replace_ellipsis((Ellipsis, None, Ellipsis), 2) + (slice(None, None, None), slice(None, None, None), None) + """ + # https://github.com/dask/dask/blob/569abf8e8048cbfb1d750900468dda0de7c56358/dask/array/slicing.py#L701 + key = tuple(_filter_next_false(lambda x: x is not Ellipsis, key)) + expanded_dims = sum(i is None for i in key) + extra_dimensions = ndim - (len(key) - expanded_dims - 1) + replaced_slices = (slice(None, None, None),) * extra_dimensions + + out: _IndexKeysNoEllipsis = () + for k in key: + if k is Ellipsis: + out += replaced_slices + else: + out += (k,) + return out + + def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: """ Get the expected dims when using tuples in __getitem__. Examples -------- - >>> _dims_from_tuple_indexing(("x", "y"), ()) - ('x', 'y') - >>> _dims_from_tuple_indexing(("x", "y"), (0,)) - ('y',) - >>> _dims_from_tuple_indexing(("x", "y"), (0, 0)) + >>> _dims_from_tuple_indexing(("x", "y", "z"), ()) + ('x', 'y', 'z') + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0,)) + ('y', 'z') + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0, 0)) + ('z',) + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0, 0, 0)) () - >>> _dims_from_tuple_indexing(("x", "y"), (0, ...)) - ('y',) - >>> _dims_from_tuple_indexing(("x", "y"), (0, slice(0))) + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0, 0, 0, ...)) + () + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0, ...)) + ('y', 'z') + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0, ..., 0)) ('y',) + >>> _dims_from_tuple_indexing(("x", "y", "z"), (0, slice(0))) + ('y', 'z') >>> _dims_from_tuple_indexing(("x", "y"), (None,)) ('dim_2', 'x', 'y') >>> _dims_from_tuple_indexing(("x", "y"), (0, None, None, 0)) @@ -425,9 +481,10 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: >>> _dims_from_tuple_indexing(("x",), (..., 0)) () """ + key_no_ellipsis = _replace_ellipsis(key, len(dims)) _dims = list(dims) j = 0 - for k in key: + for k in key_no_ellipsis: if k is None: # None adds 1 dimension: _dims.insert(j, _new_unique_dim_name(tuple(_dims))) @@ -435,9 +492,9 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims: elif isinstance(k, int): # Integer removes 1 dimension: _dims.pop(j) - else: - # Slices and Ellipsis maintains same dimensions: - pass + elif isinstance(k, slice): + # Slice retains the dimension. + j += 1 return tuple(_dims) From ec3880e0fbf7ec17d85ad759bea56ebf39164e90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:58:15 +0000 Subject: [PATCH 345/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index a453d5e582f..98292aaec97 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,12 +1,13 @@ from __future__ import annotations import math -from collections.abc import Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator from itertools import zip_longest from types import ModuleType -from typing import Any, TypeGuard, cast, Callable +from typing import Any, TypeGuard, cast from xarray.namedarray._typing import ( + _T, Default, _arrayapi, _Axes, @@ -21,7 +22,6 @@ _IndexKeys, _IndexKeysNoEllipsis, _Shape, - _T, duckarray, ) from xarray.namedarray.core import NamedArray From 4431175265d101b187364748cdab7180cc294b2e Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:45:03 +0200 Subject: [PATCH 346/367] Update _manipulation_functions.py --- .../_array_api/_manipulation_functions.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 0cff3a1da15..6dad340e585 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -6,6 +6,7 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _atleast1d_dims, + _dims_from_tuple_indexing, _dims_to_axis, _flatten_dims, _get_broadcasted_dims, @@ -333,11 +334,26 @@ def tile( def unstack( x: NamedArray[Any, Any], /, *, axis: _Axis = 0 ) -> tuple[NamedArray[Any, Any], ...]: + """ + Splits an array into a sequence of arrays along the given axis. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x", "y", "z"), np.arange(1*2*3).reshape((1, 2, 3))) + >>> x_y0, x_y1 = unstack(x, axis=1) + >>> x_y0 + + >>> x_y1 + + """ xp = _get_data_namespace(x) _datas = xp.unstack(x._data, axis=axis) + key = [slice(None)] * x.ndim + key[axis] = 0 + _dims = _dims_from_tuple_indexing(x.dims, tuple(key)) out: tuple[NamedArray[Any, Any], ...] = () for _data in _datas: - _dims = _infer_dims(_data.shape) # TODO: Fix dims out += (x._new(_dims, _data),) return out From 12677506c28b0a82095b32b69991056867d8a8f8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:45:50 +0000 Subject: [PATCH 347/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_manipulation_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 6dad340e585..eb67c13f430 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -340,7 +340,7 @@ def unstack( Examples -------- >>> import numpy as np - >>> x = NamedArray(("x", "y", "z"), np.arange(1*2*3).reshape((1, 2, 3))) + >>> x = NamedArray(("x", "y", "z"), np.arange(1 * 2 * 3).reshape((1, 2, 3))) >>> x_y0, x_y1 = unstack(x, axis=1) >>> x_y0 From 5bafb8cec52385b838a032f36769d7a87438ef64 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:23:28 +0200 Subject: [PATCH 348/367] Try simplify mean --- .../_array_api/_statistical_functions.py | 12 +++---- xarray/namedarray/_array_api/_utils.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 0326d47ed3e..11f8defe141 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -6,6 +6,7 @@ _dims_to_axis, _get_data_namespace, _get_remaining_dims, + _reduce_dims, ) from xarray.namedarray._typing import ( Default, @@ -131,7 +132,8 @@ def mean( Examples -------- - >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) + >>> import numpy as np + >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) >>> mean(x).data Array(2.5, dtype=float64) >>> mean(x, dims=("x",)).data @@ -149,11 +151,9 @@ def mean( """ xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.mean(x._data, axis=_axis, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _data = xp.mean(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def min( diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 98292aaec97..03ecbe8954d 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -369,6 +369,40 @@ def _get_remaining_dims( return dims, data +def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims: + """ + Reduce dims according to axis. + + Examples + -------- + >>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False) + () + >>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False) + ('x', 'z') + >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False) + ('x', 'y') + + keepdims retains the same dims + + >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True) + ('x', 'y', 'z') + """ + if keepdims: + return dims + + ndim = len(dims) + if axis is None: + _axis = tuple(v for v in range(ndim)) + else: + _axis = _normalize_axis_tuple(axis, ndim) + + key = [slice(None)] * ndim + for i, v in enumerate(_axis): + key[v] = 0 + + return _dims_from_tuple_indexing(dims, tuple(key)) + + def _new_unique_dim_name(dims: _Dims, i: int | None = None) -> _Dim: """ Get a new unique dimension name. From c59fb54fda6a7265984d69586469d586113ed8db Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:28:50 +0200 Subject: [PATCH 349/367] Seems successfull do it on the rest --- .../_array_api/_statistical_functions.py | 48 +++++++------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index 11f8defe141..e2617722e8b 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -5,7 +5,6 @@ from xarray.namedarray._array_api._utils import ( _dims_to_axis, _get_data_namespace, - _get_remaining_dims, _reduce_dims, ) from xarray.namedarray._typing import ( @@ -84,10 +83,9 @@ def max( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.max(x._data, axis=_axis, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - return x._new(dims=dims_, data=data_) + _data = xp.max(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def mean( @@ -166,11 +164,9 @@ def min( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.min(x._data, axis=_axis, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _data = xp.min(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def prod( @@ -184,11 +180,9 @@ def prod( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.prod(x._data, axis=_axis, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _data = xp.prod(x._data, axis=_axis, dtype=dtype, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def std( @@ -202,11 +196,9 @@ def std( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.std(x._data, axis=_axis, correction=correction, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _data = xp.std(x._data, axis=_axis, correction=correction, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def sum( @@ -220,11 +212,9 @@ def sum( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.sum(x._data, axis=_axis, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _data = xp.sum(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def var( @@ -238,8 +228,6 @@ def var( ) -> NamedArray[Any, _DType]: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) - _data = xp.var(x._data, axis=_axis, correction=correction, keepdims=False) - # TODO: Why do we need to do the keepdims ourselves? - dims_, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _data = xp.var(x._data, axis=_axis, correction=correction, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) From 889feaa5c72dbe71f6df438eb4d12b4f0d9b653f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:41:10 +0200 Subject: [PATCH 350/367] more places to simplify --- .../_array_api/_searching_functions.py | 16 ++- .../_array_api/_utility_functions.py | 20 ++-- xarray/namedarray/_array_api/_utils.py | 100 ++++++------------ 3 files changed, 50 insertions(+), 86 deletions(-) diff --git a/xarray/namedarray/_array_api/_searching_functions.py b/xarray/namedarray/_array_api/_searching_functions.py index 38e851bf5f8..755ee16b588 100644 --- a/xarray/namedarray/_array_api/_searching_functions.py +++ b/xarray/namedarray/_array_api/_searching_functions.py @@ -5,8 +5,8 @@ from xarray.namedarray._array_api._utils import ( _dim_to_optional_axis, _get_data_namespace, - _get_remaining_dims, _infer_dims, + _reduce_dims, ) from xarray.namedarray._typing import ( Default, @@ -32,10 +32,9 @@ def argmax( ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _axis = _dim_to_optional_axis(x, dim, axis) - _data = xp.argmax(x._data, axis=_axis, keepdims=False) # We fix keepdims later - # TODO: Why do we need to do the keepdims ourselves? - _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - return x._new(dims=_dims, data=data_) + _data = xp.argmax(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def argmin( @@ -48,10 +47,9 @@ def argmin( ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) _axis = _dim_to_optional_axis(x, dim, axis) - _data = xp.argmin(x._data, axis=_axis, keepdims=False) # We fix keepdims later - # TODO: Why do we need to do the keepdims ourselves? - _dims, data_ = _get_remaining_dims(x, _data, _axis, keepdims=keepdims) - return x._new(dims=_dims, data=data_) + _data = xp.argmin(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def nonzero(x: NamedArray[Any, Any], /) -> tuple[NamedArray[Any, Any], ...]: diff --git a/xarray/namedarray/_array_api/_utility_functions.py b/xarray/namedarray/_array_api/_utility_functions.py index 17cd5ad03c7..211046e04f2 100644 --- a/xarray/namedarray/_array_api/_utility_functions.py +++ b/xarray/namedarray/_array_api/_utility_functions.py @@ -5,7 +5,7 @@ from xarray.namedarray._array_api._utils import ( _dims_to_axis, _get_data_namespace, - _get_remaining_dims, + _reduce_dims, ) from xarray.namedarray._typing import ( Default, @@ -25,11 +25,10 @@ def all( axis: _AxisLike | None = None, ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.all(x._data, axis=axis_, keepdims=False) - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _axis = _dims_to_axis(x, dims, axis) + _data = xp.all(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) def any( @@ -41,8 +40,7 @@ def any( axis: _AxisLike | None = None, ) -> NamedArray[Any, Any]: xp = _get_data_namespace(x) - axis_ = _dims_to_axis(x, dims, axis) - d = xp.any(x._data, axis=axis_, keepdims=False) - dims_, data_ = _get_remaining_dims(x, d, axis_, keepdims=keepdims) - out = x._new(dims=dims_, data=data_) - return out + _axis = _dims_to_axis(x, dims, axis) + _data = xp.any(x._data, axis=_axis, keepdims=keepdims) + _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) + return x._new(dims=_dims, data=_data) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 03ecbe8954d..8257d471e7b 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -337,72 +337,6 @@ def _dim_to_axis(x: NamedArray[Any, Any], dim: _Dim | Default, axis: int) -> int return _axis -def _get_remaining_dims( - x: NamedArray[Any, _DType], - data: duckarray[Any, _DType], - axis: _AxisLike | None, - *, - keepdims: bool, -) -> tuple[_Dims, duckarray[Any, _DType]]: - """ - Get the reamining dims after a reduce operation. - """ - if data.shape == x.shape: - return x.dims, data - - removed_axes: tuple[int, ...] - if axis is None: - removed_axes = tuple(v for v in range(x.ndim)) - else: - removed_axes = _normalize_axis_tuple(axis, x.ndim) - - if keepdims: - # Insert None (aka newaxis) for removed dims - slices = tuple( - None if i in removed_axes else slice(None, None) for i in range(x.ndim) - ) - data = data[slices] - dims = x.dims - else: - dims = tuple(adim for n, adim in enumerate(x.dims) if n not in removed_axes) - - return dims, data - - -def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims: - """ - Reduce dims according to axis. - - Examples - -------- - >>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False) - () - >>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False) - ('x', 'z') - >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False) - ('x', 'y') - - keepdims retains the same dims - - >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True) - ('x', 'y', 'z') - """ - if keepdims: - return dims - - ndim = len(dims) - if axis is None: - _axis = tuple(v for v in range(ndim)) - else: - _axis = _normalize_axis_tuple(axis, ndim) - - key = [slice(None)] * ndim - for i, v in enumerate(_axis): - key[v] = 0 - - return _dims_from_tuple_indexing(dims, tuple(key)) - - def _new_unique_dim_name(dims: _Dims, i: int | None = None) -> _Dim: """ Get a new unique dimension name. @@ -565,6 +499,40 @@ def _atleast1d_dims(dims: _Dims) -> _Dims: return (_new_unique_dim_name(dims),) if len(dims) < 1 else dims +def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims: + """ + Reduce dims according to axis. + + Examples + -------- + >>> _reduce_dims(("x", "y", "z"), axis=None, keepdims=False) + () + >>> _reduce_dims(("x", "y", "z"), axis=1, keepdims=False) + ('x', 'z') + >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=False) + ('x', 'y') + + keepdims retains the same dims + + >>> _reduce_dims(("x", "y", "z"), axis=-1, keepdims=True) + ('x', 'y', 'z') + """ + if keepdims: + return dims + + ndim = len(dims) + if axis is None: + _axis = tuple(v for v in range(ndim)) + else: + _axis = _normalize_axis_tuple(axis, ndim) + + key = [slice(None)] * ndim + for i, v in enumerate(_axis): + key[v] = 0 + + return _dims_from_tuple_indexing(dims, tuple(key)) + + def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: From f6ffdaffacbcd5f93c4dd13b2a3904c42a7f1022 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:41:57 +0000 Subject: [PATCH 351/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 8257d471e7b..0c13a482fa5 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -17,12 +17,10 @@ _Dim, _Dims, _DimsLike2, - _DType, _dtype, _IndexKeys, _IndexKeysNoEllipsis, _Shape, - duckarray, ) from xarray.namedarray.core import NamedArray From 793d166c79d1913245bc82b9c13e606369727b48 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 22:22:20 +0200 Subject: [PATCH 352/367] squeeze --- .../_array_api/_manipulation_functions.py | 17 +++++++++++++++-- xarray/namedarray/_array_api/_utils.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index eb67c13f430..316bff7d868 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -14,11 +14,13 @@ _infer_dims, _insert_dim, _new_unique_dim_name, + _squeeze_dims, ) from xarray.namedarray._typing import ( Default, _Axes, _Axis, + _AxisLike, _default, _Dim, _Dims, @@ -301,10 +303,21 @@ def roll( return x._new(data=_data) -def squeeze(x: NamedArray[Any, _DType], /, axis: _Axes) -> NamedArray[Any, _DType]: +def squeeze(x: NamedArray[Any, _DType], /, axis: _AxisLike) -> NamedArray[Any, _DType]: + """ + Removes singleton dimensions (axes) from x. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x", "y", "z"), np.arange(1 * 2 * 3).reshape((1, 2, 3))) + >>> xs = squeeze(x, axis=0) + >>> xs.dims, xs.shape + (('y', 'z'), (2, 3)) + """ xp = _get_data_namespace(x) _data = xp.squeeze(x._data, axis=axis) - _dims = _infer_dims(_data.shape) # TODO: Fix dims + _dims = _squeeze_dims(x.dims, x.shape, axis) return x._new(_dims, _data) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 8257d471e7b..7acd3fa8b24 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -533,6 +533,24 @@ def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Di return _dims_from_tuple_indexing(dims, tuple(key)) +def _squeeze_dims(dims: _Dims, shape: _Shape, axis: _AxisLike) -> _Dims: + """ + Squeeze dims. + + Examples + -------- + >>> _squeeze_dims(("x", "y", "z"), (0, 2, 1), (0, 2)) + ('y',) + """ + sizes = dict(zip(dims, shape)) + for a in _normalize_axis_tuple(axis, len(dims)): + d = dims[a] + if sizes[d] < 2: + sizes.pop(d) + + return tuple(sizes.keys()) + + def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: From 353d8bf724141664f22845de809eff9ce9b0039c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:23:34 +0000 Subject: [PATCH 353/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/namedarray/_array_api/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index ec2407677fb..bafe4896b2a 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -540,7 +540,7 @@ def _squeeze_dims(dims: _Dims, shape: _Shape, axis: _AxisLike) -> _Dims: >>> _squeeze_dims(("x", "y", "z"), (0, 2, 1), (0, 2)) ('y',) """ - sizes = dict(zip(dims, shape)) + sizes = dict(zip(dims, shape, strict=False)) for a in _normalize_axis_tuple(axis, len(dims)): d = dims[a] if sizes[d] < 2: From 0f26d51a6139812cdd55ba70e35e23f9ec89375b Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 22:27:01 +0200 Subject: [PATCH 354/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index bafe4896b2a..fd0e7aa1129 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -540,7 +540,7 @@ def _squeeze_dims(dims: _Dims, shape: _Shape, axis: _AxisLike) -> _Dims: >>> _squeeze_dims(("x", "y", "z"), (0, 2, 1), (0, 2)) ('y',) """ - sizes = dict(zip(dims, shape, strict=False)) + sizes = dict(zip(dims, shape, strict=True)) for a in _normalize_axis_tuple(axis, len(dims)): d = dims[a] if sizes[d] < 2: From 2cec2c7c24bc9292670dbb21ad56d5a61c1fcbd7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 23:01:53 +0200 Subject: [PATCH 355/367] typing --- xarray/namedarray/_array_api/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index fd0e7aa1129..f135ac65a7a 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -497,7 +497,7 @@ def _atleast1d_dims(dims: _Dims) -> _Dims: return (_new_unique_dim_name(dims),) if len(dims) < 1 else dims -def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Dims: +def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: bool) -> _Dims: """ Reduce dims according to axis. @@ -524,7 +524,8 @@ def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: False) -> _Di else: _axis = _normalize_axis_tuple(axis, ndim) - key = [slice(None)] * ndim + k: _IndexKeys = (slice(None),) * ndim + key = list(k) for i, v in enumerate(_axis): key[v] = 0 From c5a3054f5f1f676f2411e054dc79e2d31e5b0a39 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 25 Sep 2024 23:03:59 +0200 Subject: [PATCH 356/367] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index f135ac65a7a..81e66809557 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -526,7 +526,7 @@ def _reduce_dims(dims: _Dims, *, axis: _AxisLike | None, keepdims: bool) -> _Dim k: _IndexKeys = (slice(None),) * ndim key = list(k) - for i, v in enumerate(_axis): + for v in _axis: key[v] = 0 return _dims_from_tuple_indexing(dims, tuple(key)) From e3923f49dc7825cabf03bf218a840e238d3f2466 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 26 Sep 2024 21:39:18 +0200 Subject: [PATCH 357/367] matrix_transpose --- .../_array_api/_linear_algebra_functions.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index ae40344c5c2..f55f0674504 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -59,10 +59,24 @@ def tensordot( def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: + """ + Transposes a matrix (or a stack of matrices) x. + + Examples + -------- + >>> import numpy as np + >>> x = NamedArray(("x", "y", "z"), np.zeros((2, 3, 4))) + >>> xT = matrix_transpose(x) + >>> xT.dims, xT.shape + (('x', 'z', 'y'), (2, 4, 3)) + + >>> x = NamedArray(("x", "y"), np.zeros((2, 3))) + >>> xT = matrix_transpose(x) + >>> xT.dims, xT.shape + (('y', 'x'), (3, 2)) xp = _get_data_namespace(x) _data = xp.matrix_transpose(x._data) - # TODO: Figure out a better way: - _dims = _infer_dims(_data.shape) + d = x.dims return NamedArray(_dims, _data) From 31d888218799484717de45483d2640066af55e79 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 26 Sep 2024 21:41:44 +0200 Subject: [PATCH 358/367] Update _linear_algebra_functions.py --- xarray/namedarray/_array_api/_linear_algebra_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index f55f0674504..a8f5ab9c364 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -74,9 +74,11 @@ def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: >>> xT = matrix_transpose(x) >>> xT.dims, xT.shape (('y', 'x'), (3, 2)) + """ xp = _get_data_namespace(x) _data = xp.matrix_transpose(x._data) d = x.dims + _dims = d[:-2] + d[-2:][::-1] # (..., M, N) -> (..., N, M) return NamedArray(_dims, _data) From 0c83ff6282949ba4cb69fe04c792950650b0b184 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:28:27 +0200 Subject: [PATCH 359/367] simplify function names --- .../_array_api/_manipulation_functions.py | 8 ++++---- xarray/namedarray/_array_api/_utils.py | 20 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 316bff7d868..e3d1a9e2918 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -9,7 +9,7 @@ _dims_from_tuple_indexing, _dims_to_axis, _flatten_dims, - _get_broadcasted_dims, + _broadcast_dims, _get_data_namespace, _infer_dims, _insert_dim, @@ -56,7 +56,7 @@ def broadcast_arrays(*arrays: NamedArray[Any, Any]) -> list[NamedArray[Any, Any] """ x = arrays[0] xp = _get_data_namespace(x) - _dims, _ = _get_broadcasted_dims(*arrays) + _dims, _ = _broadcast_dims(*arrays) _arrays = tuple(a._data for a in arrays) _datas = xp.broadcast_arrays(*_arrays) return [arr._new(_dims, _data) for arr, _data in zip(arrays, _datas, strict=False)] @@ -472,7 +472,7 @@ def _broadcast_arrays(*arrays: NamedArray[Any, Any]) -> NamedArray[Any, Any]: dimensions are sorted in order of appearance in the first variable's dimensions followed by the second variable's dimensions. """ - dims, shape = _get_broadcasted_dims(*arrays) + dims, shape = _broadcast_dims(*arrays) return tuple(_set_dims(var, dims, shape) for var in arrays) @@ -485,7 +485,7 @@ def _broadcast_arrays_with_minimal_size( Unlike the result of broadcast_variables(), variables with missing dimensions will have them added with size 1 instead of the size of the broadcast dimension. """ - dims, _ = _get_broadcasted_dims(*arrays) + dims, _ = _broadcast_dims(*arrays) return tuple(_set_dims(var, dims, None) for var in arrays) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 81e66809557..be05279e623 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -566,7 +566,7 @@ def _isnone(shape: _Shape) -> tuple[bool, ...]: return tuple(v is None and math.isnan(v) for v in shape) -def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]: +def _broadcast_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape]: """ Get the expected broadcasted dims. @@ -574,40 +574,40 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] -------- >>> import numpy as np >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) - >>> _get_broadcasted_dims(a) + >>> _broadcast_dims(a) (('x', 'y', 'z'), (5, 3, 4)) Broadcasting 0- and 1-sized dims >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) - >>> _get_broadcasted_dims(a, b) + >>> _broadcast_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) - >>> _get_broadcasted_dims(b, a) + >>> _broadcast_dims(b, a) (('x', 'y', 'z'), (5, 3, 4)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((1, 3, 4))) - >>> _get_broadcasted_dims(a, b) + >>> _broadcast_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) - >>> _get_broadcasted_dims(a, b) + >>> _broadcast_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) Broadcasting different dims >>> a = NamedArray(("x",), np.zeros((5,))) >>> b = NamedArray(("y",), np.zeros((3,))) - >>> _get_broadcasted_dims(a, b) + >>> _broadcast_dims(a, b) (('x', 'y'), (5, 3)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) - >>> _get_broadcasted_dims(a, b) + >>> _broadcast_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) - >>> _get_broadcasted_dims(b, a) + >>> _broadcast_dims(b, a) (('x', 'y', 'z'), (5, 3, 4)) @@ -615,7 +615,7 @@ def _get_broadcasted_dims(*arrays: NamedArray[Any, Any]) -> tuple[_Dims, _Shape] >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((2, 3, 4))) - >>> _get_broadcasted_dims(a, b) + >>> _broadcast_dims(a, b) Traceback (most recent call last): ... ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4)) From 27ebb03c011276d66f2a457945253acdcdee368d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:31:54 +0200 Subject: [PATCH 360/367] vecdot --- .../_array_api/_linear_algebra_functions.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index a8f5ab9c364..3c7fc082933 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -3,7 +3,13 @@ from collections.abc import Sequence from typing import Any -from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims +from xarray.namedarray._array_api._utils import ( + _get_data_namespace, + _infer_dims, + _reduce_dims, + _broadcast_dims, +) + from xarray.namedarray.core import NamedArray @@ -85,8 +91,19 @@ def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: def vecdot( x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /, *, axis: int = -1 ) -> NamedArray[Any, Any]: + """ + Computes the (vector) dot product of two arrays. + + Examples + -------- + >>> v = NamedArray(("y", "x"), np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.], [0., 6., 8.]])) + >>> n = NamedArray(("x",), np.array([0., 0.6, 0.8])) + >>> xdot = vecdot(v, n) + >>> xdot.dims, xdot.shape + (('y',), (4,)) + """ xp = _get_data_namespace(x1) _data = xp.vecdot(x1._data, x2._data, axis=axis) - # TODO: Figure out a better way: - _dims = _infer_dims(_data.shape) + d, _ = _broadcast_dims(x1, x2) + _dims = _reduce_dims(d, axis=axis, keepdims=False) return NamedArray(_dims, _data) From 0022d04afccd5a8127bf6e52e8062fcec17aecc9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:33:19 +0000 Subject: [PATCH 361/367] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../_array_api/_linear_algebra_functions.py | 12 ++++++++---- .../namedarray/_array_api/_manipulation_functions.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index 3c7fc082933..b61453a7713 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -4,12 +4,11 @@ from typing import Any from xarray.namedarray._array_api._utils import ( + _broadcast_dims, _get_data_namespace, _infer_dims, _reduce_dims, - _broadcast_dims, ) - from xarray.namedarray.core import NamedArray @@ -96,8 +95,13 @@ def vecdot( Examples -------- - >>> v = NamedArray(("y", "x"), np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.], [0., 6., 8.]])) - >>> n = NamedArray(("x",), np.array([0., 0.6, 0.8])) + >>> v = NamedArray( + ... ("y", "x"), + ... np.array( + ... [[0.0, 5.0, 0.0], [0.0, 0.0, 10.0], [0.0, 6.0, 8.0], [0.0, 6.0, 8.0]] + ... ), + ... ) + >>> n = NamedArray(("x",), np.array([0.0, 0.6, 0.8])) >>> xdot = vecdot(v, n) >>> xdot.dims, xdot.shape (('y',), (4,)) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index e3d1a9e2918..93e8b072dd8 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -6,10 +6,10 @@ from xarray.namedarray._array_api._data_type_functions import result_type from xarray.namedarray._array_api._utils import ( _atleast1d_dims, + _broadcast_dims, _dims_from_tuple_indexing, _dims_to_axis, _flatten_dims, - _broadcast_dims, _get_data_namespace, _infer_dims, _insert_dim, From b7d734afaae81df7ba6c816442ca519aa7749e8c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 27 Sep 2024 20:31:08 +0200 Subject: [PATCH 362/367] Use normal repr --- xarray/namedarray/core.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 5d2c40bdb45..51e99374342 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -19,7 +19,7 @@ import numpy as np # TODO: get rid of this after migrating this class to array API -from xarray.core import dtypes +from xarray.core import dtypes, formatting, formatting_html from xarray.core.indexing import ( ExplicitlyIndexed, ImplicitToExplicitIndexingAdapter, @@ -1408,12 +1408,10 @@ def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]: ) def __repr__(self) -> str: - # return formatting.array_repr(self) - return f"" + return formatting.array_repr(self) def _repr_html_(self) -> str: - # return formatting_html.array_repr(self) - return f"" + return formatting_html.array_repr(self) def _as_sparse( self, From 1891a6c87ffa1e866f03200e16828654fdae8e08 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 27 Sep 2024 20:51:23 +0200 Subject: [PATCH 363/367] Normal repr --- .../_array_api/_manipulation_functions.py | 147 ++++++++++-------- 1 file changed, 86 insertions(+), 61 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 93e8b072dd8..99144a00e5c 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -75,13 +75,13 @@ def broadcast_to( -------- >>> import numpy as np >>> x = NamedArray(("x",), np.arange(0, 3)) - >>> x_new = broadcast_to(x, (1, 1, 3)) - >>> x_new.dims, x_new.shape - (('dim_1', 'dim_0', 'x'), (1, 1, 3)) + >>> broadcast_to(x, (1, 1, 3)) + Size: 24B + array([[[0, 1, 2]]]) - >>> x_new = broadcast_to(x, shape=(1, 1, 3), dims=("y", "x")) - >>> x_new.dims, x_new.shape - (('dim_0', 'y', 'x'), (1, 1, 3)) + >>> broadcast_to(x, shape=(1, 1, 3), dims=("y", "x")) + Size: 24B + array([[[0, 1, 2]]]) """ xp = _get_data_namespace(x) _data = xp.broadcast_to(x._data, shape=shape) @@ -102,22 +102,23 @@ def concat( -------- >>> import numpy as np >>> x = NamedArray(("x",), np.zeros((3,))) - >>> xc = concat((x, 1 + x)) - >>> xc.dims, xc.shape - (('x',), (6,)) + >>> concat((x, 1 + x)) + Size: 48B + array([0., 0., 0., 1., 1., 1.]) - >>> x = NamedArray(("x", "y"), np.zeros((3, 4))) - >>> xc = concat((x, 1 + x)) - >>> xc.dims, xc.shape - (('x', 'y'), (6, 4)) + >>> x = NamedArray(("x", "y"), np.zeros((1, 3))) + >>> concat((x, 1 + x)) + Size: 48B + array([[0., 0., 0.], + [1., 1., 1.]]) 0D >>> x1 = NamedArray((), np.array(0)) - >>> x2 = NamedArray((), np.array(0)) - >>> xc = concat((x1, x2), axis=None) - >>> xc.dims, xc.shape - (('dim_0',), (2,)) + >>> x2 = NamedArray((), np.array(1)) + >>> concat((x1, x2), axis=None) + Size: 16B + array([0, 1]) """ x = arrays[0] xp = _get_data_namespace(x) @@ -158,12 +159,17 @@ def expand_dims( -------- >>> import numpy as np >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) - >>> x_new = expand_dims(x) - >>> x_new.dims, x_new.shape - (('dim_2', 'x', 'y'), (1, 2, 2)) - >>> x_new = expand_dims(x, dim="z") - >>> x_new.dims, x_new.shape - (('z', 'x', 'y'), (1, 2, 2)) + >>> expand_dims(x) + Size: 32B + array([[[1., 2.], + [3., 4.]]]) + + Specify dimension name + + >>> expand_dims(x, dim="z") + Size: 32B + array([[[1., 2.], + [3., 4.]]]) """ # Array Api does not support multiple axes, but maybe in the future: # https://github.com/data-apis/array-api/issues/760 @@ -218,13 +224,23 @@ def permute_dims( Examples -------- >>> import numpy as np - >>> x = NamedArray(("x", "y", "z"), np.zeros((3, 4, 5))) - >>> y = permute_dims(x, (2, 1, 0)) - >>> y.dims, y.shape - (('z', 'y', 'x'), (5, 4, 3)) - >>> y = permute_dims(x, dims=("y", "x", "z")) - >>> y.dims, y.shape - (('y', 'x', 'z'), (4, 3, 5)) + >>> x = NamedArray(("x", "y", "z"), np.zeros((1, 2, 3))) + >>> permute_dims(x, (2, 1, 0)) + Size: 48B + array([[[0.], + [0.]], + + [[0.], + [0.]], + + [[0.], + [0.]]]) + + >>> permute_dims(x, dims=("y", "x", "z")) + Size: 48B + array([[[0., 0., 0.]], + + [[0., 0., 0.]]]) """ xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axes) @@ -258,20 +274,22 @@ def reshape( -------- >>> import numpy as np >>> x = NamedArray(("x",), np.zeros((3,))) - >>> x1 = reshape(x, (-1,)) - >>> x1.dims, x1.shape - (('x',), (3,)) + >>> reshape(x, (-1,)) + Size: 24B + array([0., 0., 0.]) To N-dimensions - >>> x1 = reshape(x, (1, -1, 1)) - >>> x1.dims, x1.shape - (('dim_0', 'x', 'dim_2'), (1, 3, 1)) + >>> reshape(x, (1, -1, 1)) + Size: 24B + array([[[0.], + [0.], + [0.]]]) >>> x = NamedArray(("x", "y"), np.zeros((3, 4))) - >>> x1 = reshape(x, (-1,)) - >>> x1.dims, x1.shape - ((('x', 'y'),), (12,)) + >>> reshape(x, (-1,)) + Size: 96B + array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) """ xp = _get_data_namespace(x) _data = xp.reshape(x._data, shape, copy=copy) @@ -311,9 +329,10 @@ def squeeze(x: NamedArray[Any, _DType], /, axis: _AxisLike) -> NamedArray[Any, _ -------- >>> import numpy as np >>> x = NamedArray(("x", "y", "z"), np.arange(1 * 2 * 3).reshape((1, 2, 3))) - >>> xs = squeeze(x, axis=0) - >>> xs.dims, xs.shape - (('y', 'z'), (2, 3)) + >>> squeeze(x, axis=0) + Size: 48B + array([[0, 1, 2], + [3, 4, 5]]) """ xp = _get_data_namespace(x) _data = xp.squeeze(x._data, axis=axis) @@ -356,9 +375,11 @@ def unstack( >>> x = NamedArray(("x", "y", "z"), np.arange(1 * 2 * 3).reshape((1, 2, 3))) >>> x_y0, x_y1 = unstack(x, axis=1) >>> x_y0 - + Size: 24B + array([[0, 1, 2]]) >>> x_y1 - + Size: 24B + array([[3, 4, 5]]) """ xp = _get_data_namespace(x) _datas = xp.unstack(x._data, axis=axis) @@ -396,37 +417,41 @@ def _set_dims( -------- >>> import numpy as np >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) - >>> x_new = _set_dims(x, ("y", "x"), None) - >>> x_new.dims, x_new.shape - (('y', 'x'), (1, 3)) - >>> x_new = _set_dims(x, ("x", "y"), None) - >>> x_new.dims, x_new.shape - (('x', 'y'), (3, 1)) + >>> _set_dims(x, ("y", "x"), None) + Size: 24B + array([[1, 2, 3]]) + >>> _set_dims(x, ("x", "y"), None) + Size: 24B + array([[1], + [2], + [3]]) With shape: - >>> x_new = _set_dims(x, ("y", "x"), (2, 3)) - >>> x_new.dims, x_new.shape - (('y', 'x'), (2, 3)) + >>> _set_dims(x, ("y", "x"), (2, 3)) + Size: 48B + array([[1, 2, 3], + [1, 2, 3]]) No operation - >>> x_new = _set_dims(x, ("x",), None) - >>> x_new.dims, x_new.shape - (('x',), (3,)) - + >>> _set_dims(x, ("x",), None) + Size: 24B + array([1, 2, 3]) Unordered dims >>> x = NamedArray(("y", "x"), np.zeros((2, 3))) - >>> x_new = _set_dims(x, ("x", "y"), None) - >>> x_new.dims, x_new.shape - (('x', 'y'), (3, 2)) + >>> _set_dims(x, ("x", "y"), None) + Size: 48B + array([[0., 0.], + [0., 0.], + [0., 0.]]) Errors >>> x = NamedArray(("x",), np.asarray([1, 2, 3])) - >>> x_new = _set_dims(x, (), None) + >>> _set_dims(x, (), None) Traceback (most recent call last): ... ValueError: new dimensions () must be a superset of existing dimensions ('x',) From 26de5e42b3577c8410b848cad560d875bc5ad28c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 27 Sep 2024 21:03:33 +0200 Subject: [PATCH 364/367] Update _statistical_functions.py --- .../_array_api/_statistical_functions.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/xarray/namedarray/_array_api/_statistical_functions.py b/xarray/namedarray/_array_api/_statistical_functions.py index e2617722e8b..294cfa6012c 100644 --- a/xarray/namedarray/_array_api/_statistical_functions.py +++ b/xarray/namedarray/_array_api/_statistical_functions.py @@ -132,20 +132,22 @@ def mean( -------- >>> import numpy as np >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) - >>> mean(x).data - Array(2.5, dtype=float64) - >>> mean(x, dims=("x",)).data - Array([2., 3.], dtype=float64) + >>> mean(x) + Size: 8B + np.float64(2.5) + >>> mean(x, dims=("x",)) + Size: 16B + array([2., 3.]) Using keepdims: >>> mean(x, dims=("x",), keepdims=True) - - Array([[2., 3.]], dtype=float64) + Size: 16B + array([[2., 3.]]) >>> mean(x, dims=("y",), keepdims=True) - - Array([[1.5], - [3.5]], dtype=float64) + Size: 16B + array([[1.5], + [3.5]]) """ xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dims, axis) @@ -231,3 +233,9 @@ def var( _data = xp.var(x._data, axis=_axis, correction=correction, keepdims=keepdims) _dims = _reduce_dims(x.dims, axis=_axis, keepdims=keepdims) return x._new(dims=_dims, data=_data) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() From 5496da929e4b141a686dafd689b97c8fc145cb35 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 27 Sep 2024 21:03:42 +0200 Subject: [PATCH 365/367] Update _linear_algebra_functions.py --- .../_array_api/_linear_algebra_functions.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/xarray/namedarray/_array_api/_linear_algebra_functions.py b/xarray/namedarray/_array_api/_linear_algebra_functions.py index b61453a7713..5fb435e5387 100644 --- a/xarray/namedarray/_array_api/_linear_algebra_functions.py +++ b/xarray/namedarray/_array_api/_linear_algebra_functions.py @@ -26,8 +26,9 @@ def matmul( >>> a = NamedArray(("y", "x"), np.array([[1, 0], [0, 1]])) >>> b = NamedArray(("y", "x"), np.array([[4, 1], [2, 2]])) >>> matmul(a, b) - + Size: 32B + array([[4, 1], + [2, 2]]) For 2-D mixed with 1-D, the result is the usual. @@ -39,13 +40,18 @@ def matmul( >>> a = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 2, 4))) >>> b = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 4, 2))) - >>> axb = matmul(a, b) - >>> axb.dims, axb.shape + >>> matmul(a, b) + Size: 64B + array([[[ 28, 34], + [ 76, 98]], + + [[428, 466], + [604, 658]]]) """ xp = _get_data_namespace(x1) _data = xp.matmul(x1._data, x2._data) # TODO: Figure out a better way: - _dims = _infer_dims(_data.shape) + _dims = x1.dims # _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -70,15 +76,19 @@ def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: Examples -------- >>> import numpy as np - >>> x = NamedArray(("x", "y", "z"), np.zeros((2, 3, 4))) - >>> xT = matrix_transpose(x) - >>> xT.dims, xT.shape - (('x', 'z', 'y'), (2, 4, 3)) + >>> x = NamedArray(("x", "y", "z"), np.zeros((1, 2, 3))) + >>> matrix_transpose(x) + Size: 48B + array([[[0., 0.], + [0., 0.], + [0., 0.]]]) >>> x = NamedArray(("x", "y"), np.zeros((2, 3))) - >>> xT = matrix_transpose(x) - >>> xT.dims, xT.shape - (('y', 'x'), (3, 2)) + >>> matrix_transpose(x) + Size: 48B + array([[0., 0.], + [0., 0.], + [0., 0.]]) """ xp = _get_data_namespace(x) _data = xp.matrix_transpose(x._data) @@ -95,6 +105,7 @@ def vecdot( Examples -------- + >>> import numpy as np >>> v = NamedArray( ... ("y", "x"), ... np.array( @@ -102,12 +113,18 @@ def vecdot( ... ), ... ) >>> n = NamedArray(("x",), np.array([0.0, 0.6, 0.8])) - >>> xdot = vecdot(v, n) - >>> xdot.dims, xdot.shape - (('y',), (4,)) + >>> vecdot(v, n) + Size: 32B + array([ 3., 8., 10., 10.]) """ xp = _get_data_namespace(x1) _data = xp.vecdot(x1._data, x2._data, axis=axis) d, _ = _broadcast_dims(x1, x2) _dims = _reduce_dims(d, axis=axis, keepdims=False) return NamedArray(_dims, _data) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() From 6499fb763ba09d0430c1108d2c45c2eb86b2fb09 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 27 Sep 2024 21:15:47 +0200 Subject: [PATCH 366/367] Update _set_functions.py --- .../namedarray/_array_api/_set_functions.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/xarray/namedarray/_array_api/_set_functions.py b/xarray/namedarray/_array_api/_set_functions.py index 46d6c50bb0a..37590fb409e 100644 --- a/xarray/namedarray/_array_api/_set_functions.py +++ b/xarray/namedarray/_array_api/_set_functions.py @@ -50,12 +50,20 @@ def unique_counts(x: NamedArray[Any, Any], /) -> UniqueCountsResult: >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) >>> x_unique = unique_counts(x) >>> x_unique.values + Size: 24B + array([0, 1, 2]) >>> x_unique.counts + Size: 24B + array([1, 1, 2]) >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) >>> x_unique = unique_counts(x) >>> x_unique.values + Size: 24B + array([0, 1, 2]) >>> x_unique.counts + Size: 24B + array([1, 1, 2]) """ xp = _get_data_namespace(x) values, counts = xp.unique_counts(x._data) @@ -77,17 +85,27 @@ def unique_inverse(x: NamedArray[Any, Any], /) -> UniqueInverseResult: >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) >>> x_unique = unique_inverse(x) >>> x_unique.values + Size: 24B + array([0, 1, 2]) >>> x_unique.inverse_indices + Size: 32B + array([0, 1, 2, 2]) + >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) >>> x_unique = unique_inverse(x) - >>> x_unique.dims, x_unique.shape - (('x',), (3,)) + >>> x_unique.values + Size: 24B + array([0, 1, 2]) + >>> x_unique.inverse_indices + Size: 32B + array([[0, 1], + [2, 2]]) """ xp = _get_data_namespace(x) values, inverse_indices = xp.unique_inverse(x._data) return UniqueInverseResult( NamedArray(_flatten_dims(_atleast1d_dims(x.dims)), values), - NamedArray(_flatten_dims(x.dims), inverse_indices), + NamedArray(x.dims, inverse_indices), ) @@ -99,22 +117,28 @@ def unique_values(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]: -------- >>> import numpy as np >>> x = NamedArray(("x",), np.array([0, 1, 2, 2], dtype=int)) - >>> x_unique = unique_values(x) - >>> x_unique.dims, x_unique.shape - (('x',), (3,)) + >>> unique_values(x) + Size: 24B + array([0, 1, 2]) >>> x = NamedArray(("x", "y"), np.array([0, 1, 2, 2], dtype=int).reshape((2, 2))) - >>> x_unique = unique_values(x) - >>> x_unique.dims, x_unique.shape - (('x',), (3,)) + >>> unique_values(x) + Size: 24B + array([0, 1, 2]) # Scalars becomes 1-dimensional >>> x = NamedArray((), np.array(0, dtype=int)) - x_unique = unique_values(x) - >>> x_unique.dims, x_unique.shape - (('dim_0',), (1,)) + >>> unique_values(x) + Size: 8B + array([0]) """ xp = _get_data_namespace(x) _data = xp.unique_values(x._data) _dims = _flatten_dims(_atleast1d_dims(x.dims)) return x._new(_dims, _data) + + +if __name__ == "__main__": + import doctest + + doctest.testmod() From 4dba4bf2960fa1358876a5e6cd44d62c660bdd10 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 30 Sep 2024 21:28:05 +0200 Subject: [PATCH 367/367] flip --- .../_array_api/_manipulation_functions.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/xarray/namedarray/_array_api/_manipulation_functions.py b/xarray/namedarray/_array_api/_manipulation_functions.py index 99144a00e5c..c36c572d49c 100644 --- a/xarray/namedarray/_array_api/_manipulation_functions.py +++ b/xarray/namedarray/_array_api/_manipulation_functions.py @@ -183,10 +183,28 @@ def expand_dims( def flip( x: NamedArray[_ShapeType, _DType], /, *, axis: _Axes | None = None ) -> NamedArray[_ShapeType, _DType]: + """ + Reverse the order of elements in an array along the given axis. + + Examples + -------- + >>> import numpy as np + >>> A = NamedArray(("x",), np.arange(8)) + >>> flip(A, axis=0) + Size: 64B + array([7, 6, 5, 4, 3, 2, 1, 0]) + >>> A = NamedArray(("z", "y", "x"), np.arange(8).reshape((2, 2, 2))) + >>> flip(A, axis=0) + Size: 64B + array([[[4, 5], + [6, 7]], + + [[0, 1], + [2, 3]]]) + """ xp = _get_data_namespace(x) _data = xp.flip(x._data, axis=axis) - _dims = _infer_dims(_data.shape) # TODO: Fix dims - return x._new(_dims, _data) + return x._new(x.dims, _data) def moveaxis(