Skip to content

Commit

Permalink
Process separate q/k/v weights in MHA converter (#1020)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1020

ATT. The converter was not ready for the `not self._qkv_same_embed_dim` case [here](https://github.com/pytorch/pytorch/blob/80ed3e9ccdaab20814b4156611a19043aaaaef03/torch/nn/modules/activation.py#L1074) with separate q/k/v weights. Here we cover this case.

Intenral:

This causes a failure in the AIT lowering of the IGCTR MC model. See the post: https://fb.workplace.com/groups/gpuinference/permalink/2872581106223872/ .

Reviewed By: ColinPeppler

Differential Revision: D61155566

fbshipit-source-id: 98ba4c4150a036268ec8bcbe4f6b5aa7934374d2
  • Loading branch information
aakhundov authored and facebook-github-bot committed Aug 13, 2024
1 parent 2aef297 commit 7b41778
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
32 changes: 31 additions & 1 deletion fx2ait/fx2ait/converters/ait_module_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def multi_head_attention_module(
)

# Bind constant tensor for MHA module
q_w, k_w, v_w = None, None, None
qkv_weight, qkv_bias = None, None
for k, v in submod.named_parameters():
ait_data = _TorchConstantTensorData(v.data.contiguous().cuda().half())
Expand All @@ -81,6 +82,27 @@ def multi_head_attention_module(
name=make_str_ait_friendly(f"{target}.{k}"),
)
qkv_bias._bind_data(ait_data)
elif k == "q_proj_weight":
q_w = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
q_w._bind_data(ait_data)
elif k == "k_proj_weight":
k_w = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
k_w._bind_data(ait_data)
elif k == "v_proj_weight":
v_w = Tensor(
shape=v.shape,
dtype="float16",
name=make_str_ait_friendly(f"{target}.{k}"),
)
v_w._bind_data(ait_data)
elif "out_proj" in k:
if "weight" in k:
tensor = attn.proj.weight.tensor()
Expand All @@ -90,7 +112,15 @@ def multi_head_attention_module(
tensor._bind_data(ait_data)

# Swap out qkv tensor used by nn.CrossAttention.
q_w, k_w, v_w = chunk()(qkv_weight, 3)
if qkv_weight is not None:
assert q_w is None
assert k_w is None
assert v_w is None
q_w, k_w, v_w = chunk()(qkv_weight, 3)
else:
assert q_w is not None
assert k_w is not None
assert v_w is not None
q_b, k_b, v_b = chunk()(qkv_bias, 3)

attn.proj_q.weight._tensor = q_w
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class TestMultiHeadAttentionConverter(AITTestCase):
def test_multihead_attention_cross_attenytion(self):
def test_multihead_attention_cross_attention(self):
class TestModule(torch.nn.Module):
def __init__(self, dim, nheads):
super().__init__()
Expand Down Expand Up @@ -77,3 +77,37 @@ def forward(self, x):
expected_ops={torch.nn.MultiheadAttention},
leaf_module=torch.nn.MultiheadAttention,
)

def test_multihead_attention_different_kv_dims(self):
class TestModule(torch.nn.Module):
def __init__(self, qdim, kdim, vdim, nheads):
super().__init__()
self.attn = torch.nn.MultiheadAttention(
embed_dim=qdim,
num_heads=nheads,
batch_first=True,
kdim=kdim,
vdim=vdim,
)

def forward(self, q, k, v):
return self.attn(query=q, key=k, value=v)

batch_size = 2
seqlen = 4
qdim = 512
kdim = 128
vdim = 128
num_heads = 8

q = torch.ones(batch_size, seqlen, qdim).cuda().half()
k = torch.ones(batch_size, seqlen, kdim).cuda().half()
v = torch.ones(batch_size, seqlen, vdim).cuda().half()
model = TestModule(qdim, kdim, vdim, num_heads).eval().half().cuda()

self.run_test(
model,
[q, k, v],
expected_ops={torch.nn.MultiheadAttention},
leaf_module=torch.nn.MultiheadAttention,
)

0 comments on commit 7b41778

Please sign in to comment.