Skip to content

Commit

Permalink
add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 21, 2024
1 parent 8a5a041 commit fae50aa
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
27 changes: 27 additions & 0 deletions xarray/namedarray/_array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,33 @@
def matmul(
x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /
) -> NamedArray[Any, Any]:
"""
Matrix product of two arrays.
Examples
--------
For 2-D arrays it is the matrix product:
>>> import numpy as np
>>> a = NamedArray(("y", "x"), np.array([[1, 0], [0, 1]]))
>>> b = NamedArray(("y", "x"), np.array([[4, 1], [2, 2]]))
>>> matmul(a, b)
<Namedarray, shape=(2, 2), dims=('y', 'x'), dtype=int64, data=[[4 1]
[2 2]]>
For 2-D mixed with 1-D, the result is the usual.
>>> a = NamedArray(("y", "x"), np.array([[1, 0], [0, 1]]))
>>> b = NamedArray(("x",), np.array([1, 2]))
>>> matmul(a, b)
Broadcasting is conventional for stacks of arrays
>>> a = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 2, 4)))
>>> b = NamedArray(("z", "y", "x"), np.arange(2 * 2 * 4).reshape((2, 4, 2)))
>>> axb = matmul(a,b)
>>> axb.dims, axb.shape
"""
xp = _get_data_namespace(x1)
_data = xp.matmul(x1._data, x2._data)
# TODO: Figure out a better way:
Expand Down
16 changes: 9 additions & 7 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _infer_dims(
>>> _infer_dims((3, 1))
('dim_1', 'dim_0')
>>> _infer_dims((), ())
()
>>> _infer_dims((1,), "x")
('x',)
>>> _infer_dims((1,), None)
Expand Down Expand Up @@ -406,19 +408,19 @@ def _dims_from_tuple_indexing(dims: _Dims, key: _IndexKeys) -> _Dims:
Examples
--------
>>> dims_from_tuple_indexing(("x", "y"), ())
>>> _dims_from_tuple_indexing(("x", "y"), ())
('x', 'y')
>>> dims_from_tuple_indexing(("x", "y"), (0,))
>>> _dims_from_tuple_indexing(("x", "y"), (0,))
('y',)
>>> dims_from_tuple_indexing(("x", "y"), (0, 0))
>>> _dims_from_tuple_indexing(("x", "y"), (0, 0))
()
>>> dims_from_tuple_indexing(("x", "y"), (0, ...))
>>> _dims_from_tuple_indexing(("x", "y"), (0, ...))
('y',)
>>> dims_from_tuple_indexing(("x", "y"), (0, slice(0)))
>>> _dims_from_tuple_indexing(("x", "y"), (0, slice(0)))
('y',)
>>> dims_from_tuple_indexing(("x", "y"), (None,))
>>> _dims_from_tuple_indexing(("x", "y"), (None,))
('dim_2', 'x', 'y')
>>> dims_from_tuple_indexing(("x", "y"), (0, None, None, 0))
>>> _dims_from_tuple_indexing(("x", "y"), (0, None, None, 0))
('dim_1', 'dim_2')
"""
_dims = list(dims)
Expand Down

0 comments on commit fae50aa

Please sign in to comment.