Skip to content

Commit

Permalink
fix linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 25, 2024
1 parent 5ff5f66 commit 3e1123b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 15 deletions.
42 changes: 30 additions & 12 deletions xarray/namedarray/_array_api/_linalg/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def diagonal(x: NamedArray, /, *, offset: int = 0) -> NamedArray:

def eigh(x: NamedArray, /) -> EighResult:
xp = _get_data_namespace(x)
_datas = xp.linalg.eigh(x._data)
_dims = _infer_dims(_datas[0].shape) # TODO: Fix dims
return EighResult(*(x._new(_dims, _data) for _data in _datas))
eigvals, eigvecs = xp.linalg.eigh(x._data)
_dims_vals = _infer_dims(eigvals.shape) # TODO: Fix dims
_dims_vecs = _infer_dims(eigvecs.shape) # TODO: Fix dims
return EighResult(
x._new(_dims_vals, eigvals),
x._new(_dims_vecs, eigvecs),
)


def eigvalsh(x: NamedArray, /) -> NamedArray:
Expand Down Expand Up @@ -129,16 +133,24 @@ def qr(
x: NamedArray, /, *, mode: Literal["reduced", "complete"] = "reduced"
) -> QRResult:
xp = _get_data_namespace(x)
_datas = xp.linalg.qr(x._data)
_dims = _infer_dims(_datas[0].shape) # TODO: Fix dims
return QRResult(*(x._new(_dims, _data) for _data in _datas))
q, r = xp.linalg.qr(x._data)
_dims_q = _infer_dims(q.shape) # TODO: Fix dims
_dims_r = _infer_dims(r.shape) # TODO: Fix dims
return QRResult(
x._new(_dims_q, q),
x._new(_dims_r, r),
)


def slogdet(x: NamedArray, /) -> SlogdetResult:
xp = _get_data_namespace(x)
_datas = xp.linalg.slogdet(x._data)
_dims = _infer_dims(_datas[0].shape) # TODO: Fix dims
return SlogdetResult(*(x._new(_dims, _data) for _data in _datas))
sign, logabsdet = xp.linalg.slogdet(x._data)
_dims_sign = _infer_dims(sign.shape) # TODO: Fix dims
_dims_logabsdet = _infer_dims(logabsdet.shape) # TODO: Fix dims
return SlogdetResult(
x._new(_dims_sign, sign),
x._new(_dims_logabsdet, logabsdet),
)


def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray:
Expand All @@ -150,9 +162,15 @@ def solve(x1: NamedArray, x2: NamedArray, /) -> NamedArray:

def svd(x: NamedArray, /, *, full_matrices: bool = True) -> SVDResult:
xp = _get_data_namespace(x)
_datas = xp.linalg.svd(x._data, full_matrices=full_matrices)
_dims = _infer_dims(_datas[0].shape) # TODO: Fix dims
return SVDResult(*(x._new(_dims, _data) for _data in _datas))
u, s, vh = xp.linalg.svd(x._data, full_matrices=full_matrices)
_dims_u = _infer_dims(u.shape) # TODO: Fix dims
_dims_s = _infer_dims(s.shape) # TODO: Fix dims
_dims_vh = _infer_dims(vh.shape) # TODO: Fix dims
return SVDResult(
x._new(_dims_u, u),
x._new(_dims_s, s),
x._new(_dims_vh, vh),
)


def svdvals(x: NamedArray, /) -> NamedArray:
Expand Down
5 changes: 2 additions & 3 deletions xarray/namedarray/_array_api/_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ def argsort(
) -> NamedArray:
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dim, axis)[0]
# TODO: As NumPy currently has no native descending sort, we imitate it here:
if not descending:
_data = xp.argsort(x._data, axis=_axis, stable=stable)
else:
# As NumPy has no native descending sort, we imitate it here. Note that
# simply flipping the results of np.argsort(x._array, ...) would not
# respect the relative order like it would in native descending sorts.
_data = xp.flip(
xp.argsort(xp.flip(x._data, axis=_axis), stable=stable, axis=_axis),
axis=_axis,
Expand All @@ -49,6 +47,7 @@ def sort(
xp = _get_data_namespace(x)
_axis = _dims_to_axis(x, dim, axis)[0]
_data = xp.sort(x._data, axis=_axis, stable=stable)
# TODO: As NumPy currently has no native descending sort, we imitate it here:
if descending:
_data = xp.flip(_data, axis=_axis)
return x._new(data=_data)

0 comments on commit 3e1123b

Please sign in to comment.