Skip to content

Commit

Permalink
re-indents add_upcast in shark importer (#1523)
Browse files Browse the repository at this point in the history
* The two with blocks in add_upcast appear to be underindented making
SD 1.4 break on rdna3, I've pushed them out one more tab, and then
everything appears to work again.
  • Loading branch information
one-lithe-rune authored Jun 12, 2023
1 parent 5e7d593 commit 2fec3c8
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,24 +327,24 @@ def add_upcast(fx_g):
print("found an upcasting block let's upcast it.")
pow_node = node.args[0].args[0].args[0]
rsqrt_node = node
with fx_g.graph.inserting_before(pow_node):
lhs = pow_node.args[0]
upcast_lhs = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(lhs,),
kwargs={"dtype": torch.float32},
)
pow_node.args = (upcast_lhs, pow_node.args[1])
with fx_g.graph.inserting_before(rsqrt_node):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(rsqrt_node,),
kwargs={"dtype": torch.float16},
)
rsqrt_node.append(new_node)
rsqrt_node.replace_all_uses_with(new_node)
new_node.args = (rsqrt_node,)
new_node.kwargs = {"dtype": torch.float16}
with fx_g.graph.inserting_before(pow_node):
lhs = pow_node.args[0]
upcast_lhs = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(lhs,),
kwargs={"dtype": torch.float32},
)
pow_node.args = (upcast_lhs, pow_node.args[1])
with fx_g.graph.inserting_before(rsqrt_node):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(rsqrt_node,),
kwargs={"dtype": torch.float16},
)
rsqrt_node.append(new_node)
rsqrt_node.replace_all_uses_with(new_node)
new_node.args = (rsqrt_node,)
new_node.kwargs = {"dtype": torch.float16}

fx_g.graph.lint()

Expand Down

0 comments on commit 2fec3c8

Please sign in to comment.