Skip to content

Commit

Permalink
matrix_transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 26, 2024
1 parent c5a3054 commit e3923f4
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions xarray/namedarray/_array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,24 @@ def tensordot(


def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]:
"""
Transposes a matrix (or a stack of matrices) x.

Examples
--------
>>> import numpy as np
>>> x = NamedArray(("x", "y", "z"), np.zeros((2, 3, 4)))
>>> xT = matrix_transpose(x)
>>> xT.dims, xT.shape
(('x', 'z', 'y'), (2, 4, 3))

>>> x = NamedArray(("x", "y"), np.zeros((2, 3)))
>>> xT = matrix_transpose(x)
>>> xT.dims, xT.shape
(('y', 'x'), (3, 2))
xp = _get_data_namespace(x)
_data = xp.matrix_transpose(x._data)
# TODO: Figure out a better way:
_dims = _infer_dims(_data.shape)
d = x.dims
return NamedArray(_dims, _data)


Expand Down

0 comments on commit e3923f4

Please sign in to comment.