From f76fc4a31d1d6e71fbdbc637b5a5994827730c32 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 11:04:27 -0500 Subject: [PATCH] Formatting --- apps/shark_studio/api/llm.py | 8 +++----- apps/shark_studio/api/sd.py | 35 +++++++++++++++++++--------------- apps/shark_studio/api/utils.py | 19 ++++++++++++++---- 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 5ce9c2fc2d..217fb6784f 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -155,7 +155,9 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name]["initializer"]( + self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][ + "initializer" + ]( self.hf_model_name, hf_auth_token, compile_to="torch", @@ -273,10 +275,6 @@ def format_out(results): self.prev_token_len = token_len + len(history) if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]: - if ( - format_out(token) - == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] - ): break for i in range(len(history)): diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 744649db76..d064e04a8f 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -10,7 +10,9 @@ from pathlib import Path from random import randint from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline -from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import SharkSDXLPipeline +from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import ( + SharkSDXLPipeline, +) from apps.shark_studio.api.controlnet import control_adapter_map @@ -104,9 +106,9 @@ def __init__( pipe_id_list.append(custom_vae) self.pipe_id = "_".join(pipe_id_list) self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) - self.weights_path = Path(os.path.join( - get_checkpoints_path(), safe_name(self.base_model_id) - )) + self.weights_path = Path( + os.path.join(get_checkpoints_path(), safe_name(self.base_model_id)) + ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) @@ -140,18 +142,21 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): weights = copy.deepcopy(self.model_map) if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights( - custom_weights - ) + custom_weights_params, _ = process_custom_pipe_weights(custom_weights) for key in weights: if key not in ["vae_decode", "pipeline", "full_pipeline"]: weights[key] = custom_weights_params - - vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) + vmfbs, weights = self.sd_pipe.check_prepared( + mlirs, vmfbs, weights, interactive=False + ) print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") - self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline) - print("\n[LOG] Pipeline successfully prepared for runtime. Generating images...") + self.sd_pipe.load_pipeline( + vmfbs, weights, self.rt_device, self.compiled_pipeline + ) + print( + "\n[LOG] Pipeline successfully prepared for runtime. Generating images..." + ) return def generate_images( @@ -236,7 +241,7 @@ def shark_sd_fn( control_mode = None hints = [] num_loras = 0 - import_ir=True + import_ir = True for i in embeddings: num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: @@ -305,7 +310,6 @@ def shark_sd_fn( # Initializes the pipeline and retrieves IR based on all # parameters that are static in the turbine output format, # which is currently MLIR in the torch dialect. - sd_pipe = StableDiffusion( **submit_pipe_kwargs, @@ -325,7 +329,7 @@ def shark_sd_fn( out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) # total_time = time.time() - start_time # text_output = f"Total image(s) generation time: {total_time:.4f}sec" - #print(f"\n[LOG] {text_output}") + # print(f"\n[LOG] {text_output}") # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: @@ -352,8 +356,9 @@ def view_json_file(file_path): content = fopen.read() return content + def safe_name(name): - return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") + return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") if __name__ == "__main__": diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index f213320045..0e53bd4a5a 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -18,6 +18,7 @@ get_iree_vulkan_runtime_flags, ) + def get_available_devices(): def get_devices_by_name(driver_name): from shark.iree_utils._common import iree_device_map @@ -49,7 +50,7 @@ def get_devices_by_name(driver_name): return device_list set_iree_runtime_flags() - + available_devices = [] from shark.iree_utils.vulkan_utils import ( get_all_vulkan_devices, @@ -78,6 +79,7 @@ def get_devices_by_name(driver_name): available_devices.extend(cpu_device) return available_devices + def set_init_device_flags(): if "vulkan" in cmd_opts.device: # set runtime flags for vulkan. @@ -126,8 +128,14 @@ def set_iree_runtime_flags(): ] set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + def parse_device(device_str): - from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map + from shark.iree_utils.compile_utils import ( + clean_device_info, + get_iree_target_triple, + iree_target_map, + ) + rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: @@ -147,7 +155,7 @@ def parse_device(device_str): def get_rocm_target_chip(device_str): - #TODO: Use a data file to map device_str to target chip. + # TODO: Use a data file to map device_str to target chip. rocm_chip_map = { "6700": "gfx1031", "6800": "gfx1030", @@ -164,7 +172,10 @@ def get_rocm_target_chip(device_str): for key in rocm_chip_map: if key in device_str: return rocm_chip_map[key] - raise AssertionError(f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues.") + raise AssertionError( + f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues." + ) + def get_all_devices(driver_name): """