diff --git a/.gitignore b/.gitignore index 395a677ba6..eeb217e2b6 100644 --- a/.gitignore +++ b/.gitignore @@ -159,7 +159,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ # vscode related .vscode diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 3e801050c6..54c32c8269 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -3,6 +3,10 @@ from apps.language_models.src.pipelines import vicuna_pipeline as vp from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp import torch +import json + +if __name__ == "__main__": + import gc parser = argparse.ArgumentParser( @@ -55,35 +59,38 @@ help="Run model in cli mode", ) +parser.add_argument( + "--config", + default=None, + help="configuration file", +) + if __name__ == "__main__": args, unknown = parser.parse_known_args() vic = None if not args.sharded: first_vic_mlir_path = ( - Path(f"first_vicuna_{args.precision}.mlir") + None if args.first_vicuna_mlir_path is None else Path(args.first_vicuna_mlir_path) ) second_vic_mlir_path = ( - Path(f"second_vicuna_{args.precision}.mlir") + None if args.second_vicuna_mlir_path is None else Path(args.second_vicuna_mlir_path) ) first_vic_vmfb_path = ( - Path( - f"first_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb" - ) + None if args.first_vicuna_vmfb_path is None else Path(args.first_vicuna_vmfb_path) ) second_vic_vmfb_path = ( - Path( - f"second_vicuna_{args.precision}_{args.device.replace('://', '_')}.vmfb" - ) + None if args.second_vicuna_vmfb_path is None else Path(args.second_vicuna_vmfb_path) ) + vic = vp.Vicuna( "vicuna", device=args.device, @@ -95,16 +102,21 @@ load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, ) else: + if args.config is not None: + config_file = open(args.config) + config_json = json.load(config_file) + config_file.close() + else: + config_json = None vic = vsp.Vicuna( "vicuna", device=args.device, precision=args.precision, + config_json=config_json, ) prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" - import gc - while True: # TODO: Add break condition from user input user_prompt = input("User: ") diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py index cadba37cae..cba0b0f952 100644 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py @@ -145,7 +145,7 @@ def forward( class ShardedVicunaModel(torch.nn.Module): - def __init__(self, model, layers0, layers1): + def __init__(self, model, layers0, layers1, lmhead, embedding, norm): super().__init__() self.model = model assert len(layers0) == len(model.model.layers) @@ -154,6 +154,12 @@ def __init__(self, model, layers0, layers1): self.model.model.config.output_attentions = False self.layers0 = layers0 self.layers1 = layers1 + self.norm = norm + self.embedding = embedding + self.lmhead = lmhead + self.model.model.norm = self.norm + self.model.model.embed_tokens = self.embedding + self.model.lm_head = self.lmhead def forward( self, @@ -176,3 +182,69 @@ def forward( attention_mask=attention_mask, past_key_values=past_key_values, ) + + +class LMHead(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, hidden_states): + output = self.model(hidden_states) + return output + + +class LMHeadCompiled(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward(self, hidden_states): + hidden_states = hidden_states.detach() + output = self.model("forward", (hidden_states,)) + output = torch.tensor(output) + return output + + +class VicunaNorm(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, hidden_states): + output = self.model(hidden_states) + return output + + +class VicunaNormCompiled(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward(self, hidden_states): + hidden_states.detach() + output = self.model("forward", (hidden_states,)) + output = torch.tensor(output) + return output + + +class VicunaEmbedding(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids): + output = self.model(input_ids) + return output + + +class VicunaEmbeddingCompiled(torch.nn.Module): + def __init__(self, shark_module): + super().__init__() + self.model = shark_module + + def forward(self, input_ids): + input_ids.detach() + output = self.model("forward", (input_ids,)) + output = torch.tensor(output) + return output diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 4745504ae5..60cf258f4e 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -33,16 +33,23 @@ def __init__( first_vicuna_vmfb_path=None, second_vicuna_vmfb_path=None, load_mlir_from_shark_tank=True, + low_device_memory=False, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device + if precision in ["int4", "int8"]: + print("int4 and int8 are not supported yet, using fp32") + precision = "fp32" self.precision = precision self.first_vicuna_vmfb_path = first_vicuna_vmfb_path self.second_vicuna_vmfb_path = second_vicuna_vmfb_path self.first_vicuna_mlir_path = first_vicuna_mlir_path self.second_vicuna_mlir_path = second_vicuna_mlir_path self.load_mlir_from_shark_tank = load_mlir_from_shark_tank + self.low_device_memory = low_device_memory + self.first_vic = None + self.second_vic = None if self.first_vicuna_mlir_path == None: self.first_vicuna_mlir_path = self.get_model_path() if self.second_vicuna_mlir_path == None: @@ -61,7 +68,7 @@ def get_model_path(self, model_number="first", suffix="mlir"): if suffix == "mlir": return Path(f"{model_number}_vicuna_{self.precision}.{suffix}") return Path( - f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}" + f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}" ) def get_tokenizer(self): @@ -87,7 +94,7 @@ def compile_first_vicuna(self): # Compilation path needs some more work before it is functional print( - f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with" + f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n" f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}" ) if self.first_vicuna_mlir_path.exists(): @@ -436,12 +443,19 @@ def generate(self, prompt, cli=False): # TODO: refactor for cleaner integration import gc + if not self.low_device_memory: + if self.first_vic == None: + self.first_vic = self.compile_first_vicuna() + if self.second_vic == None: + self.second_vic = self.compile_second_vicuna() res = [] res_tokens = [] params = { "prompt": prompt, "is_first": True, - "fv": self.compile_first_vicuna(), + "fv": self.compile_first_vicuna() + if self.first_vic == None + else self.first_vic, } generated_token_op = self.generate_new_token(params=params) @@ -457,18 +471,20 @@ def generate(self, prompt, cli=False): print(f"Assistant: {detok}", end=" ", flush=True) # Clear First Vic from Memory (main and cuda) - del params - torch.cuda.empty_cache() - gc.collect() + if self.low_device_memory: + del params + torch.cuda.empty_cache() + gc.collect() - sec_vic = self.compile_second_vicuna() for _ in range(self.max_num_tokens - 2): params = { "prompt": None, "is_first": False, "logits": logits, "pkv": pkv, - "sv": sec_vic, + "sv": self.compile_second_vicuna() + if self.second_vic == None + else self.second_vic, } generated_token_op = self.generate_new_token(params=params) @@ -489,9 +505,10 @@ def generate(self, prompt, cli=False): res.append(detok) if cli: print(f"{detok}", end=" ", flush=True) - del sec_vic, pkv, logits - torch.cuda.empty_cache() - gc.collect() + if self.device == "cuda": + del sec_vic, pkv, logits + torch.cuda.empty_cache() + gc.collect() for i in range(len(res_tokens)): if type(res_tokens[i]) != int: diff --git a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py index 79387af462..9cb3a428ea 100644 --- a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py @@ -4,6 +4,12 @@ CompiledFirstVicunaLayer, CompiledSecondVicunaLayer, ShardedVicunaModel, + LMHead, + LMHeadCompiled, + VicunaEmbedding, + VicunaEmbeddingCompiled, + VicunaNorm, + VicunaNormCompiled, ) from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase from shark.shark_importer import import_with_fx @@ -19,9 +25,11 @@ import torch import torch_mlir import os +import json class Vicuna(SharkLLMBase): + # Class representing Sharded Vicuna Model def __init__( self, model_name, @@ -29,21 +37,25 @@ def __init__( max_num_tokens=512, device="cuda", precision="fp32", + config_json=None, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device self.precision = precision self.tokenizer = self.get_tokenizer() + self.config = config_json self.shark_model = self.compile(device=device) def get_tokenizer(self): + # Retrieve the tokenizer from Huggingface tokenizer = AutoTokenizer.from_pretrained( self.hf_model_path, use_fast=False ) return tokenizer def get_src_model(self): + # Retrieve the torch model from Huggingface kwargs = {"torch_dtype": torch.float} vicuna_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs @@ -51,6 +63,8 @@ def get_src_model(self): return vicuna_model def write_in_dynamic_inputs0(self, module, dynamic_input_size): + # Current solution for ensuring mlir files support dynamic inputs + # TODO find a more elegant way to implement this new_lines = [] for line in module.splitlines(): line = re.sub(f"{dynamic_input_size}x", "?x", line) @@ -107,6 +121,7 @@ def compile_vicuna_layer( past_key_value0=None, past_key_value1=None, ): + # Compile a hidden decoder layer of vicuna if past_key_value0 is None and past_key_value1 is None: model_inputs = (hidden_states, attention_mask, position_ids) else: @@ -126,7 +141,154 @@ def compile_vicuna_layer( ) return mlir_bytecode + def get_device_index(self, layer_string): + # Get the device index from the config file + # In the event that different device indices are assigned to + # different parts of a layer, a majority vote will be taken and + # everything will be run on the most commonly used device + if self.config is None: + return None + idx_votes = {} + for key in self.config.keys(): + if re.search(layer_string, key): + if int(self.config[key]["gpu"]) in idx_votes.keys(): + idx_votes[int(self.config[key]["gpu"])] += 1 + else: + idx_votes[int(self.config[key]["gpu"])] = 1 + device_idx = max(idx_votes, key=idx_votes.get) + return device_idx + + def compile_lmhead( + self, lmh, hidden_states, device="cpu", device_idx=None + ): + # compile the lm head of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"lmhead.mlir") + vmfb_path = Path(f"lmhead.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + lmh, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="lmhead") + shark_module.load_module(vmfb_path) + compiled_module = LMHeadCompiled(shark_module) + return compiled_module + + def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): + # compile the normalization layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"norm.mlir") + vmfb_path = Path(f"norm.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + fvn, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="norm") + shark_module.load_module(vmfb_path) + compiled_module = VicunaNormCompiled(shark_module) + return compiled_module + + def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): + # compile the embedding layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"embedding.mlir") + vmfb_path = Path(f"embedding.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + input_ids = torch_mlir.TensorPlaceholder.like( + input_ids, dynamic_axes=[1] + ) + module = torch_mlir.compile( + fve, + (input_ids,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="embedding") + shark_module.load_module(vmfb_path) + compiled_module = VicunaEmbeddingCompiled(shark_module) + + return compiled_module + def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): + # compile all layers for vmfb + # this needs to be run seperatley for first and second vicuna mlirs, modules = [], [] for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"): if is_first: @@ -198,10 +360,6 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): verbose=False, ) - # bytecode_stream = BytesIO() - # module.operation.write_bytecode(bytecode_stream) - # bytecode = bytecode_stream.getvalue() - if is_first: module = self.write_in_dynamic_inputs0(str(module), 137) bytecode = module.encode("UTF-8") @@ -224,20 +382,25 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): if is_first: vmfb_path = Path(f"{idx}_0.vmfb") if vmfb_path.exists(): - # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( None, device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.load_module(vmfb_path) else: print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.save_module( @@ -255,19 +418,25 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): vmfb_path = Path(f"{idx}_1.vmfb") if vmfb_path.exists(): # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( None, device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.load_module(vmfb_path) else: print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.$]" + ) module = SharkInference( mlirs[idx], device=device, - device_idx=idx % 1, + device_idx=device_idx, mlir_dialect="tm_tensor", ) module.save_module( @@ -303,6 +472,42 @@ def get_sharded_model(self, device="cpu"): torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), ) + norm = VicunaNorm(vicuna_model.model.norm) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.norm(?:\.|\s|$)" + ) + print(device_idx) + norm = self.compile_norm( + norm, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + + embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" + ) + print(device_idx) + embeddings = self.compile_embedding( + embeddings, + (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), + device=self.device, + device_idx=device_idx, + ) + + lmhead = LMHead(vicuna_model.lm_head) + device_idx = self.get_device_index( + r"vicuna\.model\.lm_head(?:\.|\s|$)" + ) + print(device_idx) + lmhead = self.compile_lmhead( + lmhead, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + layers0 = [ FirstVicunaLayer(layer) for layer in vicuna_model.model.layers ] @@ -323,7 +528,12 @@ def get_sharded_model(self, device="cpu"): shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1] sharded_model = ShardedVicunaModel( - vicuna_model, shark_layers0, shark_layers1 + vicuna_model, + shark_layers0, + shark_layers1, + lmhead, + embeddings, + norm, ) return sharded_model diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 6d11f96d08..1fcc03db09 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -757,6 +757,14 @@ def save_output_img(output_img, img_seed, extra_info={}): if args.ckpt_loc: img_model = Path(os.path.basename(args.ckpt_loc)).stem + img_vae = None + if args.custom_vae: + img_vae = Path(os.path.basename(args.custom_vae)).stem + + img_lora = None + if args.use_lora: + img_lora = Path(os.path.basename(args.use_lora)).stem + if args.output_img_format == "jpg": out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") output_img.save(out_img_path, quality=95, subsampling=0) @@ -767,7 +775,9 @@ def save_output_img(output_img, img_seed, extra_info={}): if args.write_metadata_to_png: pngInfo.add_text( "parameters", - f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}", + f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps: {args.steps}," + f"Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}," + f"Size: {args.width}x{args.height}, Model: {img_model}, VAE: {img_vae}, LoRA: {img_lora}", ) output_img.save(out_img_path, "PNG", pnginfo=pngInfo) @@ -778,6 +788,9 @@ def save_output_img(output_img, img_seed, extra_info={}): "Image saved as png instead. Supported formats: png / jpg" ) + # To be as low-impact as possible to the existing CSV format, we append + # "VAE" and "LORA" to the end. However, it does not fit the hierarchy of + # importance for each data point. Something to consider. new_entry = { "VARIANT": img_model, "SCHEDULER": args.scheduler, @@ -791,6 +804,8 @@ def save_output_img(output_img, img_seed, extra_info={}): "WIDTH": args.width, "MAX_LENGTH": args.max_length, "OUTPUT": out_img_path, + "VAE": img_vae, + "LORA": img_lora, } new_entry.update(extra_info) diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index a0cbd59a62..004a052e5f 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -44,6 +44,7 @@ def launch_app(address): img2img_api, upscaler_api, inpaint_api, + outpaint_api, ) from fastapi import FastAPI, APIRouter import uvicorn @@ -55,9 +56,7 @@ def launch_app(address): app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"]) app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"]) app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"]) - # app.add_api_route( - # "/sdapi/v1/outpaint", outpaint_api, methods=["post"] - # ) + app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"]) app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"]) app.include_router(APIRouter()) uvicorn.run(app, host="127.0.0.1", port=args.server_port) diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 7df49bddec..6ed965cece 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -342,7 +342,7 @@ def img2img_api( ) # Converts generator type to subscriptable - res = list(res)[0] + res = next(res) return { "images": encode_pil_to_base64(res[0]), diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 082668e85b..be8a58def4 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -278,7 +278,7 @@ def inpaint_api( custom_model="None", hf_model_id=InputData["hf_model_id"] if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-1-base", + else "stabilityai/stable-diffusion-2-inpainting", custom_vae="None", precision="fp16", device=available_devices[0], @@ -289,6 +289,10 @@ def inpaint_api( lora_hf_id="", ondemand=False, ) + + # Converts generator type to subscriptable + res = next(res) + return { "images": encode_pil_to_base64(res[0]), "parameters": {}, diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index 0401275f2d..d6b0d2b317 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -287,7 +287,7 @@ def outpaint_api( custom_model="None", hf_model_id=InputData["hf_model_id"] if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-1-base", + else "stabilityai/stable-diffusion-2-inpainting", custom_vae="None", precision="fp16", device=available_devices[0], @@ -298,6 +298,10 @@ def outpaint_api( lora_hf_id="", ondemand=False, ) + + # Convert Generator to Subscriptable + res = next(res) + return { "images": encode_pil_to_base64(res[0]), "parameters": {}, diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 0e5cf4092d..369d864a31 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -140,6 +140,8 @@ def chat(curr_system_message, history, model, device, precision): choices=[ "fp16", "fp32", + "int4", + "int8", ], visible=True, ) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 44e41f1d4c..6e9ff38860 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -265,6 +265,10 @@ def txt2img_api( lora_hf_id="", ondemand=False, ) + + # Convert Generator to Subscriptable + res = next(res) + return { "images": encode_pil_to_base64(res[0]), "parameters": {}, @@ -301,7 +305,7 @@ def txt2img_api( ) txt2img_hf_model_id = gr.Textbox( elem_id="hf_model_id", - placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3, https://civitai.com/api/download/models/15236", + placeholder="Select 'None' in the dropdown on the left and enter model ID here", value="", label="HuggingFace Model ID or Civitai model download URL", lines=3, @@ -553,6 +557,9 @@ def txt2img_api( height, txt2img_custom_model, txt2img_hf_model_id, + lora_weights, + lora_hf_id, + custom_vae, ], outputs=[ txt2img_png_info_img, @@ -566,5 +573,8 @@ def txt2img_api( height, txt2img_custom_model, txt2img_hf_model_id, + lora_weights, + lora_hf_id, + custom_vae, ], ) diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index 902af4feb8..7a4b6469c1 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -300,7 +300,7 @@ def upscaler_api( ondemand=False, ) # Converts generator type to subscriptable - res = list(res)[0] + res = next(res) return { "images": encode_pil_to_base64(res[0]), diff --git a/apps/stable_diffusion/web/utils/metadata/png_metadata.py b/apps/stable_diffusion/web/utils/metadata/png_metadata.py index a9128ee108..51a92f07d6 100644 --- a/apps/stable_diffusion/web/utils/metadata/png_metadata.py +++ b/apps/stable_diffusion/web/utils/metadata/png_metadata.py @@ -62,6 +62,82 @@ def parse_generation_parameters(x: str): return res +def try_find_model_base_from_png_metadata( + file: str, folder: str = "models" +) -> str: + custom = "" + + # Remove extension from file info + if file.endswith(".safetensors") or file.endswith(".ckpt"): + file = Path(file).stem + # Check for the file name match with one of the local ckpt or safetensors files + if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file(): + custom = file + ".ckpt" + if Path( + get_custom_model_pathfile(file + ".safetensors", folder) + ).is_file(): + custom = file + ".safetensors" + + return custom + + +def find_model_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + png_hf_id = "" + png_custom = "" + + if key in metadata: + model_file = metadata[key] + png_custom = try_find_model_base_from_png_metadata(model_file) + # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") + if model_file in predefined_models: + png_custom = model_file + # If nothing had matched, check vendor/hf_model_id + if not png_custom and model_file.count("/"): + png_hf_id = model_file + # No matching model was found + if not png_custom and not png_hf_id: + print( + "Import PNG info: Unable to find a matching model for %s" + % model_file + ) + + return png_custom, png_hf_id + + +def find_vae_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> str: + vae_custom = "" + + if key in metadata: + vae_file = metadata[key] + vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") + + # VAE input is optional, should not print or throw an error if missing + + return vae_custom + + +def find_lora_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + lora_hf_id = "" + lora_custom = "" + + if key in metadata: + lora_file = metadata[key] + lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") + # If nothing had matched, check vendor/hf_model_id + if not lora_custom and lora_file.count("/"): + lora_hf_id = lora_file + + # LoRA input is optional, should not print or throw an error if missing + + return lora_custom, lora_hf_id + + def import_png_metadata( pil_data, prompt, @@ -74,40 +150,21 @@ def import_png_metadata( height, custom_model, hf_model_id, + custom_lora, + hf_lora_id, + custom_vae, ): try: png_info = pil_data.info["parameters"] metadata = parse_generation_parameters(png_info) - png_hf_model_id = "" - png_custom_model = "" - - if "Model" in metadata: - # Remove extension from model info - if metadata["Model"].endswith(".safetensors") or metadata[ - "Model" - ].endswith(".ckpt"): - metadata["Model"] = Path(metadata["Model"]).stem - # Check for the model name match with one of the local ckpt or safetensors files - if Path( - get_custom_model_pathfile(metadata["Model"] + ".ckpt") - ).is_file(): - png_custom_model = metadata["Model"] + ".ckpt" - if Path( - get_custom_model_pathfile(metadata["Model"] + ".safetensors") - ).is_file(): - png_custom_model = metadata["Model"] + ".safetensors" - # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") - if metadata["Model"] in predefined_models: - png_custom_model = metadata["Model"] - # If nothing had matched, check vendor/hf_model_id - if not png_custom_model and metadata["Model"].count("/"): - png_hf_model_id = metadata["Model"] - # No matching model was found - if not png_custom_model and not png_hf_model_id: - print( - "Import PNG info: Unable to find a matching model for %s" - % metadata["Model"] - ) + + (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( + "Model", metadata + ) + (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( + "LoRA", metadata + ) + vae_custom_model = find_vae_from_png_metadata("VAE", metadata) negative_prompt = metadata["Negative prompt"] steps = int(metadata["Steps"]) @@ -115,12 +172,24 @@ def import_png_metadata( seed = int(metadata["Seed"]) width = float(metadata["Size-1"]) height = float(metadata["Size-2"]) + if "Model" in metadata and png_custom_model: custom_model = png_custom_model hf_model_id = "" if "Model" in metadata and png_hf_model_id: custom_model = "None" hf_model_id = png_hf_model_id + + if "LoRA" in metadata and lora_custom_model: + custom_lora = lora_custom_model + hf_lora_id = "" + if "LoRA" in metadata and lora_hf_model_id: + custom_lora = "None" + hf_lora_id = lora_hf_model_id + + if "VAE" in metadata and vae_custom_model: + custom_vae = vae_custom_model + if "Prompt" in metadata: prompt = metadata["Prompt"] if "Sampler" in metadata: @@ -149,4 +218,7 @@ def import_png_metadata( height, custom_model, hf_model_id, + custom_lora, + hf_lora_id, + custom_vae, ) diff --git a/rest_api_tests/api_test.py b/rest_api_tests/api_test.py index 844dcfd58b..365dc51dcf 100644 --- a/rest_api_tests/api_test.py +++ b/rest_api_tests/api_test.py @@ -89,21 +89,155 @@ def img2img_test(): print(f"response from server was : {res.status_code}") - print("Extracting response object") + # NOTE Uncomment below to save the picture - # Uncomment below to save the picture + # print("Extracting response object") + # response_obj = res.json() + # img_b64 = response_obj.get("images", [False])[0] or response_obj.get( + # "image" + # ) + # img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", "")) + # im_file = BytesIO(img_b2) + # response_img = Image.open(im_file) + # print("Saving Response Image to: response_img") + # response_img.save(r"rest_api_tests/response_img.png") - response_obj = res.json() - img_b64 = response_obj.get("images", [False])[0] or response_obj.get( - "image" + +def inpainting_test(): + prompt = "Paint a rabbit riding on the dog" + negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" + seed = 2121991605 + height = 512 + width = 512 + steps = 50 + noise_level = 10 + cfg_scale = 7 + is_full_res = False + full_res_padding = 32 + image_path = r"./rest_api_tests/dog.png" + + img_file = open(image_path, "rb") + image = ( + "data:image/png;base64," + base64.b64encode(img_file.read()).decode() ) - img_b2 = base64.b64decode(img_b64.replace("data:image/png;base64,", "")) - im_file = BytesIO(img_b2) - response_img = Image.open(im_file) - print("Saving Response Image to: response_img") - response_img.save(r"rest_api_tests/response_img.png") + img_file = open(image_path, "rb") + mask = ( + "data:image/png;base64," + base64.b64encode(img_file.read()).decode() + ) + + url = "http://127.0.0.1:8080/sdapi/v1/inpaint" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + data = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "image": image, + "mask": mask, + "height": height, + "width": width, + "steps": steps, + "noise_level": noise_level, + "cfg_scale": cfg_scale, + "seed": seed, + "is_full_res": is_full_res, + "full_res_padding": full_res_padding, + } + + res = requests.post(url=url, json=data, headers=headers, timeout=1000) + + print(f"[Inpainting] response from server was : {res.status_code}") + + +def outpainting_test(): + prompt = "Paint a rabbit riding on the dog" + negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" + seed = 2121991605 + height = 512 + width = 512 + steps = 50 + cfg_scale = 7 + color_variation = 0.2 + noise_q = 0.2 + directions = ["up", "down", "right", "left"] + pixels = 32 + mask_blur = 64 + image_path = r"./rest_api_tests/dog.png" + + # Converting Image to Base64 + img_file = open(image_path, "rb") + init_images = [ + "data:image/png;base64," + base64.b64encode(img_file.read()).decode() + ] + + url = "http://127.0.0.1:8080/sdapi/v1/outpaint" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + data = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "seed": seed, + "height": height, + "width": width, + "steps": steps, + "cfg_scale": cfg_scale, + "color_variation": color_variation, + "noise_q": noise_q, + "directions": directions, + "pixels": pixels, + "mask_blur": mask_blur, + "init_images": init_images, + } + + res = requests.post(url=url, json=data, headers=headers, timeout=1000) + + print(f"[Outpaint] response from server was : {res.status_code}") + + +def txt2img_test(): + prompt = "Paint a rabbit in a top hate" + negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" + seed = 2121991605 + height = 512 + width = 512 + steps = 50 + cfg_scale = 7 + + url = "http://127.0.0.1:8080/sdapi/v1/txt2img" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + data = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "seed": seed, + "height": height, + "width": width, + "steps": steps, + "cfg_scale": cfg_scale, + } + + res = requests.post(url=url, json=data, headers=headers, timeout=1000) + + print(f"[txt2img] response from server was : {res.status_code}") if __name__ == "__main__": + txt2img_test() img2img_test() upscaler_test() + inpainting_test() + outpainting_test() diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index 8c79243129..3fc689fbe7 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -63,6 +63,7 @@ def get_supported_device_list(): _IREE_DEVICE_MAP = { "cpu": "local-task", "cpu-task": "local-task", + "AMD-AIE": "local-task", "cpu-sync": "local-sync", "cuda": "cuda", "vulkan": "vulkan", @@ -81,6 +82,7 @@ def iree_target_map(device): _IREE_TARGET_MAP = { "cpu": "llvm-cpu", "cpu-task": "llvm-cpu", + "AMD-AIE": "llvm-cpu", "cpu-sync": "llvm-cpu", "cuda": "cuda", "vulkan": "vulkan", diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 78ee1ca6d5..a05bfc89c6 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -20,6 +20,7 @@ import os import re import tempfile +from pathlib import Path # Get the iree-compile arguments given device. @@ -39,7 +40,10 @@ def get_iree_device_args(device, extra_args=[]): if device_uri[0] == "cpu": from shark.iree_utils.cpu_utils import get_iree_cpu_args - return get_iree_cpu_args() + data_tiling_flag = ["--iree-flow-enable-data-tiling"] + u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"] + + return get_iree_cpu_args() + data_tiling_flag + u_kernel_flag if device_uri[0] == "cuda": from shark.iree_utils.gpu_utils import get_iree_gpu_args @@ -355,6 +359,9 @@ def load_vmfb_using_mmap( # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. # (This would arise if we're invoking `compile` from a SharkInference obj) temp_file_to_unlink = None + + if isinstance(flatbuffer_blob_or_path, Path): + flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() if ( isinstance(flatbuffer_blob_or_path, str) and ".vmfb" in flatbuffer_blob_or_path diff --git a/shark/shark_downloader.py b/shark/shark_downloader.py index 8005ecc120..73a8054955 100644 --- a/shark/shark_downloader.py +++ b/shark/shark_downloader.py @@ -60,12 +60,15 @@ def download_public_file( else: continue - destination_filename = os.path.join(destination_folder_name, blob_name) - if os.path.isdir(destination_filename): - continue - with open(destination_filename, "wb") as f: - with tqdm.wrapattr(f, "write", total=blob.size) as file_obj: - storage_client.download_blob_to_file(blob, file_obj) + else: + destination_filename = os.path.join( + destination_folder_name, blob_name + ) + if os.path.isdir(destination_filename): + continue + with open(destination_filename, "wb") as f: + with tqdm.wrapattr(f, "write", total=blob.size) as file_obj: + storage_client.download_blob_to_file(blob, file_obj) input_type_to_np_dtype = { diff --git a/shark/shark_eager/shark_eager.py b/shark/shark_eager/shark_eager.py new file mode 100644 index 0000000000..807c6c014e --- /dev/null +++ b/shark/shark_eager/shark_eager.py @@ -0,0 +1,206 @@ +from typing import Any, Dict, List, Tuple +from collections import defaultdict +from shark.shark_importer import import_with_fx +import torchvision.models as models +import copy +import io +import numpy as np +import sys +import torch +import torch.fx +from torch.fx.node import Node +from typing import Dict +import torch_mlir + + +def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"): + mlir_module = torch_mlir.compile( + fx_g, inputs, output_type="linalg-on-tensors" + ) + bytecode_stream = io.BytesIO() + mlir_module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + from shark.shark_inference import SharkInference + + shark_module = SharkInference( + mlir_module=bytecode, + device=device, + mlir_dialect="tm_tensor", + ) + shark_module.compile(extra_args=[]) + return shark_module + + +def _make_single_op_gm(node, captured_val, compiled_graph): + """Make a GraphModule that just executes the given node.""" + g = torch.fx.Graph() + env = {} + inputs = [] + for arg in node.args: + if arg and hasattr(arg, "name"): + env[arg.name] = g.placeholder(arg.name) + if isinstance(captured_val[arg.name], (list, tuple)): + for val in captured_val[arg.name]: + inputs.append(val) + else: + inputs.append(captured_val[arg.name]) + + call = g.node_copy(node, lambda n: env[n.name]) + g.output(call) + g.lint() + single_node = torch.fx.GraphModule(torch.nn.Module(), g) + compiled_module = shark_backend(single_node, inputs) + compiled_graph[node.name] = { + "module": compiled_module, + "inputs": [i for i in env], + "result": None, + } + return + + +def compiled_graph(gm: torch.fx.GraphModule, attr_info): + compiled_graph = {} + g = gm.graph + for node in g.nodes: + if node.op == "call_function": + if not ( + node.target in [torch.ops.aten.empty] + or node.name.startswith("getitem") + ): + _make_single_op_gm(node, attr_info, compiled_graph) + + # Currently torch.aten.empty has an compilation issue, so running natively. + elif node.target in [torch.ops.aten.empty]: + compiled_graph[node.name] = { + "target": node.target, + "args": node.args, + "kwargs": node.kwargs, + "result": None, + } + # Get item is a simple case takes a tuple and return the tensor at a particular index. + elif node.name.startswith("getitem"): + compiled_graph[node.name] = { + "input": node.args[0].name, + "pos": node.args[1], + "result": None, + } + + return compiled_graph + + +class ShapeProp: + """ + Shape propagation. This class takes a `GraphModule`. + Then, its `propagate` method executes the `GraphModule` + node-by-node with the given arguments. As each operation + executes, the ShapeProp class stores away the shape and + element type for the output values of each operation on + the `shape` and `dtype` attributes of the operation's + `Node`. + """ + + def __init__(self, mod): + self.mod = mod + self.graph = mod.graph + self.modules = dict(self.mod.named_modules()) + + def propagate(self, *args): + args_iter = iter(args) + env: Dict[str, Node] = {} + + def load_arg(a): + return torch.fx.graph.map_arg(a, lambda n: env[n.name]) + + def fetch_attr(target: str): + target_atoms = target.split(".") + attr_itr = self.mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + for node in self.graph.nodes: + if node.op == "placeholder": + result = next(args_iter) + elif node.op == "get_attr": + result = fetch_attr(node.target) + elif node.op == "call_function": + result = node.target( + *load_arg(node.args), **load_arg(node.kwargs) + ) + elif node.op == "call_method": + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == "call_module": + result = self.modules[node.target]( + *load_arg(node.args), **load_arg(node.kwargs) + ) + + # This is the only code specific to shape propagation. + # you can delete this `if` branch and this becomes + # a generic GraphModule interpreter. + if isinstance(result, torch.Tensor): + node.shape = result.shape + node.dtype = result.dtype + + env[node.name] = result + + return env + + # return load_arg(self.graph.result) + + +resnet18 = models.resnet18(pretrained=True) +resnet18.train(False) +input = (torch.randn(1, 3, 224, 224),) + +print(resnet18(input[0])) + +fx_graph = import_with_fx(resnet18, input, mlir_type="fx") + +shape_prop = ShapeProp(fx_graph) + +x = shape_prop.propagate(input[0]) + +shark_graph = compiled_graph(fx_graph, x) + + +for key in shark_graph: + if key.startswith("getitem"): + input_val = shark_graph[key]["input"] + pos = shark_graph[key]["pos"] + if input_val not in shark_graph: + shark_graph[key]["result"] = x[input_val][pos].detach() + else: + shark_graph[key]["result"] = shark_graph[input_val]["result"][ + pos + ].detach() + elif key.startswith("empty"): + operator = shark_graph[key]["target"] + args = shark_graph[key]["args"] + kwargs = shark_graph[key]["kwargs"] + shark_graph[key]["result"] = operator(*args, **kwargs).detach() + else: + input_val = shark_graph[key]["inputs"] + input_tensors = [] + for input in input_val: + if input not in shark_graph: + input_tensors.append(x[input].detach()) + else: + input_tensors.append(shark_graph[input]["result"]) + + val = shark_graph[key]["module"]("forward", input_tensors) + if isinstance(val, (tuple, list)): + list_val = [] + for v in val: + list_val.append(torch.from_numpy(v)) + shark_graph[key]["result"] = list_val + else: + shark_graph[key]["result"] = torch.from_numpy(val) + + +print(shark_graph) diff --git a/shark/shark_generate_model_config.py b/shark/shark_generate_model_config.py index 5041242ed1..8eb70e0fa9 100644 --- a/shark/shark_generate_model_config.py +++ b/shark/shark_generate_model_config.py @@ -1,5 +1,8 @@ +import re import json -from collections import OrderedDict +import torch_mlir +from iree.compiler import compile_str +from shark.shark_importer import import_with_fx, get_f16_inputs class GenerateConfigFile: @@ -7,7 +10,9 @@ def __init__( self, model, num_sharding_stages: int, - sharding_stages_id: list[str] = None, + sharding_stages_id: list[str], + model_input=None, + config_file_path="model_config.json", ): self.model = model self.num_sharding_stages = num_sharding_stages @@ -15,8 +20,67 @@ def __init__( assert self.num_sharding_stages == len( self.sharding_stages_id ), "Number of sharding stages should be equal to the list of their ID" + self.model_input = model_input + self.config_file_path = config_file_path - def generate_json(self): + def split_into_dispatches( + self, + backend, + fx_tracing_required=True, + f16_model=False, + torch_mlir_tracing=False, + ): + graph_for_compilation = self.model + if fx_tracing_required: + graph_for_compilation = import_with_fx( + self.model, + self.model_input, + is_f16=f16_model, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + + module = torch_mlir.compile( + graph_for_compilation, + (self.model_input), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=torch_mlir_tracing, + verbose=False, + ) + module = module.operation.get_asm(large_elements_limit=4) + compiled_module_str = str( + compile_str( + str(module), + target_backends=[backend], + extra_args=[ + "--compile-to=flow", + "--mlir-elide-elementsattrs-if-larger=4", + ], + ) + ) + + substring_start_idx = [ + m.start() + for m in re.finditer("flow.dispatch @", compiled_module_str) + ] + dispatch_list = dict() + + # dispatch_no is the 'i'th index of a dispatch out of n total dispatches of a model + # dispatch_id is the unique id of a dispatch, multiple instances of the same dispatch + # can occur in a model + for dispatch_no, substring_idx in enumerate(substring_start_idx): + dispatch_idx = ( + compiled_module_str[substring_idx:] + .split(":")[0] + .split("@")[-1] + ) + key = "dispatch_no_" + str(dispatch_no) + dispatch_list[key] = {n: "None" for n in self.sharding_stages_id} + dispatch_list[key]["dispatch_id"] = dispatch_idx + + self.generate_json(dispatch_list) + + def split_into_layers(self): model_dictionary = dict() for name, m in self.model.named_modules(): @@ -34,5 +98,8 @@ def generate_json(self): layer_dict = {n: "None" for n in self.sharding_stages_id} model_dictionary[name] = layer_dict - with open("model_config.json", "w") as outfile: - json.dump(model_dictionary, outfile) + self.generate_json(model_dictionary) + + def generate_json(self, artifacts): + with open(self.config_file_path, "w") as outfile: + json.dump(artifacts, outfile) diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 64480d02ce..e12f7c0922 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -555,6 +555,9 @@ def strip_overloads(gm): add_upcast(fx_g) fx_g.recompile() + if mlir_type == "fx": + return fx_g + if training: change_fx_graph_return_to_tuple(fx_g) inputs = flatten_training_input(inputs)