Skip to content

Commit

Permalink
[pt2e][xnnpack_quantizer] add util function to convert scalars to att…
Browse files Browse the repository at this point in the history
…rs (pytorch#110427)

Jerry provided a notebook solution for converting scalars to attrs so that they may be properly quantized:

https://fburl.com/anp/kzz7tfn1

Adding this pass as a util function in xnnpack_quantizer_utils.py

Differential Revision: [D49850150](https://our.internmc.facebook.com/intern/diff/D49850150/)
Pull Request resolved: pytorch#110427
Approved by: https://github.com/jerryzh168
  • Loading branch information
mcr229 authored and pytorchmergebot committed Oct 4, 2023
1 parent 64416a1 commit 66202ed
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn.functional as F
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
Expand Down Expand Up @@ -866,3 +867,30 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
output_qspec=shared_qspec,
_annotated=True,
)


def convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in model.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.add.Tensor,
torch.ops.aten.mul.Tensor,
]:
continue
args = list(n.args)
new_args = []
for i in range(len(args)):
if isinstance(args[i], torch.fx.Node):
new_args.append(args[i])
continue
prefix = "_tensor_constant_"
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
tensor_constant_name = get_new_attr_name(model)
model.register_buffer(tensor_constant_name, torch.tensor(float(args[i])))
with model.graph.inserting_before(n):
get_attr_node = model.graph.create_node(
"get_attr", tensor_constant_name, (), {}
)
new_args.append(get_attr_node)
n.args = tuple(new_args)
model.recompile()
return model

0 comments on commit 66202ed

Please sign in to comment.