From 045f2bb1473e3a0a8f200c3c0949ac31554055b6 Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Thu, 22 Jun 2023 15:11:41 -0700 Subject: [PATCH] Add dispatch-level config file generator for manual annotation (#1566) --- shark/shark_generate_model_config.py | 77 ++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/shark/shark_generate_model_config.py b/shark/shark_generate_model_config.py index 5041242ed1..8eb70e0fa9 100644 --- a/shark/shark_generate_model_config.py +++ b/shark/shark_generate_model_config.py @@ -1,5 +1,8 @@ +import re import json -from collections import OrderedDict +import torch_mlir +from iree.compiler import compile_str +from shark.shark_importer import import_with_fx, get_f16_inputs class GenerateConfigFile: @@ -7,7 +10,9 @@ def __init__( self, model, num_sharding_stages: int, - sharding_stages_id: list[str] = None, + sharding_stages_id: list[str], + model_input=None, + config_file_path="model_config.json", ): self.model = model self.num_sharding_stages = num_sharding_stages @@ -15,8 +20,67 @@ def __init__( assert self.num_sharding_stages == len( self.sharding_stages_id ), "Number of sharding stages should be equal to the list of their ID" + self.model_input = model_input + self.config_file_path = config_file_path - def generate_json(self): + def split_into_dispatches( + self, + backend, + fx_tracing_required=True, + f16_model=False, + torch_mlir_tracing=False, + ): + graph_for_compilation = self.model + if fx_tracing_required: + graph_for_compilation = import_with_fx( + self.model, + self.model_input, + is_f16=f16_model, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + + module = torch_mlir.compile( + graph_for_compilation, + (self.model_input), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=torch_mlir_tracing, + verbose=False, + ) + module = module.operation.get_asm(large_elements_limit=4) + compiled_module_str = str( + compile_str( + str(module), + target_backends=[backend], + extra_args=[ + "--compile-to=flow", + "--mlir-elide-elementsattrs-if-larger=4", + ], + ) + ) + + substring_start_idx = [ + m.start() + for m in re.finditer("flow.dispatch @", compiled_module_str) + ] + dispatch_list = dict() + + # dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model + # dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch + # can occur in a model + for dispatch_no, substring_idx in enumerate(substring_start_idx): + dispatch_idx = ( + compiled_module_str[substring_idx:] + .split(":")[0] + .split("@")[-1] + ) + key = "dispatch_no_" + str(dispatch_no) + dispatch_list[key] = {n: "None" for n in self.sharding_stages_id} + dispatch_list[key]["dispatch_id"] = dispatch_idx + + self.generate_json(dispatch_list) + + def split_into_layers(self): model_dictionary = dict() for name, m in self.model.named_modules(): @@ -34,5 +98,8 @@ def generate_json(self): layer_dict = {n: "None" for n in self.sharding_stages_id} model_dictionary[name] = layer_dict - with open("model_config.json", "w") as outfile: - json.dump(model_dictionary, outfile) + self.generate_json(model_dictionary) + + def generate_json(self, artifacts): + with open(self.config_file_path, "w") as outfile: + json.dump(artifacts, outfile)