You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
So I think the reshape function reshape1 should change x = tf.transpose(x, [2, 0, 1, 3]) into x = tf.transpose(x, [0, 2, 1, 3]). And so does the reshape2.
The text was updated successfully, but these errors were encountered:
Hi, thanks a lot for your code. It seems that I find a bug.
In the
MultiHeadAttention
layer, thereshape1
functionThe transpose puts the head axis before the batch axis. After reshaping, the first axis should be like this (suppose N samples and only 2 heads):
But the repeats of
mask
:will return
mask
like this:(find the useage of repeat_elements here)
However, actually we want
mask
to be like this:So I think the reshape function
reshape1
should changex = tf.transpose(x, [2, 0, 1, 3])
intox = tf.transpose(x, [0, 2, 1, 3])
. And so does thereshape2
.The text was updated successfully, but these errors were encountered: