From 3e1123b2f1a8c4dc0a0c2f09b342d079ce0a7667 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 25 Aug 2024 17:26:55 +0200 Subject: [PATCH] fix linalg --- .../namedarray/_array_api/_linalg/_linalg.py | 42 +++++++++++++------ .../_array_api/_sorting_functions.py | 5 +-- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/xarray/namedarray/_array_api/_linalg/_linalg.py b/xarray/namedarray/_array_api/_linalg/_linalg.py index 189ad6b81f2..9470436a3ae 100644 --- a/xarray/namedarray/_array_api/_linalg/_linalg.py +++ b/xarray/namedarray/_array_api/_linalg/_linalg.py @@ -63,9 +63,13 @@ def diagonal(x: NamedArray, /, *, offset: int = 0) -> NamedArray: def eigh(x: NamedArray, /) -> EighResult: xp = _get_data_namespace(x) - _datas = xp.linalg.eigh(x._data) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return EighResult(*(x._new(_dims, _data) for _data in _datas)) + eigvals, eigvecs = xp.linalg.eigh(x._data) + _dims_vals = _infer_dims(eigvals.shape) # TODO: Fix dims + _dims_vecs = _infer_dims(eigvecs.shape) # TODO: Fix dims + return EighResult( + x._new(_dims_vals, eigvals), + x._new(_dims_vecs, eigvecs), + ) def eigvalsh(x: NamedArray, /) -> NamedArray: @@ -129,16 +133,24 @@ def qr( x: NamedArray, /, *, mode: Literal["reduced", "complete"] = "reduced" ) -> QRResult: xp = _get_data_namespace(x) - _datas = xp.linalg.qr(x._data) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return QRResult(*(x._new(_dims, _data) for _data in _datas)) + q, r = xp.linalg.qr(x._data) + _dims_q = _infer_dims(q.shape) # TODO: Fix dims + _dims_r = _infer_dims(r.shape) # TODO: Fix dims + return QRResult( + x._new(_dims_q, q), + x._new(_dims_r, r), + ) def slogdet(x: NamedArray, /) -> SlogdetResult: xp = _get_data_namespace(x) - _datas = xp.linalg.slogdet(x._data) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return SlogdetResult(*(x._new(_dims, _data) for _data in _datas)) + sign, logabsdet = xp.linalg.slogdet(x._data) + _dims_sign = _infer_dims(sign.shape) # TODO: Fix dims + _dims_logabsdet = _infer_dims(logabsdet.shape) # TODO: Fix dims + return SlogdetResult( + x._new(_dims_sign, sign), + x._new(_dims_logabsdet, logabsdet), + ) def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray: @@ -150,9 +162,15 @@ def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray: def svd(x: NamedArray, /, *, full_matrices: bool = True) -> SVDResult: xp = _get_data_namespace(x) - _datas = xp.linalg.svd(x._data, full_matrices=full_matrices) - _dims = _infer_dims(_datas[0].shape) # TODO: Fix dims - return SVDResult(*(x._new(_dims, _data) for _data in _datas)) + u, s, vh = xp.linalg.svd(x._data, full_matrices=full_matrices) + _dims_u = _infer_dims(u.shape) # TODO: Fix dims + _dims_s = _infer_dims(s.shape) # TODO: Fix dims + _dims_vh = _infer_dims(vh.shape) # TODO: Fix dims + return SVDResult( + x._new(_dims_u, u), + x._new(_dims_s, s), + x._new(_dims_vh, vh), + ) def svdvals(x: NamedArray, /) -> NamedArray: diff --git a/xarray/namedarray/_array_api/_sorting_functions.py b/xarray/namedarray/_array_api/_sorting_functions.py index a47ab7ae527..66798daccdb 100644 --- a/xarray/namedarray/_array_api/_sorting_functions.py +++ b/xarray/namedarray/_array_api/_sorting_functions.py @@ -20,12 +20,10 @@ def argsort( ) -> NamedArray: xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis)[0] + # TODO: As NumPy currently has no native descending sort, we imitate it here: if not descending: _data = xp.argsort(x._data, axis=_axis, stable=stable) else: - # As NumPy has no native descending sort, we imitate it here. Note that - # simply flipping the results of np.argsort(x._array, ...) would not - # respect the relative order like it would in native descending sorts. _data = xp.flip( xp.argsort(xp.flip(x._data, axis=_axis), stable=stable, axis=_axis), axis=_axis, @@ -49,6 +47,7 @@ def sort( xp = _get_data_namespace(x) _axis = _dims_to_axis(x, dim, axis)[0] _data = xp.sort(x._data, axis=_axis, stable=stable) + # TODO: As NumPy currently has no native descending sort, we imitate it here: if descending: _data = xp.flip(_data, axis=_axis) return x._new(data=_data)