Skip to content

Commit

Permalink
Add dispatch-level config file generator for manual annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Jun 21, 2023
1 parent d61b664 commit 6c64078
Showing 1 changed file with 67 additions and 5 deletions.
72 changes: 67 additions & 5 deletions shark/shark_generate_model_config.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,81 @@
import re
import json
from collections import OrderedDict
import torch_mlir
from iree.compiler import compile_str
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
)


class GenerateConfigFile:
def __init__(
self,
model,
num_sharding_stages: int,
sharding_stages_id: list[str] = None,
sharding_stages_id: list[str],
fx_tracing_required=False,
model_input=None,
config_file_path="model_config.json",
):
self.model = model
self.num_sharding_stages = num_sharding_stages
self.sharding_stages_id = sharding_stages_id
assert self.num_sharding_stages == len(
self.sharding_stages_id
), "Number of sharding stages should be equal to the list of their ID"
self.ts_graph = None
if fx_tracing_required:
self.ts_graph = get_torch_mlir_module_bytecode(
self.model, model_input
)
self.config_file_path = config_file_path

def split_into_dispatches(self, input_tensor, backend):
graph_for_compilation = self.model
if self.ts_graph:
graph_for_compilation = self.ts_graph

module = torch_mlir.compile(
graph_for_compilation,
(input_tensor),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=True,
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
compiled_module_str = str(
compile_str(
str(module),
target_backends=[backend],
extra_args=[
"--compile-to=flow",
"--mlir-elide-elementsattrs-if-larger=4",
],
)
)

substring_start_idx = [
m.start()
for m in re.finditer("flow.dispatch @", compiled_module_str)
]
dispatch_list = dict()

def generate_json(self):
# dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model
# dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch
# can occur in a model
for dispatch_no, substring_idx in enumerate(substring_start_idx):
dispatch_idx = (
compiled_module_str[substring_idx:]
.split(":")[0]
.split("@")[-1]
)
key = "dispatch_no_" + str(dispatch_no)
dispatch_list[key] = {n: "None" for n in self.sharding_stages_id}
dispatch_list[key]["dispatch_id"] = dispatch_idx

self.generate_json(dispatch_list)

def split_into_layers(self):
model_dictionary = dict()

for name, m in self.model.named_modules():
Expand All @@ -34,5 +93,8 @@ def generate_json(self):
layer_dict = {n: "None" for n in self.sharding_stages_id}
model_dictionary[name] = layer_dict

with open("model_config.json", "w") as outfile:
json.dump(model_dictionary, outfile)
self.generate_json(model_dictionary)

def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)

0 comments on commit 6c64078

Please sign in to comment.