Skip to content

Commit

Permalink
API: Implement tensordot and matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Mar 27, 2024
1 parent d1b33c2 commit dc1eeb0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 5 deletions.
6 changes: 4 additions & 2 deletions src/finch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
Tensor,
astype,
random,
tensordot,
matmul,
permute_dims,
multiply,
sum,
prod,
add,
subtract,
multiply,
divide,
positive,
negative,
Expand Down Expand Up @@ -65,6 +66,8 @@
"DenseStorage",
"astype",
"random",
"tensordot",
"matmul",
"permute_dims",
"int_",
"int8",
Expand All @@ -82,7 +85,6 @@
"complex64",
"complex128",
"bool",
"multiply",
"lazy",
"compiled",
"compute",
Expand Down
2 changes: 1 addition & 1 deletion src/finch/julia.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import juliapkg

juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.19")
juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.20")
import juliacall # noqa

juliapkg.resolve()
Expand Down
22 changes: 20 additions & 2 deletions src/finch/tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, Iterable, Optional, Union

import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
Expand Down Expand Up @@ -113,7 +113,11 @@ def __pow__(self, other):
return self._elemwise_op(".^", other)

def __matmul__(self, other):
raise NotImplementedError
# TODO: Implement and use mul instead of tensordot
# https://github.com/willow-ahrens/finch-tensor/pull/22#issuecomment-2007884763
if self.ndim != 2 or other.ndim != 2:
raise ValueError(f"Both tensors must be 2-dimensional, but are: {self.ndim=} and {other.ndim=}.")
return tensordot(self, other, axes=((-1,), (-2,)))

def __abs__(self):
return self._elemwise_op("abs")
Expand Down Expand Up @@ -463,6 +467,20 @@ def prod(
return _reduce(x, jl.prod, axis, dtype)


def tensordot(x1: Tensor, x2: Tensor, /, *, axes=2) -> Tensor:
if isinstance(axes, Iterable):
self_axes = normalize_axis_tuple(axes[0], x1.ndim)
other_axes = normalize_axis_tuple(axes[1], x2.ndim)
axes = (tuple(i + 1 for i in self_axes), tuple(i + 1 for i in other_axes))

result = jl.tensordot(x1._obj, x2._obj, axes)
return Tensor(result)


def matmul(x1: Tensor, x2: Tensor) -> Tensor:
return x1 @ x2


def add(x1: Tensor, x2: Tensor, /) -> Tensor:
return x1 + x2

Expand Down
57 changes: 57 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,60 @@ def test_reductions(arr3d, func_name, axis, dtype):
actual = actual.todense()

assert_equal(actual, expected)


@pytest.mark.parametrize(
"storage",
[
None,
(
finch.Storage(finch.SparseList(finch.Element(np.int64(0))), order="C"),
finch.Storage(finch.Dense(finch.SparseList(finch.Element(np.int64(0)))), order="C"),
finch.Storage(
finch.Dense(finch.SparseList(finch.SparseList(finch.Element(np.int64(0))))),
order="C",
),
)
]
)
def test_tensordot(arr3d, storage):
A_finch = finch.Tensor(arr1d)
B_finch = finch.Tensor(arr2d)
C_finch = finch.Tensor(arr3d)
if storage is not None:
A_finch = A_finch.to_device(storage[0])
B_finch = B_finch.to_device(storage[1])
C_finch = C_finch.to_device(storage[2])

actual = finch.tensordot(B_finch, B_finch)
expected = np.tensordot(arr2d, arr2d)
assert_equal(actual.todense(), expected)

actual = finch.tensordot(B_finch, B_finch, axes=(1, 1))
expected = np.tensordot(arr2d, arr2d, axes=(1, 1))
assert_equal(actual.todense(), expected)

actual = finch.tensordot(C_finch, finch.permute_dims(C_finch, (2, 1, 0)), axes=((2, 0), (0, 2)))
expected = np.tensordot(arr3d, arr3d.T, axes=((2, 0), (0, 2)))
assert_equal(actual.todense(), expected)

actual = finch.tensordot(C_finch, A_finch, axes=(2, 0))
expected = np.tensordot(arr3d, arr1d, axes=(2, 0))
assert_equal(actual.todense(), expected)


def test_matmul(arr2d, arr3d):
A_finch = finch.Tensor(arr2d)
B_finch = finch.Tensor(arr2d.T)
C_finch = finch.permute_dims(A_finch, (1, 0))
D_finch = finch.Tensor(arr3d)

actual = A_finch @ B_finch
expected = arr2d @ arr2d.T
assert_equal(actual.todense(), expected)

actual = A_finch @ C_finch
assert_equal(actual.todense(), expected)

with pytest.raises(ValueError, match="Both tensors must be 2-dimensional"):
A_finch @ D_finch

0 comments on commit dc1eeb0

Please sign in to comment.