diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 1bef8508ac..1467a0e1de 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -436,6 +436,14 @@ def compile(self): # return tuple of shark_modules once mem is supported # return fvic_shark_model, svic_shark_model + def decode_tokens(self, res_tokens): + for i in range(len(res_tokens)): + if type(res_tokens[i]) != int: + res_tokens[i] = int(res_tokens[i][0]) + + res_str = self.tokenizer.decode(res_tokens) + return res_str + def generate(self, prompt, cli=False): # TODO: refactor for cleaner integration import gc @@ -445,7 +453,6 @@ def generate(self, prompt, cli=False): self.first_vic = self.compile_first_vicuna() if self.second_vic == None: self.second_vic = self.compile_second_vicuna() - res = [] res_tokens = [] params = { "prompt": prompt, @@ -461,8 +468,8 @@ def generate(self, prompt, cli=False): logits = generated_token_op["logits"] pkv = generated_token_op["pkv"] detok = generated_token_op["detok"] + yield detok - res.append(detok) res_tokens.append(token) if cli: print(f"Assistant: {detok}", end=" ", flush=True) @@ -495,25 +502,24 @@ def generate(self, prompt, cli=False): break res_tokens.append(token) if detok == "<0x0A>": - res.append("\n") if cli: print("\n", end="", flush=True) else: - res.append(detok) if cli: print(f"{detok}", end=" ", flush=True) + + if len(res_tokens) % 3 == 0: + part_str = self.decode_tokens(res_tokens) + yield part_str + if self.device == "cuda": del sec_vic, pkv, logits torch.cuda.empty_cache() gc.collect() - for i in range(len(res_tokens)): - if type(res_tokens[i]) != int: - res_tokens[i] = int(res_tokens[i][0]) - - res_str = self.tokenizer.decode(res_tokens) + res_str = self.decode_tokens(res_tokens) # print(f"[DEBUG] final output : \n{res_str}") - return res_str + yield res_str def generate_new_token(self, params, debug=False): def forward_first(first_vic, prompt, cache_outputs=False): diff --git a/apps/stable_diffusion/scripts/img2img.py b/apps/stable_diffusion/scripts/img2img.py index 4bd568c2f7..8175ee1ed4 100644 --- a/apps/stable_diffusion/scripts/img2img.py +++ b/apps/stable_diffusion/scripts/img2img.py @@ -103,6 +103,7 @@ def main(): dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, use_stencil=use_stencil, ) total_time = time.time() - start_time diff --git a/apps/stable_diffusion/scripts/inpaint.py b/apps/stable_diffusion/scripts/inpaint.py index ce5d6c8e09..6aa48327f7 100644 --- a/apps/stable_diffusion/scripts/inpaint.py +++ b/apps/stable_diffusion/scripts/inpaint.py @@ -81,6 +81,7 @@ def main(): dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, ) total_time = time.time() - start_time text_output = f"prompt={args.prompts}" diff --git a/apps/stable_diffusion/scripts/outpaint.py b/apps/stable_diffusion/scripts/outpaint.py index 10ef8f7ed9..affe8f31d3 100644 --- a/apps/stable_diffusion/scripts/outpaint.py +++ b/apps/stable_diffusion/scripts/outpaint.py @@ -79,6 +79,7 @@ def main(): dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, ) total_time = time.time() - start_time text_output = f"prompt={args.prompts}" diff --git a/apps/stable_diffusion/scripts/upscaler.py b/apps/stable_diffusion/scripts/upscaler.py index f5a18b115d..baf433aed4 100644 --- a/apps/stable_diffusion/scripts/upscaler.py +++ b/apps/stable_diffusion/scripts/upscaler.py @@ -73,6 +73,7 @@ dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, ) total_time = time.time() - start_time text_output = f"prompt={args.prompts}" diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 93e416d740..a3c4831b68 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -520,16 +520,17 @@ def forward(self, latent, timestep, text_embedding, noise_level): torch.nn.functional.pad(inputs[2], pad), inputs[3]) input_mask = [True, True, True, False] + model_name = "unet512" if use_large else "unet" shark_unet, unet_mlir = compile_through_fx( unet, inputs, - extended_model_name=self.model_name["unet"], + extended_model_name=self.model_name[model_name], is_f16=is_f16, f16_input_mask=input_mask, use_tuned=self.use_tuned, extra_args=get_opt_flags("unet", precision=self.precision), base_model_id=self.base_model_id, - model_name="unet", + model_name=model_name, precision=self.precision, return_mlir=self.return_mlir, ) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py index 2f8bcd8bc9..24ad167cea 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py @@ -135,6 +135,7 @@ def generate_images( dtype, use_base_vae, cpu_scheduling, + max_embeddings_multiples, use_stencil, ): # prompts and negative prompts must be a list. @@ -156,7 +157,10 @@ def generate_images( # Get text embeddings with weight emphasis from prompts text_embeddings = self.encode_prompts_weight( - prompts, neg_prompts, max_length + prompts, + neg_prompts, + max_length, + max_embeddings_multiples=max_embeddings_multiples, ) # guidance scale as a float32 tensor. diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py index e15aec33ea..515a7bd5ea 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py @@ -378,6 +378,7 @@ def generate_images( dtype, use_base_vae, cpu_scheduling, + max_embeddings_multiples, ): # prompts and negative prompts must be a list. if isinstance(prompts, str): @@ -408,7 +409,10 @@ def generate_images( # Get text embeddings with weight emphasis from prompts text_embeddings = self.encode_prompts_weight( - prompts, neg_prompts, max_length + prompts, + neg_prompts, + max_length, + max_embeddings_multiples=max_embeddings_multiples, ) # guidance scale as a float32 tensor. diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py index bc5d509e24..d782633d97 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py @@ -379,6 +379,7 @@ def generate_images( dtype, use_base_vae, cpu_scheduling, + max_embeddings_multiples, ): # prompts and negative prompts must be a list. if isinstance(prompts, str): @@ -409,7 +410,10 @@ def generate_images( # Get text embeddings with weight emphasis from prompts text_embeddings = self.encode_prompts_weight( - prompts, neg_prompts, max_length + prompts, + neg_prompts, + max_length, + max_embeddings_multiples=max_embeddings_multiples, ) # guidance scale as a float32 tensor. diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index f69d701300..e7fa415f78 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -204,6 +204,7 @@ def generate_images( dtype, use_base_vae, cpu_scheduling, + max_embeddings_multiples, use_stencil, ): # Control Embedding check & conversion @@ -230,7 +231,10 @@ def generate_images( # Get text embeddings with weight emphasis from prompts text_embeddings = self.encode_prompts_weight( - prompts, neg_prompts, max_length + prompts, + neg_prompts, + max_length, + max_embeddings_multiples=max_embeddings_multiples, ) # guidance scale as a float32 tensor. diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py index a3f91cba14..5c50f79c9d 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py @@ -168,7 +168,10 @@ def produce_img_latents( text_embeddings = torch.from_numpy(text_embeddings).to(dtype) text_embeddings_numpy = text_embeddings.detach().numpy() self.status = SD_STATE_IDLE - self.load_unet() + if text_embeddings.shape[1] <= self.model_max_length: + self.load_unet() + else: + self.load_unet_512() for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() latent_model_input = torch.cat([latents] * 2) @@ -182,15 +185,26 @@ def produce_img_latents( # Profiling Unet. profile_device = start_profiling(file_path="unet.rdc") - noise_pred = self.unet( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - noise_level, - ), - ) + if text_embeddings.shape[1] <= self.model_max_length: + noise_pred = self.unet( + "forward", + ( + latent_model_input, + timestep, + text_embeddings_numpy, + noise_level, + ), + ) + else: + noise_pred = self.unet_512( + "forward", + ( + latent_model_input, + timestep, + text_embeddings_numpy, + noise_level, + ), + ) end_profiling(profile_device) noise_pred = torch.from_numpy(noise_pred) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) @@ -219,6 +233,7 @@ def produce_img_latents( if self.ondemand: self.unload_unet() + self.unload_unet_512() avg_step_time = step_time_sum / len(total_timesteps) self.log += f"\nAverage step time: {avg_step_time}ms/it" @@ -243,6 +258,7 @@ def generate_images( dtype, use_base_vae, cpu_scheduling, + max_embeddings_multiples, ): # prompts and negative prompts must be a list. if isinstance(prompts, str): @@ -264,7 +280,10 @@ def generate_images( # Get text embeddings with weight emphasis from prompts text_embeddings = self.encode_prompts_weight( - prompts, neg_prompts, max_length + prompts, + neg_prompts, + max_length, + max_embeddings_multiples=max_embeddings_multiples, ) # 4. Preprocess image diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 6ed965cece..d3f87d38d9 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -249,6 +249,7 @@ def img2img_inf( dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, use_stencil=use_stencil, ) seeds.append(img_seed) diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index be8a58def4..c13f5eacad 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -204,6 +204,7 @@ def inpaint_inf( dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, ) seeds.append(img_seed) total_time = time.time() - start_time diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index d6b0d2b317..35fb93ee63 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -211,6 +211,7 @@ def outpaint_inf( dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, ) seeds.append(img_seed) total_time = time.time() - start_time diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index 58cbaa72d9..07e9d7b631 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -65,16 +65,11 @@ def chat(curr_system_message, history, model, device, precision): ) prompt = messages.strip() print("prompt = ", prompt) - sentence = vicuna_model.generate(prompt) - partial_text = "" - for new_text in sentence.split(" "): - # print(new_text) - partial_text += new_text + " " + for partial_text in vicuna_model.generate(prompt): history[-1][1] = partial_text - # Yield an empty string to cleanup the message textbox and the updated conversation history yield history - history[-1][1] = sentence + return history # else Model is StableLM diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index 7a4b6469c1..882da8a957 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -202,6 +202,7 @@ def upscaler_inf( dtype, args.use_base_vae, cpu_scheduling, + args.max_embeddings_multiples, ) if global_obj.get_sd_status() == SD_STATE_CANCEL: break diff --git a/shark/examples/shark_inference/mega_test.py b/shark/examples/shark_inference/mega_test.py index efc5e70b79..a4e6f6b406 100644 --- a/shark/examples/shark_inference/mega_test.py +++ b/shark/examples/shark_inference/mega_test.py @@ -1,10 +1,7 @@ import torch import torch_mlir from shark.shark_inference import SharkInference -from apps.stable_diffusion.src.utils import ( - compile_through_fx, - args, -) +from shark.shark_compile import shark_compile_through_fx from MEGABYTE_pytorch import MEGABYTE import os @@ -37,23 +34,22 @@ def forward(self, input): megaModel = MegaModel() -input = [torch.randint(0, 16000, (1, 1024, 4))] +inputs = [torch.randint(0, 16000, (1, 1024, 4))] # CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :- # 1. aten.alias -shark_module, _ = compile_through_fx( - megaModel, - inputs=input, +shark_module, _ = shark_compile_through_fx( + model=megaModel, + inputs=inputs, extended_model_name="mega_shark", - debug=False, - generate_vmfb=True, + is_f16=False, + f16_input_mask=None, save_dir=os.getcwd(), + debug=False, + generate_or_load_vmfb=True, extra_args=[], - base_model_id=None, - model_name="mega_shark", - precision=None, - return_mlir=True, device="cuda", + mlir_dialect="tm_tensor", ) # logits = model(x) @@ -63,10 +59,10 @@ def print_output_info(output, msg): print("\n\t", output.shape) -ans = shark_module("forward", input) +ans = shark_module("forward", inputs) print_output_info(torch.from_numpy(ans), "SHARK's output") -ans = megaModel.forward(*input) +ans = megaModel.forward(*inputs) print_output_info(ans, "ORIGINAL Model's output") # and sample from the logits accordingly diff --git a/shark/shark_compile.py b/shark/shark_compile.py new file mode 100644 index 0000000000..79431155f5 --- /dev/null +++ b/shark/shark_compile.py @@ -0,0 +1,99 @@ +import os +import tempfile +from shark.shark_inference import SharkInference +from shark.shark_importer import import_with_fx + + +def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]): + vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") + shark_module = None + if os.path.isfile(vmfb_path): + shark_module = SharkInference( + None, + device=device, + mlir_dialect=mlir_dialect, + ) + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + return shark_module + + +def compile_module( + shark_module, extended_model_name, generate_vmfb, extra_args=[] +): + if generate_vmfb: + vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") + if os.path.isfile(vmfb_path): + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + else: + print( + "No vmfb found. Compiling and saving to {}".format(vmfb_path) + ) + path = shark_module.save_module( + os.getcwd(), extended_model_name, extra_args + ) + shark_module.load_module(path, extra_args=extra_args) + else: + shark_module.compile(extra_args) + return shark_module + + +def shark_compile_through_fx( + model, + inputs, + extended_model_name, + is_f16=False, + f16_input_mask=None, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=None, + mlir_dialect="tm_tensor", +): + if generate_or_load_vmfb: + shark_module = load_vmfb( + extended_model_name=extended_model_name, + device=device, + mlir_dialect=mlir_dialect, + extra_args=extra_args, + ) + if shark_module: + return ( + shark_module, + None, + ) + + from shark.parser import shark_args + + if "cuda" in device: + shark_args.enable_tf32 = True + + ( + mlir_module, + _, + ) = import_with_fx( + model=model, + inputs=inputs, + is_f16=is_f16, + f16_input_mask=f16_input_mask, + debug=debug, + model_name=extended_model_name, + save_dir=save_dir, + ) + + shark_module = SharkInference( + mlir_module, + device=device, + mlir_dialect=mlir_dialect, + ) + return ( + compile_module( + shark_module, + extended_model_name, + generate_vmfb=generate_or_load_vmfb, + extra_args=extra_args, + ), + mlir_module, + )