Skip to content

Commit

Permalink
vecdot
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 26, 2024
1 parent 0c83ff6 commit 27ebb03
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions xarray/namedarray/_array_api/_linear_algebra_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from collections.abc import Sequence
from typing import Any

from xarray.namedarray._array_api._utils import _get_data_namespace, _infer_dims
from xarray.namedarray._array_api._utils import (
_get_data_namespace,
_infer_dims,
_reduce_dims,
_broadcast_dims,
)

from xarray.namedarray.core import NamedArray


Expand Down Expand Up @@ -85,8 +91,19 @@ def matrix_transpose(x: NamedArray[Any, Any], /) -> NamedArray[Any, Any]:
def vecdot(
x1: NamedArray[Any, Any], x2: NamedArray[Any, Any], /, *, axis: int = -1
) -> NamedArray[Any, Any]:
"""
Computes the (vector) dot product of two arrays.
Examples
--------
>>> v = NamedArray(("y", "x"), np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.], [0., 6., 8.]]))
>>> n = NamedArray(("x",), np.array([0., 0.6, 0.8]))
>>> xdot = vecdot(v, n)
>>> xdot.dims, xdot.shape
(('y',), (4,))
"""
xp = _get_data_namespace(x1)
_data = xp.vecdot(x1._data, x2._data, axis=axis)
# TODO: Figure out a better way:
_dims = _infer_dims(_data.shape)
d, _ = _broadcast_dims(x1, x2)
_dims = _reduce_dims(d, axis=axis, keepdims=False)
return NamedArray(_dims, _data)

0 comments on commit 27ebb03

Please sign in to comment.