diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml index 765a6bf761..9b96bf270f 100644 --- a/.github/workflows/test-studio.yml +++ b/.github/workflows/test-studio.yml @@ -81,6 +81,4 @@ jobs: source shark.venv/bin/activate pip install -r requirements.txt --no-cache-dir pip install -e . - pip uninstall -y torch - pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html python apps/shark_studio/tests/api_test.py diff --git a/.gitignore b/.gitignore index f67152b007..7d6a0d4215 100644 --- a/.gitignore +++ b/.gitignore @@ -164,7 +164,7 @@ cython_debug/ # vscode related .vscode -# Shark related artefacts +# Shark related artifacts *venv/ shark_tmp/ *.vmfb @@ -172,6 +172,7 @@ shark_tmp/ tank/dict_configs.py *.csv reproducers/ +apps/shark_studio/web/configs # ORT related artefacts cache_models/ @@ -188,6 +189,11 @@ variants.json # models folder apps/stable_diffusion/web/models/ +# model artifacts (SHARK) +*.tempfile +*.mlir +*.vmfb + # Stencil annotators. stencil_annotator/ diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 6ee80ae49e..217fb6784f 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -155,7 +155,9 @@ def __init__( use_auth_token=hf_auth_token, ) elif not os.path.exists(self.tempfile_name): - self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( + self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][ + "initializer" + ]( self.hf_model_name, hf_auth_token, compile_to="torch", @@ -258,8 +260,7 @@ def format_out(results): history.append(format_out(token)) while ( - format_out(token) - != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] + format_out(token) != llm_model_map[self.hf_model_name]["stop_token"] and len(history) < self.max_tokens ): dec_time = time.time() @@ -273,10 +274,7 @@ def format_out(results): self.prev_token_len = token_len + len(history) - if ( - format_out(token) - == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"] - ): + if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]: break for i in range(len(history)): @@ -310,7 +308,7 @@ def chat_hf(self, prompt): self.first_input = False history.append(int(token)) - while token != llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: + while token != llm_model_map[self.hf_model_name]["stop_token"]: dec_time = time.time() result = self.hf_mod(token.reshape([1, 1]), past_key_values=pkv) history.append(int(token)) @@ -321,7 +319,7 @@ def chat_hf(self, prompt): self.prev_token_len = token_len + len(history) - if token == llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]: + if token == llm_model_map[self.hf_model_name]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index b4f0f0ddc0..83574d294d 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -4,51 +4,59 @@ import os import json import numpy as np +import copy from tqdm.auto import tqdm from pathlib import Path from random import randint -from turbine_models.custom_models.sd_inference import clip, unet, vae +from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline +from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import ( + SharkSDXLPipeline, +) + + from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label from apps.shark_studio.web.utils.file_utils import ( safe_name, get_resource_path, get_checkpoints_path, ) -from apps.shark_studio.modules.pipeline import SharkPipelineBase -from apps.shark_studio.modules.schedulers import get_schedulers -from apps.shark_studio.modules.prompt_encoding import ( - get_weighted_text_embeddings, -) + from apps.shark_studio.modules.img_processing import ( - resize_stencil, save_output_img, - resamplers, - resampler_list, ) from apps.shark_studio.modules.ckpt_processing import ( preprocessCKPT, - process_custom_pipe_weights, + save_irpa, ) -from transformers import CLIPTokenizer -from diffusers.image_processor import VaeImageProcessor - -sd_model_map = { - "clip": { - "initializer": clip.export_clip_model, - }, - "unet": { - "initializer": unet.export_unet_model, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - }, + +EMPTY_SD_MAP = { + "clip": None, + "scheduler": None, + "unet": None, + "vae_decode": None, +} + +EMPTY_SDXL_MAP = { + "prompt_encoder": None, + "scheduled_unet": None, + "vae_decode": None, + "pipeline": None, + "full_pipeline": None, +} + +EMPTY_FLAGS = { + "clip": None, + "unet": None, + "vae": None, + "pipeline": None, } -class StableDiffusion(SharkPipelineBase): +class StableDiffusion: # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile @@ -61,66 +69,36 @@ def __init__( height: int, width: int, batch_size: int, + steps: int, + scheduler: str, precision: str, device: str, custom_vae: str = None, num_loras: int = 0, import_ir: bool = True, is_controlled: bool = False, - hf_auth_token=None, ): - self.model_max_length = 77 - self.batch_size = batch_size self.precision = precision - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.height = height - self.width = width - self.scheduler_obj = {} - static_kwargs = { - "pipe": { - "external_weights": "safetensors", - }, - "clip": {"hf_model_name": base_model_id}, - "unet": { - "hf_model_name": base_model_id, - "unet_model": unet.UnetModel(hf_model_name=base_model_id), - "batch_size": batch_size, - # "is_controlled": is_controlled, - # "num_loras": num_loras, - "height": height, - "width": width, - "precision": precision, - "max_length": self.model_max_length, - }, - "vae_encode": { - "hf_model_name": base_model_id, - "vae_model": vae.VaeModel( - hf_model_name=custom_vae if custom_vae else base_model_id, - ), - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - "vae_decode": { - "hf_model_name": base_model_id, - "vae_model": vae.VaeModel( - hf_model_name=custom_vae if custom_vae else base_model_id, - ), - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - } - super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir) + self.compiled_pipeline = False + self.base_model_id = base_model_id + self.custom_vae = custom_vae + self.is_sdxl = "xl" in self.base_model_id.lower() + if self.is_sdxl: + self.turbine_pipe = SharkSDXLPipeline + self.model_map = EMPTY_SDXL_MAP + else: + self.turbine_pipe = SharkSDPipeline + self.model_map = EMPTY_SD_MAP + external_weights = "safetensors" + max_length = 64 + target_backend, self.rt_device, triple = parse_device(device) pipe_id_list = [ safe_name(base_model_id), str(batch_size), - str(self.model_max_length), + str(max_length), f"{str(height)}x{str(width)}", precision, - self.device, + triple, ] if num_loras > 0: pipe_id_list.append(str(num_loras) + "lora") @@ -129,227 +107,116 @@ def __init__( if custom_vae: pipe_id_list.append(custom_vae) self.pipe_id = "_".join(pipe_id_list) + self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id)) + self.weights_path = Path( + os.path.join( + get_checkpoints_path(), safe_name(self.base_model_id + "_" + precision) + ) + ) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + + decomp_attn = True + attn_spec = None + if triple in ["gfx940", "gfx942", "gfx90a"]: + decomp_attn = False + attn_spec = "mfma" + elif triple in ["gfx1100", "gfx1103"]: + decomp_attn = False + attn_spec = "wmma" + elif target_backend == "llvm-cpu": + decomp_attn = False + + self.sd_pipe = self.turbine_pipe( + hf_model_name=base_model_id, + scheduler_id=scheduler, + height=height, + width=width, + precision=precision, + max_length=max_length, + batch_size=batch_size, + num_inference_steps=steps, + device=target_backend, + iree_target_triple=triple, + ireec_flags=EMPTY_FLAGS, + attn_spec=attn_spec, + decomp_attn=decomp_attn, + pipeline_dir=self.pipeline_dir, + external_weights_dir=self.weights_path, + external_weights=external_weights, + custom_vae=custom_vae, + ) print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") - del static_kwargs gc.collect() def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") - self.is_img2img = is_img2img - self.schedulers = get_schedulers(self.base_model_id) - - self.weights_path = os.path.join( - get_checkpoints_path(), self.safe_name(self.base_model_id) - ) - if not os.path.exists(self.weights_path): - os.mkdir(self.weights_path) - - for model in adapters: - self.model_map[model] = adapters[model] - - for submodel in self.static_kwargs: - if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) - if submodel not in ["clip", "clip2"]: - self.static_kwargs[submodel][ - "external_weights" - ] = custom_weights_params - else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" + self.is_img2img = False + mlirs = copy.deepcopy(self.model_map) + vmfbs = copy.deepcopy(self.model_map) + weights = copy.deepcopy(self.model_map) + + if custom_weights: + custom_weights = os.path.join( + get_checkpoints_path("checkpoints"), + safe_name(self.base_model_id.split("/")[-1]), + custom_weights, + ) + diffusers_weights_path = preprocessCKPT(custom_weights, self.precision) + for key in weights: + if key in ["scheduled_unet", "unet"]: + unet_weights_path = os.path.join( + diffusers_weights_path, + "unet", + "diffusion_pytorch_model.safetensors", ) - else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" - ) - - self.get_compiled_map(pipe_id=self.pipe_id) - print("\n[LOG] Pipeline successfully prepared for runtime.") - return + weights[key] = save_irpa(unet_weights_path, "unet.") + + elif key in ["clip", "prompt_encoder"]: + if not self.is_sdxl: + sd1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) + weights[key] = save_irpa(sd1_path, "text_encoder_model.") + else: + clip_1_path = os.path.join( + diffusers_weights_path, "text_encoder", "model.safetensors" + ) + clip_2_path = os.path.join( + diffusers_weights_path, + "text_encoder_2", + "model.safetensors", + ) + weights[key] = [ + save_irpa(clip_1_path, "text_encoder_model_1."), + save_irpa(clip_2_path, "text_encoder_model_2."), + ] + + elif key in ["vae_decode"] and weights[key] is None: + vae_weights_path = os.path.join( + diffusers_weights_path, + "vae", + "diffusion_pytorch_model.safetensors", + ) + weights[key] = save_irpa(vae_weights_path, "vae.") - def encode_prompts_weight( - self, - prompt, - negative_prompt, - do_classifier_free_guidance=True, - ): - # Encodes the prompt into text encoder hidden states. - self.load_submodels(["clip"]) - self.tokenizer = CLIPTokenizer.from_pretrained( - self.base_model_id, - subfolder="tokenizer", + vmfbs, weights = self.sd_pipe.check_prepared( + mlirs, vmfbs, weights, interactive=False ) - clip_inf_start = time.time() - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") + self.sd_pipe.load_pipeline( + vmfbs, weights, self.rt_device, self.compiled_pipeline ) - - if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - pad = (0, 0) * (len(text_embeddings.shape) - 2) - pad = pad + ( - 0, - self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1], + print( + "\n[LOG] Pipeline successfully prepared for runtime. Generating images..." ) - text_embeddings = torch.nn.functional.pad(text_embeddings, pad) - - # SHARK: Report clip inference time - clip_inf_time = (time.time() - clip_inf_start) * 1000 - if self.ondemand: - self.unload_submodels(["clip"]) - gc.collect() - print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") - - return text_embeddings.numpy().astype(np.float16) - - def prepare_latents( - self, - generator, - num_inference_steps, - image, - strength, - ): - noise = torch.randn( - ( - self.batch_size, - 4, - self.height // 8, - self.width // 8, - ), - generator=generator, - dtype=self.dtype, - ).to("cpu") - - self.scheduler.set_timesteps(num_inference_steps) - if self.is_img2img: - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] - latents = self.encode_image(image) - latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) - return latents, [timesteps] - else: - self.scheduler.is_scale_input_called = True - latents = noise * self.scheduler.init_noise_sigma - return latents, self.scheduler.timesteps - - def encode_image(self, input_image): - self.load_submodels(["vae_encode"]) - vae_encode_start = time.time() - latents = self.run("vae_encode", input_image) - vae_inf_time = (time.time() - vae_encode_start) * 1000 - if self.ondemand: - self.unload_submodels(["vae_encode"]) - print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}") - - return latents - - def produce_img_latents( - self, - latents, - text_embeddings, - guidance_scale, - total_timesteps, - cpu_scheduling, - mask=None, - masked_image_latents=None, - return_all_latents=False, - ): - # self.status = SD_STATE_IDLE - step_time_sum = 0 - latent_history = [latents] - text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype) - text_embeddings_numpy = text_embeddings.detach().numpy() - guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype) - self.load_submodels(["unet"]) - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(self.dtype).detach().numpy() - latent_model_input = self.scheduler.scale_model_input(latents, t).to( - self.dtype - ) - if mask is not None and masked_image_latents is not None: - latent_model_input = torch.cat( - [ - torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype), - mask, - masked_image_latents, - ], - dim=1, - ).to(self.dtype) - if cpu_scheduling: - latent_model_input = latent_model_input.detach().numpy() - - # Profiling Unet. - # profile_device = start_profiling(file_path="unet.rdc") - noise_pred = self.run( - "unet", - [ - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - ], - ) - # end_profiling(profile_device) - - if cpu_scheduling: - noise_pred = torch.from_numpy(noise_pred.to_host()) - latents = self.scheduler.step(noise_pred, t, latents).prev_sample - else: - latents = self.run("scheduler_step", (noise_pred, t, latents)) - - latent_history.append(latents) - step_time = (time.time() - step_start_time) * 1000 - # print( - # f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms" - # ) - step_time_sum += step_time - - # if self.status == SD_STATE_CANCEL: - # break - - if self.ondemand: - self.unload_submodels(["unet"]) - gc.collect() - - avg_step_time = step_time_sum / len(total_timesteps) - print(f"\n[LOG] Average step time: {avg_step_time}ms/it") - - if not return_all_latents: - return latents - all_latents = torch.cat(latent_history, dim=0) - return all_latents - - def decode_latents(self, latents, cpu_scheduling=True): - latents_numpy = latents.to(self.dtype) - if cpu_scheduling: - latents_numpy = latents.detach().numpy() - - # profile_device = start_profiling(file_path="vae.rdc") - vae_start = time.time() - images = self.run("vae_decode", latents_numpy).to_host() - vae_inf_time = (time.time() - vae_start) * 1000 - # end_profiling(profile_device) - print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") - - images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy() - pil_images = self.image_processor.numpy_to_pil(images) - return pil_images + return def generate_images( self, prompt, negative_prompt, image, - scheduler, - steps, strength, guidance_scale, seed, @@ -359,69 +226,15 @@ def generate_images( control_mode, hints, ): - # TODO: Batched args - self.image_processor = VaeImageProcessor(do_convert_rgb=True) - self.scheduler = self.schedulers[scheduler] - self.ondemand = ondemand - if self.is_img2img: - image, _ = self.image_processor.preprocess(image, resample_type) - else: - image = None - - print("\n[LOG] Generating images...") - batched_args = [ - prompt, - negative_prompt, - image, - ] - for arg in batched_args: - if not isinstance(arg, list): - arg = [arg] * self.batch_size - if len(arg) < self.batch_size: - arg = arg * self.batch_size - else: - arg = [arg[i] for i in range(self.batch_size)] - - text_embeddings = self.encode_prompts_weight( + img = self.sd_pipe.generate_images( prompt, negative_prompt, + 1, + guidance_scale, + seed, + return_imgs=True, ) - - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - - generator = torch.manual_seed(seed) - - init_latents, final_timesteps = self.prepare_latents( - generator=generator, - num_inference_steps=steps, - image=image, - strength=strength, - ) - - latents = self.produce_img_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=final_timesteps, - cpu_scheduling=True, # until we have schedulers through Turbine - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_submodels(["vae_decode"]) - for i in tqdm(range(0, latents.shape[0], self.batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + self.batch_size], - cpu_scheduling=True, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_submodels(["vae_decode"]) - - return all_imgs + return img def shark_sd_fn_dict_input( @@ -481,6 +294,7 @@ def shark_sd_fn( control_mode = None hints = [] num_loras = 0 + import_ir = True for i in embeddings: num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: @@ -514,8 +328,10 @@ def shark_sd_fn( "device": device, "custom_vae": custom_vae, "num_loras": num_loras, - "import_ir": cmd_opts.import_mlir, + "import_ir": import_ir, "is_controlled": is_controlled, + "steps": steps, + "scheduler": scheduler, } submit_prep_kwargs = { "custom_weights": custom_weights, @@ -527,8 +343,6 @@ def shark_sd_fn( "prompt": prompt, "negative_prompt": negative_prompt, "image": sd_init_image, - "steps": steps, - "scheduler": scheduler, "strength": strength, "guidance_scale": guidance_scale, "seed": seed, @@ -566,9 +380,9 @@ def shark_sd_fn( for current_batch in range(batch_count): start_time = time.time() out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) - total_time = time.time() - start_time - text_output = f"Total image(s) generation time: {total_time:.4f}sec" - print(f"\n[LOG] {text_output}") + # total_time = time.time() - start_time + # text_output = f"Total image(s) generation time: {total_time:.4f}sec" + # print(f"\n[LOG] {text_output}") # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: @@ -596,6 +410,10 @@ def view_json_file(file_path): return content +def safe_name(name): + return name.replace("/", "_").replace("\\", "_").replace(".", "_") + + if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index e9268aa83b..0516255d2b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -71,6 +71,8 @@ def get_devices_by_name(driver_name): available_devices.extend(cuda_devices) rocm_devices = get_devices_by_name("rocm") available_devices.extend(rocm_devices) + hip_devices = get_devices_by_name("hip") + available_devices.extend(hip_devices) cpu_device = get_devices_by_name("cpu-sync") available_devices.extend(cpu_device) cpu_device = get_devices_by_name("cpu-task") @@ -127,6 +129,54 @@ def set_iree_runtime_flags(): set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) +def parse_device(device_str): + from shark.iree_utils.compile_utils import ( + clean_device_info, + get_iree_target_triple, + iree_target_map, + ) + + rt_driver, device_id = clean_device_info(device_str) + target_backend = iree_target_map(rt_driver) + if device_id: + rt_device = f"{rt_driver}://{device_id}" + else: + rt_device = rt_driver + + match target_backend: + case "vulkan-spirv": + triple = get_iree_target_triple(device_str) + return target_backend, rt_device, triple + case "rocm": + triple = get_rocm_target_chip(device_str) + return target_backend, rt_device, triple + case "llvm-cpu": + return "llvm-cpu", "local-task", "x86_64-linux-gnu" + + +def get_rocm_target_chip(device_str): + # TODO: Use a data file to map device_str to target chip. + rocm_chip_map = { + "6700": "gfx1031", + "6800": "gfx1030", + "6900": "gfx1030", + "7900": "gfx1100", + "MI300X": "gfx942", + "MI300A": "gfx940", + "MI210": "gfx90a", + "MI250": "gfx90a", + "MI100": "gfx908", + "MI50": "gfx906", + "MI60": "gfx906", + } + for key in rocm_chip_map: + if key in device_str: + return rocm_chip_map[key] + raise AssertionError( + f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues." + ) + + def get_all_devices(driver_name): """ Inputs: driver_name diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index fc0bd3b7b8..433df13654 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -2,6 +2,11 @@ import json import re import requests +import torch +import safetensors +from shark_turbine.aot.params import ( + ParameterArchiveBuilder, +) from io import BytesIO from pathlib import Path from tqdm import tqdm @@ -15,21 +20,21 @@ ) -def get_path_to_diffusers_checkpoint(custom_weights): +def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"): path = Path(custom_weights) diffusers_path = path.parent.absolute() - diffusers_directory_name = os.path.join("diffusers", path.stem) + diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}") complete_path_to_diffusers = diffusers_path / diffusers_directory_name complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) path_to_diffusers = complete_path_to_diffusers.as_posix() return path_to_diffusers -def preprocessCKPT(custom_weights, is_inpaint=False): - path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) +def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision) if next(Path(path_to_diffusers).iterdir(), None): print("Checkpoint already loaded at : ", path_to_diffusers) - return + return path_to_diffusers else: print( "Diffusers' checkpoint will be identified here : ", @@ -51,8 +56,24 @@ def preprocessCKPT(custom_weights, is_inpaint=False): from_safetensors=from_safetensors, num_in_channels=num_in_channels, ) + if precision == "fp16": + pipe.to(dtype=torch.float16) pipe.save_pretrained(path_to_diffusers) + del pipe print("Loading complete") + return path_to_diffusers + + +def save_irpa(weights_path, prepend_str): + weights = safetensors.torch.load_file(weights_path) + archive = ParameterArchiveBuilder() + for key in weights.keys(): + new_key = prepend_str + key + archive.add_tensor(new_key, weights[key]) + + irpa_file = weights_path.replace(".safetensors", ".irpa") + archive.save(irpa_file) + return irpa_file def convert_original_vae(vae_checkpoint): diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 3e931b1c78..56df8973d0 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -101,11 +101,12 @@ def export_scheduler_model(model): scheduler_model_map = { + "PNDM": export_scheduler_model("PNDMScheduler"), + "DPMSolverSDE": export_scheduler_model("DpmSolverSDEScheduler"), "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), "LCM": export_scheduler_model("LCMScheduler"), "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), - "PNDM": export_scheduler_model("PNDMScheduler"), "DDPM": export_scheduler_model("DDPMScheduler"), "DDIM": export_scheduler_model("DDIMScheduler"), "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"), diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index 7bed2cb7b0..49f4482576 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -36,6 +36,7 @@ def test01_LLMSmall(self): device="cpu", precision="fp32", quantization="None", + streaming_llm=True, ) count = 0 label = "Turkishoure Turkish" diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 418f087548..54ae4a139f 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -137,7 +137,7 @@ def view_json_file(file_obj): streaming_llm = gr.Checkbox( label="Run in streaming mode (requires recompilation)", value=True, - interactive=True, + interactive=False, ) prompt_prefix = gr.Checkbox( label="Add System Prompt", diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index fa2c1836fd..a4df173b1c 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -17,7 +17,6 @@ write_default_sd_config, ) from apps.shark_studio.api.sd import ( - sd_model_map, shark_sd_fn_dict_input, cancel_sd, ) @@ -45,11 +44,10 @@ import apps.shark_studio.web.utils.globals as global_obj sd_default_models = [ - "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-1-base", "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-xl-1.0", + "stabilityai/stable-diffusion-xl-base-1.0", "stabilityai/sdxl-turbo", ] @@ -231,14 +229,9 @@ def import_original(original_img, width, height): def base_model_changed(base_model_id): - ckpt_path = Path( - os.path.join( - cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id)) - ) - ) - ckpt_path.mkdir(parents=True, exist_ok=True) - - new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints") + new_choices = get_checkpoints( + os.path.join("checkpoints", os.path.basename(str(base_model_id))) + ) + get_checkpoints(model_type="checkpoints") return gr.Dropdown( value=new_choices[0] if len(new_choices) > 0 else "None", @@ -286,14 +279,14 @@ def base_model_changed(base_model_id): with gr.Row(): height = gr.Slider( 384, - 768, + 1024, value=cmd_opts.height, step=8, label="\U00002195\U0000FE0F Height", ) width = gr.Slider( 384, - 768, + 1024, value=cmd_opts.width, step=8, label="\U00002194\U0000FE0F Width", @@ -586,6 +579,21 @@ def base_model_changed(base_model_id): object_fit="fit", preview=True, ) + with gr.Row(): + std_output = gr.Textbox( + value=f"{sd_model_info}\n" + f"Images will be saved at " + f"{get_generated_imgs_path()}", + lines=2, + elem_id="std_output", + show_label=True, + label="Log", + show_copy_button=True, + ) + sd_element.load( + logger.read_sd_logs, None, std_output, every=1 + ) + sd_status = gr.Textbox(visible=False) with gr.Row(): batch_count = gr.Slider( 1, @@ -621,18 +629,19 @@ def base_model_changed(base_model_id): stop_batch = gr.Button("Stop") with gr.Tab(label="Config", id=102) as sd_tab_config: with gr.Column(elem_classes=["sd-right-panel"]): - Path(get_configs_path()).mkdir(parents=True, exist_ok=True) - default_config_file = os.path.join( - get_configs_path(), - "default_sd_config.json", - ) - write_default_sd_config(default_config_file) - sd_json = gr.JSON( - label="SD Config", - elem_classes=["fill"], - value=view_json_file(default_config_file), - render=False, - ) + with gr.Row(elem_classes=["fill"]): + Path(get_configs_path()).mkdir( + parents=True, exist_ok=True + ) + default_config_file = os.path.join( + get_configs_path(), + "default_sd_config.json", + ) + write_default_sd_config(default_config_file) + sd_json = gr.JSON( + elem_classes=["fill"], + value=view_json_file(default_config_file), + ) with gr.Row(): with gr.Column(scale=3): load_sd_config = gr.FileExplorer( @@ -695,30 +704,11 @@ def base_model_changed(base_model_id): inputs=[sd_json, sd_config_name], outputs=[sd_config_name], ) - with gr.Row(elem_classes=["fill"]): - sd_json.render() save_sd_config.click( fn=save_sd_cfg, inputs=[sd_json, sd_config_name], outputs=[sd_config_name], ) - with gr.Tab(label="Log", id=103) as sd_tab_log: - with gr.Row(): - std_output = gr.Textbox( - value=f"{sd_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - elem_id="std_output", - show_label=True, - label="Log", - show_copy_button=True, - ) - sd_element.load( - logger.read_sd_logs, None, std_output, every=1 - ) - sd_status = gr.Textbox(visible=False) - with gr.Tab(label="Automation", id=104) as sd_tab_automation: - pass pull_kwargs = dict( fn=pull_sd_configs, diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py index 9617c16565..3619055676 100644 --- a/apps/shark_studio/web/utils/file_utils.py +++ b/apps/shark_studio/web/utils/file_utils.py @@ -47,7 +47,7 @@ def write_default_sd_config(path): def safe_name(name): - return name.replace("/", "_").replace("-", "_") + return name.split("/")[-1].replace("-", "_") def get_path_stem(path): diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py index 72f663f246..d1cadc1e00 100644 --- a/apps/shark_studio/web/utils/metadata/png_metadata.py +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -3,9 +3,8 @@ from apps.shark_studio.web.utils.file_utils import ( get_checkpoint_pathfile, ) -from apps.shark_studio.api.sd import ( - sd_model_map, -) +from apps.shark_studio.api.sd import EMPTY_SD_MAP as sd_model_map + from apps.shark_studio.modules.schedulers import ( scheduler_model_map, ) diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py index 7f65120cbb..ebbc4ae6af 100644 --- a/apps/shark_studio/web/utils/tmp_configs.py +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -17,7 +17,7 @@ def clear_tmp_mlir(): and filename.endswith(".mlir") ] for filename in mlir_files: - os.remove(shark_tmp + filename) + os.remove(os.path.join(shark_tmp, filename)) print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.") diff --git a/requirements.txt b/requirements.txt index 1ff2b685d7..bbb8adf6b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://download.pytorch.org/whl/nightly/cpu -f https://iree.dev/pip-release-links.html --pre setuptools wheel -torch==2.3.0 + +torch>=2.3.0 shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models # SHARK Runner tqdm @@ -17,17 +18,14 @@ google-cloud-storage # Testing pytest -pytest-xdist -pytest-forked Pillow parameterized # Add transformers, diffusers and scipy since it most commonly used #accelerate is now required for diffusers import from ckpt. accelerate -scipy ftfy -gradio==4.19.2 +gradio==4.29.0 altair omegaconf # 0.3.2 doesn't have binaries for arm64 diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 749a7c4e6f..b6a6994449 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -88,5 +88,8 @@ else {python -m venv .\shark.venv\} .\shark.venv\Scripts\activate python -m pip install --upgrade pip pip install wheel -pip install -r requirements.txt +pip install --pre -r requirements.txt +pip install -e . + +>>>>>>> 0c904eb7 (Shark Studio SDXL support, HIP driver support, simpler device info, small fixes) Write-Host "Source your venv with ./shark.venv/Scripts/activate" diff --git a/setup_venv.sh b/setup_venv.sh index 64f769d794..11fe79fd68 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -84,21 +84,7 @@ else PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/ fi -$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL} - -if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then - T_VER=$($PYTHON -m pip show torch | grep Version) - T_VER_MIN=${T_VER:14:12} - TV_VER=$($PYTHON -m pip show torchvision | grep Version) - TV_VER_MAJ=${TV_VER:9:6} - $PYTHON -m pip uninstall -y torchvision - $PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ - if [ $? -eq 0 ];then - echo "Successfully Installed torch + cu118." - else - echo "Could not install torch + cu118." >&2 - fi -fi +$PYTHON -m pip install --no-warn-conflicts -e . -f ${RUNTIME} -f ${PYTORCH_URL} if [[ -z "${NO_BREVITAS}" ]]; then $PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev diff --git a/shark/iree_utils/_common.py b/shark/iree_utils/_common.py index c58405b46e..1d022f67e4 100644 --- a/shark/iree_utils/_common.py +++ b/shark/iree_utils/_common.py @@ -76,6 +76,7 @@ def get_supported_device_list(): "vulkan": "vulkan", "metal": "metal", "rocm": "rocm", + "hip": "hip", "intel-gpu": "level_zero", } @@ -94,6 +95,7 @@ def iree_target_map(device): "vulkan": "vulkan-spirv", "metal": "metal", "rocm": "rocm", + "hip": "rocm", "intel-gpu": "opencl-spirv", } diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 5fd1d4006a..f93c8fef2e 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -62,13 +62,16 @@ def get_iree_device_args(device, extra_args=[]): from shark.iree_utils.gpu_utils import get_iree_rocm_args return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) + if device == "hip": + from shark.iree_utils.gpu_utils import get_iree_rocm_args + return get_iree_rocm_args(device_num=device_num, extra_args=extra_args, hip_driver=True) return [] def get_iree_target_triple(device): args = get_iree_device_args(device) for flag in args: - if "triple" in flag.split("-"): - triple = flag.split("=") + if "triple" in flag: + triple = flag.split("=")[-1] return triple return "" @@ -89,9 +92,9 @@ def clean_device_info(raw_device): if len(device_id) <= 2: device_id = int(device_id) - if device not in ["rocm", "vulkan"]: + if device not in ["hip", "rocm", "vulkan"]: device_id = None - if device in ["rocm", "vulkan"] and device_id == None: + if device in ["hip", "rocm", "vulkan"] and device_id == None: device_id = 0 return device, device_id diff --git a/shark/iree_utils/gpu_utils.py b/shark/iree_utils/gpu_utils.py index 0eba67ff53..db6ef14e34 100644 --- a/shark/iree_utils/gpu_utils.py +++ b/shark/iree_utils/gpu_utils.py @@ -52,7 +52,7 @@ def check_rocm_device_arch_in_args(extra_args): return None -def get_rocm_device_arch(device_num=0, extra_args=[]): +def get_rocm_device_arch(device_num=0, extra_args=[], hip_driver=False): # ROCM Device Arch selection: # 1 : User given device arch using `--iree-rocm-target-chip` flag # 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index @@ -68,15 +68,23 @@ def get_rocm_device_arch(device_num=0, extra_args=[]): arch_in_device_dump = None # get rocm arch from iree dump devices - def get_devices_info_from_dump(dump): + def get_devices_info_from_dump(dump, driver): from os import linesep - - dump_clean = list( - filter( - lambda s: "--device=rocm" in s or "gpu-arch-name:" in s, - dump.split(linesep), + + if driver == "hip": + dump_clean = list( + filter( + lambda s: "AMD" in s, + dump.split(linesep), + ) + ) + else: + dump_clean = list( + filter( + lambda s: f"--device={driver}" in s or "gpu-arch-name:" in s, + dump.split(linesep), + ) ) - ) arch_pairs = [ ( dump_clean[i].split("=")[1].strip(), @@ -87,16 +95,17 @@ def get_devices_info_from_dump(dump): return arch_pairs dump_device_info = None + driver = "hip" if hip_driver else "rocm" try: dump_device_info = run_cmd( - "iree-run-module --dump_devices=rocm", raise_err=True + "iree-run-module --dump_devices=" + driver, raise_err=True ) except Exception as e: - print("could not execute `iree-run-module --dump_devices=rocm`") + print("could not execute `iree-run-module --dump_devices=" + driver + "`") if dump_device_info is not None: device_num = 0 if device_num is None else device_num - device_arch_pairs = get_devices_info_from_dump(dump_device_info[0]) + device_arch_pairs = get_devices_info_from_dump(dump_device_info[0], driver) if len(device_arch_pairs) > device_num: # can find arch in the list arch_in_device_dump = device_arch_pairs[device_num][1] @@ -107,24 +116,22 @@ def get_devices_info_from_dump(dump): default_rocm_arch = "gfx1100" print( "Did not find ROCm architecture from `--iree-rocm-target-chip` flag" - "\n or from `iree-run-module --dump_devices=rocm` command." + "\n or from `iree-run-module --dump_devices` command." f"\nUsing {default_rocm_arch} as ROCm arch for compilation." ) return default_rocm_arch # Get the default gpu args given the architecture. -def get_iree_rocm_args(device_num=0, extra_args=[]): +def get_iree_rocm_args(device_num=0, extra_args=[], hip_driver=False): ireert.flags.FUNCTION_INPUT_VALIDATION = False - rocm_flags = ["--iree-rocm-link-bc=true"] - + rocm_flags = [] if check_rocm_device_arch_in_args(extra_args) is None: - rocm_arch = get_rocm_device_arch(device_num, extra_args) + rocm_arch = get_rocm_device_arch(device_num, extra_args, hip_driver=hip_driver) rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}") return rocm_flags - # Some constants taken from cuda.h CUDA_SUCCESS = 0 CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16