Skip to content

Commit

Permalink
Add attention op insertion code
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Aug 10, 2023
1 parent 8e90f1b commit b817c8f
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,51 @@ def get_f16_inputs(inputs, is_f16, f16_input_mask):
return tuple(f16_masked_inputs)


def insert_attention_block(fx_g):
import torch

unary_ops = [
torch.ops.aten._unsafe_view,
torch.ops.aten.view,
torch.ops.aten.expand,
torch.ops.aten.clone,
]

def traverse(node):
while node.target in unary_ops:
node = node.args[0]
return node

for node in fx_g.graph.nodes:
if node.target in [torch.ops.aten.bmm]:
outer_bmm = node
node = traverse(outer_bmm.args[0])
if node.target in [torch.ops.aten._softmax]:
softmax_node = node
node = traverse(softmax_node.args[0])
if node.target in [torch.ops.aten.bmm]:
inner_bmm = node
value = outer_bmm.args[1]
key = inner_bmm.args[1]
with fx_g.graph.inserting_before(outer_bmm):
key = fx_g.graph.call_function(
torch.ops.aten.transpose,
args=(key, -2, -1),
kwargs={},
)
query = inner_bmm.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten.scaled_dot_product_attention,
args=(query, key, value),
kwargs={},
)
outer_bmm.append(new_node)
outer_bmm.append(key)
outer_bmm.replace_all_uses_with(new_node)

fx_g.graph.lint()


# Upcasts the block/list of ops.
def add_upcast(fx_g):
import torch
Expand Down Expand Up @@ -640,6 +685,8 @@ def strip_overloads(gm):

strip_overloads(fx_g)

insert_attention_block(fx_g)

if is_f16:
fx_g = fx_g.half()
transform_fx(fx_g)
Expand All @@ -659,6 +706,7 @@ def strip_overloads(gm):
return ts_graph

inputs = get_f16_inputs(inputs, is_f16, f16_input_mask)

mlir_importer = SharkImporter(
ts_graph,
inputs,
Expand Down

0 comments on commit b817c8f

Please sign in to comment.