diff --git a/shark/shark_generate_model_config.py b/shark/shark_generate_model_config.py index 5041242ed1..9af05edbd6 100644 --- a/shark/shark_generate_model_config.py +++ b/shark/shark_generate_model_config.py @@ -1,5 +1,10 @@ +import re import json -from collections import OrderedDict +import torch_mlir +from iree.compiler import compile_str +from apps.language_models.utils import ( + get_torch_mlir_module_bytecode, +) class GenerateConfigFile: @@ -7,7 +12,10 @@ def __init__( self, model, num_sharding_stages: int, - sharding_stages_id: list[str] = None, + sharding_stages_id: list[str], + fx_tracing_required=False, + model_input=None, + config_file_path="model_config.json", ): self.model = model self.num_sharding_stages = num_sharding_stages @@ -15,8 +23,59 @@ def __init__( assert self.num_sharding_stages == len( self.sharding_stages_id ), "Number of sharding stages should be equal to the list of their ID" + self.ts_graph = None + if fx_tracing_required: + self.ts_graph = get_torch_mlir_module_bytecode( + self.model, model_input + ) + self.config_file_path = config_file_path + + def split_into_dispatches(self, input_tensor, backend): + graph_for_compilation = self.model + if self.ts_graph: + graph_for_compilation = self.ts_graph + + module = torch_mlir.compile( + graph_for_compilation, + (input_tensor), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=True, + verbose=False, + ) + module = module.operation.get_asm(large_elements_limit=4) + compiled_module_str = str( + compile_str( + str(module), + target_backends=[backend], + extra_args=[ + "--compile-to=flow", + "--mlir-elide-elementsattrs-if-larger=4", + ], + ) + ) + + substring_start_idx = [ + m.start() + for m in re.finditer("flow.dispatch @", compiled_module_str) + ] + dispatch_list = dict() - def generate_json(self): + # dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model + # dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch + # can occur in a model + for dispatch_no, substring_idx in enumerate(substring_start_idx): + dispatch_idx = ( + compiled_module_str[substring_idx:] + .split(":")[0] + .split("@")[-1] + ) + key = "dispatch_no_" + str(dispatch_no) + dispatch_list[key] = {n: "None" for n in self.sharding_stages_id} + dispatch_list[key]["dispatch_id"] = dispatch_idx + + self.generate_json(dispatch_list) + + def split_into_layers(self): model_dictionary = dict() for name, m in self.model.named_modules(): @@ -34,5 +93,8 @@ def generate_json(self): layer_dict = {n: "None" for n in self.sharding_stages_id} model_dictionary[name] = layer_dict - with open("model_config.json", "w") as outfile: - json.dump(model_dictionary, outfile) + self.generate_json(model_dictionary) + + def generate_json(self, artifacts): + with open(self.config_file_path, "w") as outfile: + json.dump(artifacts, outfile)