From 29d0f88fbaf1cf5791b6e683b94103ee94d3a447 Mon Sep 17 00:00:00 2001 From: Colin Chan Date: Thu, 17 Aug 2023 06:46:39 -0700 Subject: [PATCH] Allow benchmark_function to accept leaf modules (#902) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/902 Allow benchmark_function to accept leaf modules in order to test lowered ait submodules Reviewed By: henryhu6 Differential Revision: D48363248 fbshipit-source-id: 5f67353cd31ac2bd87265b914e613dc907f88064 --- fx2ait/fx2ait/tools/common_fx2ait.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/fx2ait/fx2ait/tools/common_fx2ait.py b/fx2ait/fx2ait/tools/common_fx2ait.py index 4d25997a1..371b7b40f 100644 --- a/fx2ait/fx2ait/tools/common_fx2ait.py +++ b/fx2ait/fx2ait/tools/common_fx2ait.py @@ -388,15 +388,34 @@ def benchmark_function( mod: torch.nn.Module, inputs: List[torch.Tensor], permute_inputs: Optional[List[int]] = None, + precision: LowerPrecision = LowerPrecision.FP16, + leaf_module: Callable = None, ) -> float: mod.eval() + + leaf_module_list = [] + if leaf_module: + if isinstance(leaf_module, list): + leaf_module_list.extend(leaf_module) + else: + leaf_module_list.append(leaf_module) + mod = acc_tracer.trace( mod, inputs, + leaf_module_list=leaf_module_list, ) original_inputs = inputs if permute_inputs: inputs = [inp.permute(*permute_inputs).contiguous() for inp in inputs] + torch_dtype = lower_precision_to_torch_type(precision) + mod.to(torch_dtype) + inputs = map_aggregate( + inputs, + lambda inp: inp.to(torch_dtype).contiguous() + if inp.dtype not in (torch.bool, torch.int64) + else inp.contiguous(), + ) interp = AITInterpreter( mod, inputs,