Skip to content

Commit

Permalink
Minor changes in the fx transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
Prashant Kumar committed Jun 13, 2023
1 parent 2fec3c8 commit 0a4c8fc
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,19 @@ def add_upcast(fx_g):
import torch

for node in fx_g.graph.nodes:
if node.target in [torch.ops.aten.rsqrt]:
if node.target in [torch.ops.aten.mul]:
# This is a very strict check.
if (
node.args[0].target in [torch.ops.aten.add]
and node.args[0].args[0].target in [torch.ops.aten.mean]
and node.args[0].args[0].args[0].target in [torch.ops.aten.pow]
node.args[1].target in [torch.ops.aten.rsqrt]
and node.args[1].args[0].target in [torch.ops.aten.add]
and node.args[1].args[0].args[0].target
in [torch.ops.aten.mean]
and node.args[1].args[0].args[0].args[0].target
in [torch.ops.aten.pow]
):
print("found an upcasting block let's upcast it.")
pow_node = node.args[0].args[0].args[0]
rsqrt_node = node
pow_node = node.args[1].args[0].args[0].args[0]
mul_node = node
with fx_g.graph.inserting_before(pow_node):
lhs = pow_node.args[0]
upcast_lhs = fx_g.graph.call_function(
Expand All @@ -335,15 +338,15 @@ def add_upcast(fx_g):
kwargs={"dtype": torch.float32},
)
pow_node.args = (upcast_lhs, pow_node.args[1])
with fx_g.graph.inserting_before(rsqrt_node):
with fx_g.graph.inserting_before(mul_node):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(rsqrt_node,),
args=(mul_node,),
kwargs={"dtype": torch.float16},
)
rsqrt_node.append(new_node)
rsqrt_node.replace_all_uses_with(new_node)
new_node.args = (rsqrt_node,)
mul_node.append(new_node)
mul_node.replace_all_uses_with(new_node)
new_node.args = (mul_node,)
new_node.kwargs = {"dtype": torch.float16}

fx_g.graph.lint()
Expand Down Expand Up @@ -433,6 +436,14 @@ def transform_fx(fx_g):
node.replace_all_uses_with(new_node)
new_node.args = (node,)

# Required for cuda debugging.
# for node in fx_g.graph.nodes:
# if node.op == "call_function":
# if node.kwargs.get("device") == torch.device(type="cpu"):
# new_kwargs = node.kwargs.copy()
# new_kwargs["device"] = torch.device(type="cuda")
# node.kwargs = new_kwargs

fx_g.graph.lint()


Expand Down

0 comments on commit 0a4c8fc

Please sign in to comment.