Skip to content

Commit

Permalink
revert attention
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Jun 26, 2023
1 parent cede684 commit 8cc9e76
Showing 1 changed file with 15 additions and 40 deletions.
55 changes: 15 additions & 40 deletions python/aitemplate/frontend/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Frontend for attention module
"""
from aitemplate.compiler import ops
from aitemplate.compiler.base import IntVar
from aitemplate.compiler.ops import flash_attention
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
Expand Down Expand Up @@ -362,55 +361,32 @@ def __init__(
)
self.proj_drop = Dropout(proj_drop, dtype=dtype)

def attention(self, q, k, v, seqlens=None):
def attention(self, q, k, v):
batch = q.shape()[0]
head_dim = self.dim // self.num_heads

query = self.proj_q(q)
key = self.proj_k(k)
value = self.proj_v(v)

if detect_target().name() == "cuda":
query = ops.permute()(
ops.reshape()(query, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
)
key = ops.permute()(
ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
value = ops.permute()(
ops.reshape()(value, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
)
return self.op(query, key, value)
elif seqlens:
query = ops.reshape()(query, [batch, self.num_heads, -1, head_dim])
key = ops.reshape()(key, [batch, self.num_heads, -1, head_dim])
value = ops.reshape()(value, [batch, self.num_heads, -1, head_dim])
return self.op(query, key, value, seqlens)

query = ops.reshape()(query, [batch, -1, self.num_heads, head_dim])
query = ops.transpose()(query, 1, 2)
query = ops.reshape()(query, [-1, query.shape()[2], head_dim])
key = ops.reshape()(key, [batch, -1, self.num_heads, head_dim])
key = ops.transpose()(key, 1, 2)
key = ops.reshape()(key, [-1, key.shape()[2], head_dim])
value = ops.reshape()(value, [batch, -1, self.num_heads, head_dim])
value = ops.transpose()(value, 1, 2)
value = ops.reshape()(value, [-1, value.shape()[2], head_dim])
OP = ops.bmm_softmax_bmm_permute(
shape=(self.num_heads,),
scale=head_dim**-0.5,
causal=self.causal,
query = ops.permute()(
ops.reshape()(query, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
return OP(query, key, value)
key = ops.permute()(
ops.reshape()(key, [batch, -1, self.num_heads, head_dim]), [0, 2, 1, 3]
)
value = ops.permute()(
ops.reshape()(value, [batch, -1, self.num_heads, head_dim]),
[0, 2, 1, 3],
)
return self.op(query, key, value)

def forward(self, *args, seqlens=None):
def forward(self, *args):
"""forward pass for calling mha module"""
assert len(args) >= 3
x = args[0]
batch = x.shape()[0]
attn_output = self.attention(args[0], args[1], args[2], seqlens=seqlens)
attn_output = self.attention(args[0], args[1], args[2])
attn_output = ops.reshape()(attn_output, [batch, -1, self.dim])

if self.has_residual:
Expand All @@ -419,8 +395,7 @@ def forward(self, *args, seqlens=None):
else:
x = self.proj(attn_output)
x = self.proj_drop(x)
if not isinstance(batch, IntVar):
x = ops.reshape()(x, [batch, -1, self.dim])
x = ops.reshape()(x, [batch, -1, self.dim])
return x


Expand All @@ -430,4 +405,4 @@ def __init__(self) -> None:

def forward(self, q, k, v):
attn = ops.mem_eff_attention(causal=False)(q, k, v)
return attn
return attn

0 comments on commit 8cc9e76

Please sign in to comment.