-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API: Implement tensordot
and matmul
#22
Conversation
Yeah, perhaps we need to define einsum better in Finch and express both Side note: I think we should implement |
d9f991f
to
ef02c9a
Compare
@willow-ahrens The PR is ready from my side! (I also want to benchmark these functions) For now I constrained |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of API changes, thanks for the excellent work here, @mtsokol.
ef02c9a
to
dc1eeb0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @mtsokol!
Tracking issue #21
Hi @willow-ahrens @hameerabbasi,
This WIP PR introduces
finch.tensordot(x1, x2, axes)
andfinch.matmul(x1, x2)
/x1 @ x2
.I still need to complete exhaustive testing.
@willow-ahrens I've got one question regarding
matmul
. I just noticed that there's a slight difference betweentensordot
andmatmul
in Array API for >2D input, namelytensordot
aggregates non-contracted dims from one input and the other, wherematmul
takes two innermost dims for multiplication, and the rest of dims is treated as a stack/batch dimensions.Here's a NumPy code showing it (described in notes in docs):
So I'm not sure I can implement
matmul
asself.tensordot(other, axes=((-1,), (-2,)))
. WDYT?