From 0f9930096a596c7f21b156d4c9b0dfa0bfc9ee5d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 24 Apr 2024 04:01:28 -0500 Subject: [PATCH 01/20] Shark Studio SDXL support, HIP driver support, simpler device info, small fixes --- .gitignore | 5 + apps/shark_studio/api/sd.py | 405 ++++++------------------------ apps/shark_studio/api/utils.py | 96 +++---- apps/shark_studio/web/ui/sd.py | 7 +- requirements.txt | 7 +- setup_venv.ps1 | 4 +- shark/iree_utils/_common.py | 2 + shark/iree_utils/compile_utils.py | 3 + shark/iree_utils/gpu_utils.py | 41 +-- 9 files changed, 166 insertions(+), 404 deletions(-) diff --git a/.gitignore b/.gitignore index f67152b007..bf07b2794f 100644 --- a/.gitignore +++ b/.gitignore @@ -188,6 +188,11 @@ variants.json # models folder apps/stable_diffusion/web/models/ +# model artifacts (SHARK) +*.tempfile +*.mlir +*.vmfb + # Stencil annotators. stencil_annotator/ diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index b4f0f0ddc0..fe0f14bbc5 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -9,14 +9,15 @@ from pathlib import Path from random import randint from turbine_models.custom_models.sd_inference import clip, unet, vae +from turbine_models.custom_models.sdxl_inference import sdxl_compiled_pipeline from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label from apps.shark_studio.web.utils.file_utils import ( safe_name, get_resource_path, get_checkpoints_path, ) -from apps.shark_studio.modules.pipeline import SharkPipelineBase from apps.shark_studio.modules.schedulers import get_schedulers from apps.shark_studio.modules.prompt_encoding import ( get_weighted_text_embeddings, @@ -32,8 +33,6 @@ preprocessCKPT, process_custom_pipe_weights, ) -from transformers import CLIPTokenizer -from diffusers.image_processor import VaeImageProcessor sd_model_map = { "clip": { @@ -47,8 +46,15 @@ }, } +EMPTY_FLAGS = { + "clip": None, + "unet": None, + "vae": None, + "pipeline": None, +} + -class StableDiffusion(SharkPipelineBase): +class StableDiffusion: # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile @@ -61,6 +67,8 @@ def __init__( height: int, width: int, batch_size: int, + steps: int, + scheduler: str, precision: str, device: str, custom_vae: str = None, @@ -69,58 +77,18 @@ def __init__( is_controlled: bool = False, hf_auth_token=None, ): - self.model_max_length = 77 - self.batch_size = batch_size - self.precision = precision - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.height = height - self.width = width - self.scheduler_obj = {} - static_kwargs = { - "pipe": { - "external_weights": "safetensors", - }, - "clip": {"hf_model_name": base_model_id}, - "unet": { - "hf_model_name": base_model_id, - "unet_model": unet.UnetModel(hf_model_name=base_model_id), - "batch_size": batch_size, - # "is_controlled": is_controlled, - # "num_loras": num_loras, - "height": height, - "width": width, - "precision": precision, - "max_length": self.model_max_length, - }, - "vae_encode": { - "hf_model_name": base_model_id, - "vae_model": vae.VaeModel( - hf_model_name=custom_vae if custom_vae else base_model_id, - ), - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - "vae_decode": { - "hf_model_name": base_model_id, - "vae_model": vae.VaeModel( - hf_model_name=custom_vae if custom_vae else base_model_id, - ), - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - } - super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir) + self.compiled_pipeline = False + self.base_model_id = base_model_id + external_weights = "safetensors" + max_length = 64 + target_backend, self.rt_device, triple = parse_device(device) pipe_id_list = [ safe_name(base_model_id), str(batch_size), - str(self.model_max_length), + str(max_length), f"{str(height)}x{str(width)}", precision, - self.device, + triple, ] if num_loras > 0: pipe_id_list.append(str(num_loras) + "lora") @@ -129,227 +97,67 @@ def __init__( if custom_vae: pipe_id_list.append(custom_vae) self.pipe_id = "_".join(pipe_id_list) + self.weights_path = os.path.join( + get_checkpoints_path(), safe_name(self.base_model_id) + ) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + self.sd_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( + hf_model_name=base_model_id, + scheduler_id=scheduler, + height=height, + width=width, + precision=precision, + max_length=max_length, + batch_size=batch_size, + num_inference_steps=steps, + device=target_backend, + iree_target_triple=triple, + ireec_flags=EMPTY_FLAGS, + attn_spec=None, + decomp_attn=True if "gfx9" not in triple else False, + pipeline_dir=self.pipe_id, + external_weights_dir=self.weights_path, + external_weights=external_weights, + ) print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") - del static_kwargs gc.collect() def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") self.is_img2img = is_img2img - self.schedulers = get_schedulers(self.base_model_id) - - self.weights_path = os.path.join( - get_checkpoints_path(), self.safe_name(self.base_model_id) - ) - if not os.path.exists(self.weights_path): - os.mkdir(self.weights_path) - - for model in adapters: - self.model_map[model] = adapters[model] - - for submodel in self.static_kwargs: - if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) - if submodel not in ["clip", "clip2"]: - self.static_kwargs[submodel][ - "external_weights" - ] = custom_weights_params - else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" - ) - else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" - ) - - self.get_compiled_map(pipe_id=self.pipe_id) - print("\n[LOG] Pipeline successfully prepared for runtime.") + mlirs = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, + "pipeline": None, + "full_pipeline": None, + } + vmfbs = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, + "pipeline": None, + "full_pipeline": None, + } + weights = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, + "pipeline": None, + "full_pipeline": None, + } + vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) + print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") + self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline) + print("\n[LOG] Pipeline successfully prepared for runtime. Generating images...") return - def encode_prompts_weight( - self, - prompt, - negative_prompt, - do_classifier_free_guidance=True, - ): - # Encodes the prompt into text encoder hidden states. - self.load_submodels(["clip"]) - self.tokenizer = CLIPTokenizer.from_pretrained( - self.base_model_id, - subfolder="tokenizer", - ) - clip_inf_start = time.time() - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - ) - - if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - pad = (0, 0) * (len(text_embeddings.shape) - 2) - pad = pad + ( - 0, - self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1], - ) - text_embeddings = torch.nn.functional.pad(text_embeddings, pad) - - # SHARK: Report clip inference time - clip_inf_time = (time.time() - clip_inf_start) * 1000 - if self.ondemand: - self.unload_submodels(["clip"]) - gc.collect() - print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") - - return text_embeddings.numpy().astype(np.float16) - - def prepare_latents( - self, - generator, - num_inference_steps, - image, - strength, - ): - noise = torch.randn( - ( - self.batch_size, - 4, - self.height // 8, - self.width // 8, - ), - generator=generator, - dtype=self.dtype, - ).to("cpu") - - self.scheduler.set_timesteps(num_inference_steps) - if self.is_img2img: - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - latents = self.encode_image(image) - latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) - return latents, [timesteps] - else: - self.scheduler.is_scale_input_called = True - latents = noise * self.scheduler.init_noise_sigma - return latents, self.scheduler.timesteps - - def encode_image(self, input_image): - self.load_submodels(["vae_encode"]) - vae_encode_start = time.time() - latents = self.run("vae_encode", input_image) - vae_inf_time = (time.time() - vae_encode_start) * 1000 - if self.ondemand: - self.unload_submodels(["vae_encode"]) - print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}") - - return latents - - def produce_img_latents( - self, - latents, - text_embeddings, - guidance_scale, - total_timesteps, - cpu_scheduling, - mask=None, - masked_image_latents=None, - return_all_latents=False, - ): - # self.status = SD_STATE_IDLE - step_time_sum = 0 - latent_history = [latents] - text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype) - text_embeddings_numpy = text_embeddings.detach().numpy() - guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype) - self.load_submodels(["unet"]) - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(self.dtype).detach().numpy() - latent_model_input = self.scheduler.scale_model_input(latents, t).to( - self.dtype - ) - if mask is not None and masked_image_latents is not None: - latent_model_input = torch.cat( - [ - torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype), - mask, - masked_image_latents, - ], - dim=1, - ).to(self.dtype) - if cpu_scheduling: - latent_model_input = latent_model_input.detach().numpy() - - # Profiling Unet. - # profile_device = start_profiling(file_path="unet.rdc") - noise_pred = self.run( - "unet", - [ - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - ], - ) - # end_profiling(profile_device) - - if cpu_scheduling: - noise_pred = torch.from_numpy(noise_pred.to_host()) - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - else: - latents = self.run("scheduler_step", (noise_pred, t, latents)) - - latent_history.append(latents) - step_time = (time.time() - step_start_time) * 1000 - # print( - # f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms" - # ) - step_time_sum += step_time - - # if self.status == SD_STATE_CANCEL: - # break - - if self.ondemand: - self.unload_submodels(["unet"]) - gc.collect() - - avg_step_time = step_time_sum / len(total_timesteps) - print(f"\n[LOG] Average step time: {avg_step_time}ms/it") - - if not return_all_latents: - return latents - all_latents = torch.cat(latent_history, dim=0) - return all_latents - - def decode_latents(self, latents, cpu_scheduling=True): - latents_numpy = latents.to(self.dtype) - if cpu_scheduling: - latents_numpy = latents.detach().numpy() - - # profile_device = start_profiling(file_path="vae.rdc") - vae_start = time.time() - images = self.run("vae_decode", latents_numpy).to_host() - vae_inf_time = (time.time() - vae_start) * 1000 - # end_profiling(profile_device) - print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") - - images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy() - pil_images = self.image_processor.numpy_to_pil(images) - return pil_images - def generate_images( self, prompt, negative_prompt, image, - scheduler, - steps, strength, guidance_scale, seed, @@ -359,69 +167,15 @@ def generate_images( control_mode, hints, ): - # TODO: Batched args - self.image_processor = VaeImageProcessor(do_convert_rgb=True) - self.scheduler = self.schedulers[scheduler] - self.ondemand = ondemand - if self.is_img2img: - image, _ = self.image_processor.preprocess(image, resample_type) - else: - image = None - - print("\n[LOG] Generating images...") - batched_args = [ + img = self.sd_pipe.generate_images( prompt, negative_prompt, - image, - ] - for arg in batched_args: - if not isinstance(arg, list): - arg = [arg] * self.batch_size - if len(arg) < self.batch_size: - arg = arg * self.batch_size - else: - arg = [arg[i] for i in range(self.batch_size)] - - text_embeddings = self.encode_prompts_weight( - prompt, - negative_prompt, - ) - - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - - generator = torch.manual_seed(seed) - - init_latents, final_timesteps = self.prepare_latents( - generator=generator, - num_inference_steps=steps, - image=image, - strength=strength, - ) - - latents = self.produce_img_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=final_timesteps, - cpu_scheduling=True, # until we have schedulers through Turbine + 1, + guidance_scale, + seed, + return_imgs=True, ) - - # Img latents -> PIL images - all_imgs = [] - self.load_submodels(["vae_decode"]) - for i in tqdm(range(0, latents.shape[0], self.batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + self.batch_size], - cpu_scheduling=True, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_submodels(["vae_decode"]) - - return all_imgs + return img def shark_sd_fn_dict_input( @@ -516,6 +270,8 @@ def shark_sd_fn( "num_loras": num_loras, "import_ir": cmd_opts.import_mlir, "is_controlled": is_controlled, + "steps": steps, + "scheduler": scheduler, } submit_prep_kwargs = { "custom_weights": custom_weights, @@ -527,8 +283,6 @@ def shark_sd_fn( "prompt": prompt, "negative_prompt": negative_prompt, "image": sd_init_image, - "steps": steps, - "scheduler": scheduler, "strength": strength, "guidance_scale": guidance_scale, "seed": seed, @@ -566,9 +320,9 @@ def shark_sd_fn( for current_batch in range(batch_count): start_time = time.time() out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) - total_time = time.time() - start_time - text_output = f"Total image(s) generation time: {total_time:.4f}sec" - print(f"\n[LOG] {text_output}") + # total_time = time.time() - start_time + # text_output = f"Total image(s) generation time: {total_time:.4f}sec" + #print(f"\n[LOG] {text_output}") # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: @@ -595,6 +349,9 @@ def view_json_file(file_path): content = fopen.read() return content +def safe_name(name): + return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") + if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index e9268aa83b..7241bf3691 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -12,11 +12,6 @@ from cpuinfo import get_cpu_info # TODO: migrate these utils to studio -from shark.iree_utils.vulkan_utils import ( - set_iree_vulkan_runtime_flags, - get_vulkan_target_triple, - get_iree_vulkan_runtime_flags, -) def get_available_devices(): @@ -49,8 +44,6 @@ def get_devices_by_name(driver_name): device_list.append(f"{device_name} => {driver_name}://{i}") return device_list - set_iree_runtime_flags() - available_devices = [] from shark.iree_utils.vulkan_utils import ( get_all_vulkan_devices, @@ -71,6 +64,8 @@ def get_devices_by_name(driver_name): available_devices.extend(cuda_devices) rocm_devices = get_devices_by_name("rocm") available_devices.extend(rocm_devices) + hip_devices = get_devices_by_name("hip") + available_devices.extend(hip_devices) cpu_device = get_devices_by_name("cpu-sync") available_devices.extend(cpu_device) cpu_device = get_devices_by_name("cpu-task") @@ -78,54 +73,45 @@ def get_devices_by_name(driver_name): return available_devices -def set_init_device_flags(): - if "vulkan" in cmd_opts.device: - # set runtime flags for vulkan. - set_iree_runtime_flags() - - # set triple flag to avoid multiple calls to get_vulkan_triple_flag - device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) - if not cmd_opts.iree_vulkan_target_triple: - triple = get_vulkan_target_triple(device_name) - if triple is not None: - cmd_opts.iree_vulkan_target_triple = triple - print( - f"Found device {device_name}. Using target triple " - f"{cmd_opts.iree_vulkan_target_triple}." - ) - elif "cuda" in cmd_opts.device: - cmd_opts.device = "cuda" - elif "metal" in cmd_opts.device: - device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) - if not cmd_opts.iree_metal_target_platform: - from shark.iree_utils.metal_utils import get_metal_target_triple - - triple = get_metal_target_triple(device_name) - if triple is not None: - cmd_opts.iree_metal_target_platform = triple.split("-")[-1] - print( - f"Found device {device_name}. Using target triple " - f"{cmd_opts.iree_metal_target_platform}." - ) - elif "cpu" in cmd_opts.device: - cmd_opts.device = "cpu" - - -def set_iree_runtime_flags(): - # TODO: This function should be device-agnostic and piped properly - # to general runtime driver init. - vulkan_runtime_flags = get_iree_vulkan_runtime_flags() - if cmd_opts.enable_rgp: - vulkan_runtime_flags += [ - f"--enable_rgp=true", - f"--vulkan_debug_utils=true", - ] - if cmd_opts.device_allocator_heap_key: - vulkan_runtime_flags += [ - f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", - ] - set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) - +def parse_device(device_str): + from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map + rt_driver, device_id = clean_device_info(device_str) + target_backend = iree_target_map(rt_driver) + if device_id: + rt_device = f"{rt_driver}://{device_id}" + else: + rt_device = rt_driver + + match target_backend: + case "vulkan-spirv": + triple = get_iree_target_triple(device_str) + return target_backend, rt_device, triple + case "rocm": + triple = get_rocm_target_chip(device_str) + return target_backend, rt_device, triple + case "cpu": + return "llvm-cpu", "local-task", "x86_64-linux-gnu" + + +def get_rocm_target_chip(device_str): + #TODO: Use a data file to map device_str to target chip. + rocm_chip_map = { + "6700": "gfx1031", + "6800": "gfx1030", + "6900": "gfx1030", + "7900": "gfx1100", + "MI300X": "gfx942", + "MI300A": "gfx940", + "MI210": "gfx90a", + "MI250": "gfx90a", + "MI100": "gfx908", + "MI50": "gfx906", + "MI60": "gfx906", + } + for key in rocm_chip_map: + if key in device_str: + return rocm_chip_map[key] + raise AssertionError(f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues.") def get_all_devices(driver_name): """ diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index fa2c1836fd..d26da8c581 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -45,11 +45,10 @@ import apps.shark_studio.web.utils.globals as global_obj sd_default_models = [ - "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1-base", "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-xl-1.0", + "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", ] @@ -286,14 +285,14 @@ def base_model_changed(base_model_id): with gr.Row(): height = gr.Slider( 384, - 768, + 1024, value=cmd_opts.height, step=8, label="\U00002195\U0000FE0F Height", ) width = gr.Slider( 384, - 768, + 1024, value=cmd_opts.width, step=8, label="\U00002194\U0000FE0F Width", diff --git a/requirements.txt b/requirements.txt index 1ff2b685d7..fc644d814a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,9 +5,10 @@ setuptools wheel -torch==2.3.0 -shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models +torch>=2.3.0 +shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=core +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=models +diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release # SHARK Runner tqdm diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 749a7c4e6f..9b4dab3b02 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -88,5 +88,7 @@ else {python -m venv .\shark.venv\} .\shark.venv\Scripts\activate python -m pip install --upgrade pip pip install wheel -pip install -r requirements.txt +pip install --pre -r requirements.txt + +>>>>>>> 0c904eb7 (Shark Studio SDXL support, HIP driver support, simpler device info, small fixes) Write-Host "Source your venv with ./shark.venv/Scripts/activate" diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index c58405b46e..1d022f67e4 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -76,6 +76,7 @@ def get_supported_device_list(): "vulkan": "vulkan", "metal": "metal", "rocm": "rocm", + "hip": "hip", "intel-gpu": "level_zero", } @@ -94,6 +95,7 @@ def iree_target_map(device): "vulkan": "vulkan-spirv", "metal": "metal", "rocm": "rocm", + "hip": "rocm", "intel-gpu": "opencl-spirv", } diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 5fd1d4006a..9d43c115d9 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -62,6 +62,9 @@ def get_iree_device_args(device, extra_args=[]): from shark.iree_utils.gpu_utils import get_iree_rocm_args return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) + if device == "hip": + from shark.iree_utils.gpu_utils import get_iree_rocm_args + return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True) return [] def get_iree_target_triple(device): diff --git a/shark/iree_utils/gpu_utils.py b/shark/iree_utils/gpu_utils.py index 0eba67ff53..db6ef14e34 100644 --- a/shark/iree_utils/gpu_utils.py +++ b/shark/iree_utils/gpu_utils.py @@ -52,7 +52,7 @@ def check_rocm_device_arch_in_args(extra_args): return None -def get_rocm_device_arch(device_num=0, extra_args=[]): +def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False): # ROCM Device Arch selection: # 1 : User given device arch using `--iree-rocm-target-chip` flag # 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index @@ -68,15 +68,23 @@ def get_rocm_device_arch(device_num=0, extra_args=[]): arch_in_device_dump = None # get rocm arch from iree dump devices - def get_devices_info_from_dump(dump): + def get_devices_info_from_dump(dump, driver): from os import linesep - - dump_clean = list( - filter( - lambda s: "--device=rocm" in s or "gpu-arch-name:" in s, - dump.split(linesep), + + if driver == "hip": + dump_clean = list( + filter( + lambda s: "AMD" in s, + dump.split(linesep), + ) + ) + else: + dump_clean = list( + filter( + lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s, + dump.split(linesep), + ) ) - ) arch_pairs = [ ( dump_clean[i].split("=")[1].strip(), @@ -87,16 +95,17 @@ def get_devices_info_from_dump(dump): return arch_pairs dump_device_info = None + driver = "hip" if hip_driver else "rocm" try: dump_device_info = run_cmd( - "iree-run-module --dump_devices=rocm", raise_err=True + "iree-run-module --dump_devices=" + driver, raise_err=True ) except Exception as e: - print("could not execute `iree-run-module --dump_devices=rocm`") + print("could not execute `iree-run-module --dump_devices=" + driver + "`") if dump_device_info is not None: device_num = 0 if device_num is None else device_num - device_arch_pairs = get_devices_info_from_dump(dump_device_info[0]) + device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver) if len(device_arch_pairs) > device_num: # can find arch in the list arch_in_device_dump = device_arch_pairs[device_num][1] @@ -107,24 +116,22 @@ def get_devices_info_from_dump(dump): default_rocm_arch = "gfx1100" print( "Did not find ROCm architecture from `--iree-rocm-target-chip` flag" - "\n or from `iree-run-module --dump_devices=rocm` command." + "\n or from `iree-run-module --dump_devices` command." f"\nUsing {default_rocm_arch} as ROCm arch for compilation." ) return default_rocm_arch # Get the default gpu args given the architecture. -def get_iree_rocm_args(device_num=0, extra_args=[]): +def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False): ireert.flags.FUNCTION_INPUT_VALIDATION = False - rocm_flags = ["--iree-rocm-link-bc=true"] - + rocm_flags = [] if check_rocm_device_arch_in_args(extra_args) is None: - rocm_arch = get_rocm_device_arch(device_num, extra_args) + rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver) rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}") return rocm_flags - # Some constants taken from cuda.h CUDA_SUCCESS = 0 CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16 From 6bad4aa82668b3505d8e2dc05774c595c95412bc Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 24 Apr 2024 16:04:05 -0500 Subject: [PATCH 02/20] Fixups to llm API/UI and ignore user config files. --- .gitignore | 3 ++- apps/shark_studio/api/llm.py | 14 +++++--------- apps/shark_studio/tests/api_test.py | 1 + apps/shark_studio/web/ui/chat.py | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index bf07b2794f..7d6a0d4215 100644 --- a/.gitignore +++ b/.gitignore @@ -164,7 +164,7 @@ cython_debug/ # vscode related .vscode -# Shark related artefacts +# Shark related artifacts *venv/ shark_tmp/ *.vmfb @@ -172,6 +172,7 @@ shark_tmp/ tank/dict_configs.py *.csv reproducers/ +apps/shark_studio/web/configs # ORT related artefacts cache_models/ diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 6ee80ae49e..8ba816f353 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -155,7 +155,7 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( + self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name]["initializer"]( self.hf_model_name, hf_auth_token, compile_to="torch", @@ -258,8 +258,7 @@ def format_out(results): history.append(format_out(token)) while ( - format_out(token) - != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] + format_out(token) != llm_model_map[self.hf_model_name]["stop_token"] and len(history) < self.max_tokens ): dec_time = time.time() @@ -273,10 +272,7 @@ def format_out(results): self.prev_token_len = token_len + len(history) - if ( - format_out(token) - == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] - ): + if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]: break for i in range(len(history)): @@ -310,7 +306,7 @@ def chat_hf(self, prompt): self.first_input = False history.append(int(token)) - while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: + while token != llm_model_map[self.hf_model_name]["stop_token"]: dec_time = time.time() result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv) history.append(int(token)) @@ -321,7 +317,7 @@ def chat_hf(self, prompt): self.prev_token_len = token_len + len(history) - if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: + if token == llm_model_map[self.hf_model_name]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 7bed2cb7b0..49f4482576 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -36,6 +36,7 @@ def test01_LLMSmall(self): device="cpu", precision="fp32", quantization="None", + streaming_llm=True, ) count = 0 label = "Turkishoure Turkish" diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 418f087548..54ae4a139f 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -137,7 +137,7 @@ def view_json_file(file_obj): streaming_llm = gr.Checkbox( label="Run in streaming mode (requires recompilation)", value=True, - interactive=True, + interactive=False, ) prompt_prefix = gr.Checkbox( label="Add System Prompt", From 8191cbeaa55e1feff808e626e6c348f42655d70b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 22 May 2024 20:03:25 -0500 Subject: [PATCH 03/20] Small fixes for unifying pipelines. --- apps/shark_studio/api/sd.py | 51 ++++++++++--------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index fe0f14bbc5..1994f6c7ff 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -4,12 +4,12 @@ import os import json import numpy as np +import copy from tqdm.auto import tqdm from pathlib import Path from random import randint -from turbine_models.custom_models.sd_inference import clip, unet, vae -from turbine_models.custom_models.sdxl_inference import sdxl_compiled_pipeline +from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label @@ -34,16 +34,11 @@ process_custom_pipe_weights, ) -sd_model_map = { - "clip": { - "initializer": clip.export_clip_model, - }, - "unet": { - "initializer": unet.export_unet_model, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - }, + +EMPTY_SD_MAP = { + "clip": None, + "unet": None, + "vae_decode": None, } EMPTY_FLAGS = { @@ -75,7 +70,6 @@ def __init__( num_loras: int = 0, import_ir: bool = True, is_controlled: bool = False, - hf_auth_token=None, ): self.compiled_pipeline = False self.base_model_id = base_model_id @@ -102,7 +96,7 @@ def __init__( ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) - self.sd_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( + self.sd_pipe = SharkSDPipeline( hf_model_name=base_model_id, scheduler_id=scheduler, height=height, @@ -125,28 +119,10 @@ def __init__( def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") - self.is_img2img = is_img2img - mlirs = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } - vmfbs = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } - weights = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } + self.is_img2img = False + mlirs = copy.deepcopy(EMPTY_SD_MAP) + vmfbs = copy.deepcopy(EMPTY_SD_MAP) + weights = copy.deepcopy(EMPTY_SD_MAP) vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline) @@ -235,6 +211,7 @@ def shark_sd_fn( control_mode = None hints = [] num_loras = 0 + import_ir=True for i in embeddings: num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: @@ -268,7 +245,7 @@ def shark_sd_fn( "device": device, "custom_vae": custom_vae, "num_loras": num_loras, - "import_ir": cmd_opts.import_mlir, + "import_ir": import_ir, "is_controlled": is_controlled, "steps": steps, "scheduler": scheduler, From 7e5f73d7f60c341619a8cfa787ac724bdc53e812 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Mon, 29 Apr 2024 09:28:14 -0700 Subject: [PATCH 04/20] Update requirements.txt for iree-turbine (#2130) * Update requirements.txt to iree-turbine creation * Update requirements.txt * Update requirements.txt * Update requirements.txt --- apps/shark_studio/api/sd.py | 35 +++++++++++++++++++++++++++++++---- requirements.txt | 6 +++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 1994f6c7ff..5e7895bc75 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -10,6 +10,9 @@ from pathlib import Path from random import randint from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline +from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import SharkSDXLPipeline + + from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label @@ -18,6 +21,7 @@ get_resource_path, get_checkpoints_path, ) + from apps.shark_studio.modules.schedulers import get_schedulers from apps.shark_studio.modules.prompt_encoding import ( get_weighted_text_embeddings, @@ -34,6 +38,10 @@ process_custom_pipe_weights, ) +from shark.iree_utils.compile_utils import ( + clean_device_info, + get_iree_target_triple, +) EMPTY_SD_MAP = { "clip": None, @@ -73,6 +81,12 @@ def __init__( ): self.compiled_pipeline = False self.base_model_id = base_model_id + self.custom_vae = custom_vae + self.is_sdxl = "xl" in self.base_model_id.lower() + if self.is_sdxl: + self.turbine_pipe = SharkSDXLPipeline + else: + self.turbine_pipe = SharkSDPipeline external_weights = "safetensors" max_length = 64 target_backend, self.rt_device, triple = parse_device(device) @@ -91,12 +105,14 @@ def __init__( if custom_vae: pipe_id_list.append(custom_vae) self.pipe_id = "_".join(pipe_id_list) - self.weights_path = os.path.join( + self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) + self.weights_path = Path(os.path.join( get_checkpoints_path(), safe_name(self.base_model_id) - ) + )) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) - self.sd_pipe = SharkSDPipeline( + + self.sd_pipe = self.turbine_pipe( hf_model_name=base_model_id, scheduler_id=scheduler, height=height, @@ -110,9 +126,10 @@ def __init__( ireec_flags=EMPTY_FLAGS, attn_spec=None, decomp_attn=True if "gfx9" not in triple else False, - pipeline_dir=self.pipe_id, + pipeline_dir=self.pipeline_dir, external_weights_dir=self.weights_path, external_weights=external_weights, + custom_vae=custom_vae, ) print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") gc.collect() @@ -123,6 +140,15 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): mlirs = copy.deepcopy(EMPTY_SD_MAP) vmfbs = copy.deepcopy(EMPTY_SD_MAP) weights = copy.deepcopy(EMPTY_SD_MAP) + + if custom_weights: + custom_weights_params, _ = process_custom_pipe_weights( + custom_weights + ) + weights["clip"] = custom_weights_params + weights["unet"] = custom_weights_params + + vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline) @@ -280,6 +306,7 @@ def shark_sd_fn( # Initializes the pipeline and retrieves IR based on all # parameters that are static in the turbine output format, # which is currently MLIR in the torch dialect. + sd_pipe = StableDiffusion( **submit_pipe_kwargs, diff --git a/requirements.txt b/requirements.txt index fc644d814a..8cabf3f47a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,10 @@ setuptools wheel + torch>=2.3.0 -shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=core -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sdxl-fixes#subdirectory=models -diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release +shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models # SHARK Runner tqdm From 6d8fb5bbe199d8f45f9a3303f52a8523beeef997 Mon Sep 17 00:00:00 2001 From: gpetters94 Date: Tue, 30 Apr 2024 13:27:30 -0400 Subject: [PATCH 05/20] Remove IREE pin (fixes exe issue) (#2126) * Diagnose a build issue * Remove IREE pin * Revert the build on pull request change --- setup_venv.ps1 | 1 + 1 file changed, 1 insertion(+) diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 9b4dab3b02..b6a6994449 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -89,6 +89,7 @@ else {python -m venv .\shark.venv\} python -m pip install --upgrade pip pip install wheel pip install --pre -r requirements.txt +pip install -e . >>>>>>> 0c904eb7 (Shark Studio SDXL support, HIP driver support, simpler device info, small fixes) Write-Host "Source your venv with ./shark.venv/Scripts/activate" From f99e794be8b76708694edbe2ad32dbf288e0851c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 24 Apr 2024 04:01:28 -0500 Subject: [PATCH 06/20] Shark Studio SDXL support, HIP driver support, simpler device info, small fixes --- apps/shark_studio/api/sd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 5e7895bc75..9cd6e6b79b 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -21,7 +21,6 @@ get_resource_path, get_checkpoints_path, ) - from apps.shark_studio.modules.schedulers import get_schedulers from apps.shark_studio.modules.prompt_encoding import ( get_weighted_text_embeddings, @@ -56,6 +55,12 @@ "pipeline": None, } +EMPTY_FLAGS = { + "clip": None, + "unet": None, + "vae": None, + "pipeline": None, +} class StableDiffusion: # This class is responsible for executing image generation and creating From 7e50013f6d1e680e4c51c478b38e3b92a192897b Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 00:00:38 -0500 Subject: [PATCH 07/20] Abstract out SD pipelines from Studio Webui (WIP) --- apps/shark_studio/api/sd.py | 1 + apps/shark_studio/web/ui/sd.py | 2 +- apps/shark_studio/web/utils/metadata/png_metadata.py | 5 ++--- apps/shark_studio/web/utils/tmp_configs.py | 2 +- shark/iree_utils/compile_utils.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 9cd6e6b79b..1b780bcbe1 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -21,6 +21,7 @@ get_resource_path, get_checkpoints_path, ) + from apps.shark_studio.modules.schedulers import get_schedulers from apps.shark_studio.modules.prompt_encoding import ( get_weighted_text_embeddings, diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index d26da8c581..048bbcc093 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -17,10 +17,10 @@ write_default_sd_config, ) from apps.shark_studio.api.sd import ( - sd_model_map, shark_sd_fn_dict_input, cancel_sd, ) +from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map from apps.shark_studio.api.controlnet import ( cnet_preview, ) diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py index 72f663f246..d1cadc1e00 100644 --- a/apps/shark_studio/web/utils/metadata/png_metadata.py +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -3,9 +3,8 @@ from apps.shark_studio.web.utils.file_utils import ( get_checkpoint_pathfile, ) -from apps.shark_studio.api.sd import ( - sd_model_map, -) +from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map + from apps.shark_studio.modules.schedulers import ( scheduler_model_map, ) diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py index 7f65120cbb..ebbc4ae6af 100644 --- a/apps/shark_studio/web/utils/tmp_configs.py +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -17,7 +17,7 @@ def clear_tmp_mlir(): and filename.endswith(".mlir") ] for filename in mlir_files: - os.remove(shark_tmp + filename) + os.remove(os.path.join(shark_tmp, filename)) print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.") diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 9d43c115d9..b76fff81fc 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -92,9 +92,9 @@ def clean_device_info(raw_device): if len(device_id) <= 2: device_id = int(device_id) - if device not in ["rocm", "vulkan"]: + if device not in ["hip", "rocm", "vulkan"]: device_id = None - if device in ["rocm", "vulkan"] and device_id == None: + if device in ["hip", "rocm", "vulkan"] and device_id == None: device_id = 0 return device, device_id From febe88994c8d71c2896fef5bb1fff57deaeb6d12 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 01:49:41 -0500 Subject: [PATCH 08/20] Small fixes. --- apps/shark_studio/api/sd.py | 37 +++++++----------- apps/shark_studio/web/ui/sd.py | 71 +++++++++++++++------------------- 2 files changed, 46 insertions(+), 62 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 1b780bcbe1..744649db76 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -22,38 +22,27 @@ get_checkpoints_path, ) -from apps.shark_studio.modules.schedulers import get_schedulers -from apps.shark_studio.modules.prompt_encoding import ( - get_weighted_text_embeddings, -) from apps.shark_studio.modules.img_processing import ( - resize_stencil, save_output_img, - resamplers, - resampler_list, ) from apps.shark_studio.modules.ckpt_processing import ( - preprocessCKPT, process_custom_pipe_weights, ) -from shark.iree_utils.compile_utils import ( - clean_device_info, - get_iree_target_triple, -) - EMPTY_SD_MAP = { "clip": None, + "scheduler": None, "unet": None, "vae_decode": None, } -EMPTY_FLAGS = { - "clip": None, - "unet": None, - "vae": None, +EMPTY_SDXL_MAP = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, "pipeline": None, + "full_pipeline": None, } EMPTY_FLAGS = { @@ -63,6 +52,7 @@ "pipeline": None, } + class StableDiffusion: # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init @@ -91,8 +81,10 @@ def __init__( self.is_sdxl = "xl" in self.base_model_id.lower() if self.is_sdxl: self.turbine_pipe = SharkSDXLPipeline + self.model_map = EMPTY_SDXL_MAP else: self.turbine_pipe = SharkSDPipeline + self.model_map = EMPTY_SD_MAP external_weights = "safetensors" max_length = 64 target_backend, self.rt_device, triple = parse_device(device) @@ -143,16 +135,17 @@ def __init__( def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") self.is_img2img = False - mlirs = copy.deepcopy(EMPTY_SD_MAP) - vmfbs = copy.deepcopy(EMPTY_SD_MAP) - weights = copy.deepcopy(EMPTY_SD_MAP) + mlirs = copy.deepcopy(self.model_map) + vmfbs = copy.deepcopy(self.model_map) + weights = copy.deepcopy(self.model_map) if custom_weights: custom_weights_params, _ = process_custom_pipe_weights( custom_weights ) - weights["clip"] = custom_weights_params - weights["unet"] = custom_weights_params + for key in weights: + if key not in ["vae_decode", "pipeline", "full_pipeline"]: + weights[key] = custom_weights_params vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 048bbcc093..a4df173b1c 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -20,7 +20,6 @@ shark_sd_fn_dict_input, cancel_sd, ) -from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map from apps.shark_studio.api.controlnet import ( cnet_preview, ) @@ -230,14 +229,9 @@ def import_original(original_img, width, height): def base_model_changed(base_model_id): - ckpt_path = Path( - os.path.join( - cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id)) - ) - ) - ckpt_path.mkdir(parents=True, exist_ok=True) - - new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints") + new_choices = get_checkpoints( + os.path.join("checkpoints", os.path.basename(str(base_model_id))) + ) + get_checkpoints(model_type="checkpoints") return gr.Dropdown( value=new_choices[0] if len(new_choices) > 0 else "None", @@ -585,6 +579,21 @@ def base_model_changed(base_model_id): object_fit="fit", preview=True, ) + with gr.Row(): + std_output = gr.Textbox( + value=f"{sd_model_info}\n" + f"Images will be saved at " + f"{get_generated_imgs_path()}", + lines=2, + elem_id="std_output", + show_label=True, + label="Log", + show_copy_button=True, + ) + sd_element.load( + logger.read_sd_logs, None, std_output, every=1 + ) + sd_status = gr.Textbox(visible=False) with gr.Row(): batch_count = gr.Slider( 1, @@ -620,18 +629,19 @@ def base_model_changed(base_model_id): stop_batch = gr.Button("Stop") with gr.Tab(label="Config", id=102) as sd_tab_config: with gr.Column(elem_classes=["sd-right-panel"]): - Path(get_configs_path()).mkdir(parents=True, exist_ok=True) - default_config_file = os.path.join( - get_configs_path(), - "default_sd_config.json", - ) - write_default_sd_config(default_config_file) - sd_json = gr.JSON( - label="SD Config", - elem_classes=["fill"], - value=view_json_file(default_config_file), - render=False, - ) + with gr.Row(elem_classes=["fill"]): + Path(get_configs_path()).mkdir( + parents=True, exist_ok=True + ) + default_config_file = os.path.join( + get_configs_path(), + "default_sd_config.json", + ) + write_default_sd_config(default_config_file) + sd_json = gr.JSON( + elem_classes=["fill"], + value=view_json_file(default_config_file), + ) with gr.Row(): with gr.Column(scale=3): load_sd_config = gr.FileExplorer( @@ -694,30 +704,11 @@ def base_model_changed(base_model_id): inputs=[sd_json, sd_config_name], outputs=[sd_config_name], ) - with gr.Row(elem_classes=["fill"]): - sd_json.render() save_sd_config.click( fn=save_sd_cfg, inputs=[sd_json, sd_config_name], outputs=[sd_config_name], ) - with gr.Tab(label="Log", id=103) as sd_tab_log: - with gr.Row(): - std_output = gr.Textbox( - value=f"{sd_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - elem_id="std_output", - show_label=True, - label="Log", - show_copy_button=True, - ) - sd_element.load( - logger.read_sd_logs, None, std_output, every=1 - ) - sd_status = gr.Textbox(visible=False) - with gr.Tab(label="Automation", id=104) as sd_tab_automation: - pass pull_kwargs = dict( fn=pull_sd_configs, From b609a03da8cd4c7fdf28aed5ee4227fb05f9c440 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 01:58:23 -0500 Subject: [PATCH 09/20] Switch from pin to minimum torch version and fix index url --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8cabf3f47a..e650b6b391 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://download.pytorch.org/whl/nightly/cpu -f https://iree.dev/pip-release-links.html --pre From 00353aa24f1811ba75e6c0a3aff2a690f33b67dd Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 23 May 2024 02:32:59 -0500 Subject: [PATCH 10/20] Update utils.py --- apps/shark_studio/api/utils.py | 55 +++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 7241bf3691..f213320045 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -12,7 +12,11 @@ from cpuinfo import get_cpu_info # TODO: migrate these utils to studio - +from shark.iree_utils.vulkan_utils import ( + set_iree_vulkan_runtime_flags, + get_vulkan_target_triple, + get_iree_vulkan_runtime_flags, +) def get_available_devices(): def get_devices_by_name(driver_name): @@ -44,6 +48,8 @@ def get_devices_by_name(driver_name): device_list.append(f"{device_name} => {driver_name}://{i}") return device_list + set_iree_runtime_flags() + available_devices = [] from shark.iree_utils.vulkan_utils import ( get_all_vulkan_devices, @@ -72,6 +78,53 @@ def get_devices_by_name(driver_name): available_devices.extend(cpu_device) return available_devices +def set_init_device_flags(): + if "vulkan" in cmd_opts.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + cmd_opts.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_vulkan_target_triple}." + ) + elif "cuda" in cmd_opts.device: + cmd_opts.device = "cuda" + elif "metal" in cmd_opts.device: + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_metal_target_platform: + from shark.iree_utils.metal_utils import get_metal_target_triple + + triple = get_metal_target_triple(device_name) + if triple is not None: + cmd_opts.iree_metal_target_platform = triple.split("-")[-1] + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_metal_target_platform}." + ) + elif "cpu" in cmd_opts.device: + cmd_opts.device = "cpu" + + +def set_iree_runtime_flags(): + # TODO: This function should be device-agnostic and piped properly + # to general runtime driver init. + vulkan_runtime_flags = get_iree_vulkan_runtime_flags() + if cmd_opts.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + if cmd_opts.device_allocator_heap_key: + vulkan_runtime_flags += [ + f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", + ] + set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) def parse_device(device_str): from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map From 1bd7265960d4c4b1f4e04d83531b761bad94c72a Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 11:03:27 -0500 Subject: [PATCH 11/20] Fix typo --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index e650b6b391..9d6fc7cfb7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,8 +18,11 @@ google-cloud-storage # Testing pytest +<<<<<<< HEAD pytest-xdist pytest-forked +======= +>>>>>>> 08d48242 (Fix typo) Pillow parameterized From 0d2bc755ca4ef97e0d38cad7c7ada413b26efbec Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 11:04:27 -0500 Subject: [PATCH 12/20] Formatting --- apps/shark_studio/api/llm.py | 4 +++- apps/shark_studio/api/sd.py | 35 +++++++++++++++++++--------------- apps/shark_studio/api/utils.py | 19 ++++++++++++++---- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 8ba816f353..217fb6784f 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -155,7 +155,9 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name]["initializer"]( + self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][ + "initializer" + ]( self.hf_model_name, hf_auth_token, compile_to="torch", diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 744649db76..d064e04a8f 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -10,7 +10,9 @@ from pathlib import Path from random import randint from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline -from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import SharkSDXLPipeline +from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import ( + SharkSDXLPipeline, +) from apps.shark_studio.api.controlnet import control_adapter_map @@ -104,9 +106,9 @@ def __init__( pipe_id_list.append(custom_vae) self.pipe_id = "_".join(pipe_id_list) self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) - self.weights_path = Path(os.path.join( - get_checkpoints_path(), safe_name(self.base_model_id) - )) + self.weights_path = Path( + os.path.join(get_checkpoints_path(), safe_name(self.base_model_id)) + ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) @@ -140,18 +142,21 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): weights = copy.deepcopy(self.model_map) if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights( - custom_weights - ) + custom_weights_params, _ = process_custom_pipe_weights(custom_weights) for key in weights: if key not in ["vae_decode", "pipeline", "full_pipeline"]: weights[key] = custom_weights_params - - vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) + vmfbs, weights = self.sd_pipe.check_prepared( + mlirs, vmfbs, weights, interactive=False + ) print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") - self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline) - print("\n[LOG] Pipeline successfully prepared for runtime. Generating images...") + self.sd_pipe.load_pipeline( + vmfbs, weights, self.rt_device, self.compiled_pipeline + ) + print( + "\n[LOG] Pipeline successfully prepared for runtime. Generating images..." + ) return def generate_images( @@ -236,7 +241,7 @@ def shark_sd_fn( control_mode = None hints = [] num_loras = 0 - import_ir=True + import_ir = True for i in embeddings: num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: @@ -305,7 +310,6 @@ def shark_sd_fn( # Initializes the pipeline and retrieves IR based on all # parameters that are static in the turbine output format, # which is currently MLIR in the torch dialect. - sd_pipe = StableDiffusion( **submit_pipe_kwargs, @@ -325,7 +329,7 @@ def shark_sd_fn( out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) # total_time = time.time() - start_time # text_output = f"Total image(s) generation time: {total_time:.4f}sec" - #print(f"\n[LOG] {text_output}") + # print(f"\n[LOG] {text_output}") # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: @@ -352,8 +356,9 @@ def view_json_file(file_path): content = fopen.read() return content + def safe_name(name): - return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") + return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") if __name__ == "__main__": diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index f213320045..0e53bd4a5a 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -18,6 +18,7 @@ get_iree_vulkan_runtime_flags, ) + def get_available_devices(): def get_devices_by_name(driver_name): from shark.iree_utils._common import iree_device_map @@ -49,7 +50,7 @@ def get_devices_by_name(driver_name): return device_list set_iree_runtime_flags() - + available_devices = [] from shark.iree_utils.vulkan_utils import ( get_all_vulkan_devices, @@ -78,6 +79,7 @@ def get_devices_by_name(driver_name): available_devices.extend(cpu_device) return available_devices + def set_init_device_flags(): if "vulkan" in cmd_opts.device: # set runtime flags for vulkan. @@ -126,8 +128,14 @@ def set_iree_runtime_flags(): ] set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) + def parse_device(device_str): - from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map + from shark.iree_utils.compile_utils import ( + clean_device_info, + get_iree_target_triple, + iree_target_map, + ) + rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: @@ -147,7 +155,7 @@ def parse_device(device_str): def get_rocm_target_chip(device_str): - #TODO: Use a data file to map device_str to target chip. + # TODO: Use a data file to map device_str to target chip. rocm_chip_map = { "6700": "gfx1031", "6800": "gfx1030", @@ -164,7 +172,10 @@ def get_rocm_target_chip(device_str): for key in rocm_chip_map: if key in device_str: return rocm_chip_map[key] - raise AssertionError(f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues.") + raise AssertionError( + f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues." + ) + def get_all_devices(driver_name): """ From 6e7c91f8b596f82b52ddf97767115b1b89f8f696 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 11:22:18 -0500 Subject: [PATCH 13/20] Update requirements --- requirements.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9d6fc7cfb7..fd4e9e5bb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,11 +18,6 @@ google-cloud-storage # Testing pytest -<<<<<<< HEAD -pytest-xdist -pytest-forked -======= ->>>>>>> 08d48242 (Fix typo) Pillow parameterized From 072d5afab6d59928ace8ec5428b89e6dba1f0889 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 11:35:18 -0500 Subject: [PATCH 14/20] Fix device parsing. --- shark/iree_utils/compile_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index b76fff81fc..f93c8fef2e 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -70,8 +70,8 @@ def get_iree_device_args(device, extra_args=[]): def get_iree_target_triple(device): args = get_iree_device_args(device) for flag in args: - if "triple" in flag.split("-"): - triple = flag.split("=") + if "triple" in flag: + triple = flag.split("=")[-1] return triple return "" From 5905904bf52c90f58b27af9b33003edda530aef8 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 23 May 2024 12:04:18 -0500 Subject: [PATCH 15/20] Update test-studio.yml --- .github/workflows/test-studio.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml index 765a6bf761..9b96bf270f 100644 --- a/.github/workflows/test-studio.yml +++ b/.github/workflows/test-studio.yml @@ -81,6 +81,4 @@ jobs: source shark.venv/bin/activate pip install -r requirements.txt --no-cache-dir pip install -e . - pip uninstall -y torch - pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html python apps/shark_studio/tests/api_test.py From 34a40abb0de91d827a98b422eb9bbf22571f42a7 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 12:06:27 -0500 Subject: [PATCH 16/20] Fix linux setup --- requirements.txt | 1 - setup_venv.sh | 16 +--------------- 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/requirements.txt b/requirements.txt index fd4e9e5bb3..e400755c10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,6 @@ parameterized # Add transformers, diffusers and scipy since it most commonly used #accelerate is now required for diffusers import from ckpt. accelerate -scipy ftfy gradio==4.19.2 altair diff --git a/setup_venv.sh b/setup_venv.sh index 64f769d794..11fe79fd68 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -84,21 +84,7 @@ else PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/ fi -$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL} - -if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then - T_VER=$($PYTHON -m pip show torch | grep Version) - T_VER_MIN=${T_VER:14:12} - TV_VER=$($PYTHON -m pip show torchvision | grep Version) - TV_VER_MAJ=${TV_VER:9:6} - $PYTHON -m pip uninstall -y torchvision - $PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ - if [ $? -eq 0 ];then - echo "Successfully Installed torch + cu118." - else - echo "Could not install torch + cu118." >&2 - fi -fi +$PYTHON -m pip install --no-warn-conflicts -e . -f ${RUNTIME} -f ${PYTORCH_URL} if [[ -z "${NO_BREVITAS}" ]]; then $PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev From 1c6db08b671ca6f477cb8c3b20603688a5a223be Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 23 May 2024 15:06:57 -0500 Subject: [PATCH 17/20] Small fixes --- apps/shark_studio/api/sd.py | 15 ++++++- apps/shark_studio/modules/schedulers.py | 48 +++++++++++------------ apps/shark_studio/web/utils/file_utils.py | 2 +- requirements.txt | 2 +- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index d064e04a8f..150a72d28e 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -112,6 +112,17 @@ def __init__( if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) + decomp_attn = True + attn_spec = None + if triple in ["gfx940", "gfx942", "gfx90a"]: + decomp_attn = False + attn_spec = "mfma" + elif triple in ["gfx1100", "gfx1103"]: + decomp_attn = False + attn_spec = "wmma" + elif target_backend == "llvm-cpu": + decomp_attn = False + self.sd_pipe = self.turbine_pipe( hf_model_name=base_model_id, scheduler_id=scheduler, @@ -124,8 +135,8 @@ def __init__( device=target_backend, iree_target_triple=triple, ireec_flags=EMPTY_FLAGS, - attn_spec=None, - decomp_attn=True if "gfx9" not in triple else False, + attn_spec=attn_spec, + decomp_attn=decomp_attn, pipeline_dir=self.pipeline_dir, external_weights_dir=self.weights_path, external_weights=external_weights, diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 3e931b1c78..8c2413c638 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -50,30 +50,30 @@ def get_schedulers(model_id): schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained( model_id, subfolder="scheduler", algorithm_type="dpmsolver++" ) - schedulers["DPMSolverMultistepKarras"] = ( - DPMSolverMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - use_karras_sigmas=True, - ) - ) - schedulers["DPMSolverMultistepKarras++"] = ( - DPMSolverMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - algorithm_type="dpmsolver++", - use_karras_sigmas=True, - ) + schedulers[ + "DPMSolverMultistepKarras" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + use_karras_sigmas=True, + ) + schedulers[ + "DPMSolverMultistepKarras++" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="dpmsolver++", + use_karras_sigmas=True, ) schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler", ) - schedulers["EulerAncestralDiscrete"] = ( - EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained( model_id, @@ -83,11 +83,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["KDPM2AncestralDiscrete"] = ( - KDPM2AncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "KDPM2AncestralDiscrete" + ] = KDPM2AncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained( model_id, diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index 9617c16565..3619055676 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -47,7 +47,7 @@ def write_default_sd_config(path): def safe_name(name): - return name.replace("/", "_").replace("-", "_") + return name.split("/")[-1].replace("-", "_") def get_path_stem(path): diff --git a/requirements.txt b/requirements.txt index e400755c10..bbb8adf6b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ parameterized #accelerate is now required for diffusers import from ckpt. accelerate ftfy -gradio==4.19.2 +gradio==4.29.0 altair omegaconf # 0.3.2 doesn't have binaries for arm64 From 7fac1d75a863b330cf075fd0e44d6d54b7b82d49 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 27 May 2024 00:01:57 -0500 Subject: [PATCH 18/20] Fix custom weights. --- apps/shark_studio/api/sd.py | 33 ++++++++++++++++---- apps/shark_studio/api/utils.py | 3 +- apps/shark_studio/modules/ckpt_processing.py | 30 +++++++++++++++--- apps/shark_studio/modules/schedulers.py | 3 +- 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 150a72d28e..a3c6f294e6 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -29,7 +29,8 @@ ) from apps.shark_studio.modules.ckpt_processing import ( - process_custom_pipe_weights, + preprocessCKPT, + save_irpa, ) EMPTY_SD_MAP = { @@ -77,6 +78,7 @@ def __init__( import_ir: bool = True, is_controlled: bool = False, ): + self.precision = precision self.compiled_pipeline = False self.base_model_id = base_model_id self.custom_vae = custom_vae @@ -107,7 +109,7 @@ def __init__( self.pipe_id = "_".join(pipe_id_list) self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) self.weights_path = Path( - os.path.join(get_checkpoints_path(), safe_name(self.base_model_id)) + os.path.join(get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)) ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) @@ -153,10 +155,29 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): weights = copy.deepcopy(self.model_map) if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) + custom_weights = os.path.join(get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), custom_weights) + diffusers_weights_path = preprocessCKPT(custom_weights, self.precision) for key in weights: - if key not in ["vae_decode", "pipeline", "full_pipeline"]: - weights[key] = custom_weights_params + if key in ["scheduled_unet", "unet"]: + unet_weights_path = os.path.join(diffusers_weights_path, "unet", "diffusion_pytorch_model.safetensors") + weights[key] = save_irpa(unet_weights_path, "unet.") + + elif key in ["clip", "prompt_encoder"]: + if not self.is_sdxl: + sd1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") + weights[key] = save_irpa(sd1_path, "text_encoder_model.") + else: + clip_1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") + clip_2_path = os.path.join(diffusers_weights_path, "text_encoder_2", "model.safetensors") + weights[key] = [ + save_irpa(clip_1_path, "text_encoder_model_1."), + save_irpa(clip_2_path, "text_encoder_model_2.") + ] + + elif key in ["vae_decode"] and weights[key] is None: + vae_weights_path = os.path.join(diffusers_weights_path, "vae", "diffusion_pytorch_model.safetensors") + weights[key] = save_irpa(vae_weights_path, "vae.") + vmfbs, weights = self.sd_pipe.check_prepared( mlirs, vmfbs, weights, interactive=False @@ -369,7 +390,7 @@ def view_json_file(file_path): def safe_name(name): - return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_") + return name.replace("/", "_").replace("\\", "_").replace(".", "_") if __name__ == "__main__": diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 0e53bd4a5a..d63da5fc1a 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -135,7 +135,6 @@ def parse_device(device_str): get_iree_target_triple, iree_target_map, ) - rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: @@ -150,7 +149,7 @@ def parse_device(device_str): case "rocm": triple = get_rocm_target_chip(device_str) return target_backend, rt_device, triple - case "cpu": + case "llvm-cpu": return "llvm-cpu", "local-task", "x86_64-linux-gnu" diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index fc0bd3b7b8..523ca08b57 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -2,6 +2,11 @@ import json import re import requests +import torch +import safetensors +from shark_turbine.aot.params import ( + ParameterArchiveBuilder, +) from io import BytesIO from pathlib import Path from tqdm import tqdm @@ -15,21 +20,21 @@ ) -def get_path_to_diffusers_checkpoint(custom_weights): +def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"): path = Path(custom_weights) diffusers_path = path.parent.absolute() - diffusers_directory_name = os.path.join("diffusers", path.stem) + diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}") complete_path_to_diffusers = diffusers_path / diffusers_directory_name complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) path_to_diffusers = complete_path_to_diffusers.as_posix() return path_to_diffusers -def preprocessCKPT(custom_weights, is_inpaint=False): - path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) +def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision) if next(Path(path_to_diffusers).iterdir(), None): print("Checkpoint already loaded at : ", path_to_diffusers) - return + return path_to_diffusers else: print( "Diffusers' checkpoint will be identified here : ", @@ -51,8 +56,23 @@ def preprocessCKPT(custom_weights, is_inpaint=False): from_safetensors=from_safetensors, num_in_channels=num_in_channels, ) + if precision == "fp16": + pipe.to(dtype=torch.float16) pipe.save_pretrained(path_to_diffusers) + del pipe print("Loading complete") + return path_to_diffusers + +def save_irpa(weights_path, prepend_str): + weights = safetensors.torch.load_file(weights_path) + archive = ParameterArchiveBuilder() + for key in weights.keys(): + new_key = prepend_str + key + archive.add_tensor(new_key, weights[key]) + + irpa_file = weights_path.replace(".safetensors", ".irpa") + archive.save(irpa_file) + return irpa_file def convert_original_vae(vae_checkpoint): diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 8c2413c638..1f5f99bc6d 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -101,11 +101,12 @@ def export_scheduler_model(model): scheduler_model_map = { + "PNDM": export_scheduler_model("PNDMScheduler"), + "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"), "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), "LCM": export_scheduler_model("LCMScheduler"), "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), - "PNDM": export_scheduler_model("PNDMScheduler"), "DDPM": export_scheduler_model("DDPMScheduler"), "DDIM": export_scheduler_model("DDIMScheduler"), "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"), From c2f717d6e95555fa390f38ea4c687866197778fb Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 27 May 2024 00:02:18 -0500 Subject: [PATCH 19/20] Formatting --- apps/shark_studio/api/sd.py | 41 +++++++++++++++----- apps/shark_studio/api/utils.py | 1 + apps/shark_studio/modules/ckpt_processing.py | 3 +- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index a3c6f294e6..83574d294d 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -109,7 +109,9 @@ def __init__( self.pipe_id = "_".join(pipe_id_list) self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) self.weights_path = Path( - os.path.join(get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision)) + os.path.join( + get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision) + ) ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) @@ -155,29 +157,48 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): weights = copy.deepcopy(self.model_map) if custom_weights: - custom_weights = os.path.join(get_checkpoints_path("checkpoints"), safe_name(self.base_model_id.split("/")[-1]), custom_weights) + custom_weights = os.path.join( + get_checkpoints_path("checkpoints"), + safe_name(self.base_model_id.split("/")[-1]), + custom_weights, + ) diffusers_weights_path = preprocessCKPT(custom_weights, self.precision) for key in weights: if key in ["scheduled_unet", "unet"]: - unet_weights_path = os.path.join(diffusers_weights_path, "unet", "diffusion_pytorch_model.safetensors") + unet_weights_path = os.path.join( + diffusers_weights_path, + "unet", + "diffusion_pytorch_model.safetensors", + ) weights[key] = save_irpa(unet_weights_path, "unet.") - + elif key in ["clip", "prompt_encoder"]: if not self.is_sdxl: - sd1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") + sd1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) weights[key] = save_irpa(sd1_path, "text_encoder_model.") else: - clip_1_path = os.path.join(diffusers_weights_path, "text_encoder", "model.safetensors") - clip_2_path = os.path.join(diffusers_weights_path, "text_encoder_2", "model.safetensors") + clip_1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) + clip_2_path = os.path.join( + diffusers_weights_path, + "text_encoder_2", + "model.safetensors", + ) weights[key] = [ save_irpa(clip_1_path, "text_encoder_model_1."), - save_irpa(clip_2_path, "text_encoder_model_2.") + save_irpa(clip_2_path, "text_encoder_model_2."), ] elif key in ["vae_decode"] and weights[key] is None: - vae_weights_path = os.path.join(diffusers_weights_path, "vae", "diffusion_pytorch_model.safetensors") + vae_weights_path = os.path.join( + diffusers_weights_path, + "vae", + "diffusion_pytorch_model.safetensors", + ) weights[key] = save_irpa(vae_weights_path, "vae.") - vmfbs, weights = self.sd_pipe.check_prepared( mlirs, vmfbs, weights, interactive=False diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index d63da5fc1a..0516255d2b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -135,6 +135,7 @@ def parse_device(device_str): get_iree_target_triple, iree_target_map, ) + rt_driver, device_id = clean_device_info(device_str) target_backend = iree_target_map(rt_driver) if device_id: diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index 523ca08b57..433df13654 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -30,7 +30,7 @@ def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"): return path_to_diffusers -def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False): +def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False): path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision) if next(Path(path_to_diffusers).iterdir(), None): print("Checkpoint already loaded at : ", path_to_diffusers) @@ -63,6 +63,7 @@ def preprocessCKPT(custom_weights, precision = "fp16", is_inpaint=False): print("Loading complete") return path_to_diffusers + def save_irpa(weights_path, prepend_str): weights = safetensors.torch.load_file(weights_path) archive = ParameterArchiveBuilder() From 7762166535b6653454e73c9ddcfbac5c4b77971c Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 28 May 2024 11:29:03 -0500 Subject: [PATCH 20/20] formatting --- apps/shark_studio/modules/schedulers.py | 48 ++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 1f5f99bc6d..56df8973d0 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -50,30 +50,30 @@ def get_schedulers(model_id): schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained( model_id, subfolder="scheduler", algorithm_type="dpmsolver++" ) - schedulers[ - "DPMSolverMultistepKarras" - ] = DPMSolverMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - use_karras_sigmas=True, - ) - schedulers[ - "DPMSolverMultistepKarras++" - ] = DPMSolverMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - algorithm_type="dpmsolver++", - use_karras_sigmas=True, + schedulers["DPMSolverMultistepKarras"] = ( + DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + use_karras_sigmas=True, + ) + ) + schedulers["DPMSolverMultistepKarras++"] = ( + DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="dpmsolver++", + use_karras_sigmas=True, + ) ) schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler", ) - schedulers[ - "EulerAncestralDiscrete" - ] = EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", + schedulers["EulerAncestralDiscrete"] = ( + EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) ) schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained( model_id, @@ -83,11 +83,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers[ - "KDPM2AncestralDiscrete" - ] = KDPM2AncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", + schedulers["KDPM2AncestralDiscrete"] = ( + KDPM2AncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) ) schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained( model_id,