From 66202ed29c989e02455b424890b3dfdc4fa81802 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 3 Oct 2023 15:05:59 -0700 Subject: [PATCH] [pt2e][xnnpack_quantizer] add util function to convert scalars to attrs (#110427) Jerry provided a notebook solution for converting scalars to attrs so that they may be properly quantized: https://fburl.com/anp/kzz7tfn1 Adding this pass as a util function in xnnpack_quantizer_utils.py Differential Revision: [D49850150](https://our.internmc.facebook.com/intern/diff/D49850150/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110427 Approved by: https://github.com/jerryzh168 --- .../quantizer/xnnpack_quantizer_utils.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index add533dd80d153..085c65b768bce0 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -866,3 +867,30 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None: output_qspec=shared_qspec, _annotated=True, ) + + +def convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mul.Tensor, + ]: + continue + args = list(n.args) + new_args = [] + for i in range(len(args)): + if isinstance(args[i], torch.fx.Node): + new_args.append(args[i]) + continue + prefix = "_tensor_constant_" + get_new_attr_name = get_new_attr_name_with_prefix(prefix) + tensor_constant_name = get_new_attr_name(model) + model.register_buffer(tensor_constant_name, torch.tensor(float(args[i]))) + with model.graph.inserting_before(n): + get_attr_node = model.graph.create_node( + "get_attr", tensor_constant_name, (), {} + ) + new_args.append(get_attr_node) + n.args = tuple(new_args) + model.recompile() + return model