diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py index e4593570db..da76c60630 100644 --- a/apps/shark_studio/api/initializers.py +++ b/apps/shark_studio/api/initializers.py @@ -8,10 +8,10 @@ from apps.shark_studio.modules.timer import startup_timer from apps.shark_studio.web.utils.tmp_configs import ( - config_tmp, - clear_tmp_mlir, - clear_tmp_imgs, - ) + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, +) def imports(): @@ -57,6 +57,7 @@ def initialize(): from apps.shark_studio.web.utils.file_utils import ( create_checkpoint_folders, ) + # Create custom models folders if they don't exist create_checkpoint_folders() diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 80ad1e8edf..688cc63b29 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -57,7 +57,8 @@ def __init__( self.use_system_prompt = use_system_prompt self.global_iter = 0 if os.path.exists(self.vmfb_name) and ( - external_weights is None or os.path.exists(str(self.external_weight_file)) + external_weights is None + or os.path.exists(str(self.external_weight_file)) ): self.iree_module_dict = dict() ( diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 43c6a1830c..db9b595413 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -11,10 +11,16 @@ from turbine_models.custom_models.sd_inference import clip, unet, vae from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.web.utils.state import status_label -from apps.shark_studio.web.utils.file_utils import safe_name, get_resource_path, get_checkpoints_path +from apps.shark_studio.web.utils.file_utils import ( + safe_name, + get_resource_path, + get_checkpoints_path, +) from apps.shark_studio.modules.pipeline import SharkPipelineBase from apps.shark_studio.modules.schedulers import get_schedulers -from apps.shark_studio.modules.prompt_encoding import get_weighted_text_embeddings +from apps.shark_studio.modules.prompt_encoding import ( + get_weighted_text_embeddings, +) from apps.shark_studio.modules.img_processing import ( resize_stencil, save_output_img, @@ -42,25 +48,26 @@ }, "unet": { "initializer": unet.export_unet_model, - "ireec_flags": ["--iree-flow-collapse-reduction-dims", - "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], "external_weight_file": None, }, "vae_decode": { "initializer": vae.export_vae_model, "external_weight_file": None, - "ireec_flags": ["--iree-flow-collapse-reduction-dims", - "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "ireec_flags": [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], }, } class StableDiffusion(SharkPipelineBase): - # This class is responsible for executing image generation and creating # /managing a set of compiled modules to run Stable Diffusion. The init # aims to be as general as possible, and the class will infer and compile @@ -73,7 +80,6 @@ class StableDiffusion(SharkPipelineBase): # embeddings: a dict of embedding checkpoints or model IDs to use when # initializing the compiled modules. - def __init__( self, base_model_id, @@ -99,10 +105,12 @@ def __init__( "clip": {"hf_model_name": base_model_id}, "unet": { "hf_model_name": base_model_id, - "unet_model": unet.UnetModel(hf_model_name=base_model_id, hf_auth_token=None), + "unet_model": unet.UnetModel( + hf_model_name=base_model_id, hf_auth_token=None + ), "batch_size": batch_size, - #"is_controlled": is_controlled, - #"num_loras": num_loras, + # "is_controlled": is_controlled, + # "num_loras": num_loras, "height": height, "width": width, "precision": precision, @@ -110,7 +118,9 @@ def __init__( }, "vae_encode": { "hf_model_name": custom_vae if custom_vae else base_model_id, - "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), + "vae_model": vae.VaeModel( + hf_model_name=base_model_id, hf_auth_token=None + ), "batch_size": batch_size, "height": height, "width": width, @@ -118,7 +128,9 @@ def __init__( }, "vae_decode": { "hf_model_name": custom_vae if custom_vae else base_model_id, - "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), + "vae_model": vae.VaeModel( + hf_model_name=base_model_id, hf_auth_token=None + ), "batch_size": batch_size, "height": height, "width": width, @@ -135,7 +147,7 @@ def __init__( precision, ] if num_loras > 0: - pipe_id_list.append(str(num_loras)+"lora") + pipe_id_list.append(str(num_loras) + "lora") if is_controlled: pipe_id_list.append("controlled") if custom_vae: @@ -145,8 +157,9 @@ def __init__( del static_kwargs gc.collect() - - def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img): + def prepare_pipe( + self, scheduler, custom_weights, adapters, embeddings, is_img2img + ): print( f"\n[LOG] Preparing pipeline with scheduler {scheduler}" f"\n[LOG] Custom embeddings currently unsupported." @@ -165,15 +178,16 @@ def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2i for i in self.model_map: self.model_map[i]["external_weights_file"] = None elif custom_weights: - print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?") + print( + f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?" + ) self.static_kwargs["pipe"] = { - # "external_weight_path": self.weights_path, -# "external_weights": "safetensors", + # "external_weight_path": self.weights_path, + # "external_weights": "safetensors", } self.get_compiled_map(pipe_id=self.pipe_id) print("\n[LOG] Pipeline successfully prepared for runtime.") return - def generate_images( self, @@ -191,24 +205,24 @@ def generate_images( control_mode, hints, ): - #TODO: Batched args + # TODO: Batched args self.ondemand = ondemand if self.is_img2img: image, _ = self.process_sd_init_image(image, resample_type) - else: + else: image = None print("\n[LOG] Generating images...") - batched_args=[ + batched_args = [ prompt, negative_prompt, - #steps, - #strength, - #guidance_scale, - #seed, - #resample_type, - #control_mode, - #hints, + # steps, + # strength, + # guidance_scale, + # seed, + # resample_type, + # control_mode, + # hints, ] for arg in batched_args: if not isinstance(arg, list): @@ -222,7 +236,7 @@ def generate_images( prompt, negative_prompt, ) - + uint32_info = np.iinfo(np.uint32) uint32_min, uint32_max = uint32_info.min, uint32_info.max if seed < uint32_min or seed >= uint32_max: @@ -242,7 +256,7 @@ def generate_images( text_embeddings=text_embeddings, guidance_scale=guidance_scale, total_timesteps=final_timesteps, - cpu_scheduling=True, # until we have schedulers through Turbine + cpu_scheduling=True, # until we have schedulers through Turbine ) # Img latents -> PIL images @@ -260,7 +274,6 @@ def generate_images( return all_imgs - def encode_prompts_weight( self, prompt, @@ -275,7 +288,6 @@ def encode_prompts_weight( ) clip_inf_start = time.time() - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( pipe=self, prompt=prompt, @@ -300,7 +312,6 @@ def encode_prompts_weight( return text_embeddings.numpy().astype(np.float16) - def prepare_latents( self, generator, @@ -318,7 +329,7 @@ def prepare_latents( generator=generator, dtype=self.dtype, ).to("cpu") - + self.scheduler.set_timesteps(num_inference_steps) if self.is_img2img: init_timestep = min( @@ -336,7 +347,6 @@ def prepare_latents( latents = noise * self.scheduler.init_noise_sigma return latents, self.scheduler.timesteps - def encode_image(self, input_image): self.load_submodels(["vae_encode"]) vae_encode_start = time.time() @@ -348,7 +358,6 @@ def encode_image(self, input_image): return latents - def produce_img_latents( self, latents, @@ -370,11 +379,15 @@ def produce_img_latents( for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() timestep = torch.tensor([t]).to(self.dtype).detach().numpy() - latent_model_input = self.scheduler.scale_model_input(latents, t).to(self.dtype) + latent_model_input = self.scheduler.scale_model_input( + latents, t + ).to(self.dtype) if mask is not None and masked_image_latents is not None: latent_model_input = torch.cat( [ - torch.from_numpy(np.asarray(latent_model_input)).to(torch.float16), + torch.from_numpy(np.asarray(latent_model_input)).to( + torch.float16 + ), mask, masked_image_latents, ], @@ -411,7 +424,7 @@ def produce_img_latents( # ) step_time_sum += step_time - #if self.status == SD_STATE_CANCEL: + # if self.status == SD_STATE_CANCEL: # break if self.ondemand: @@ -426,7 +439,6 @@ def produce_img_latents( all_latents = torch.cat(latent_history, dim=0) return all_latents - def decode_latents(self, latents, use_base_vae, cpu_scheduling): if use_base_vae: latents = 1 / 0.18215 * latents @@ -435,11 +447,11 @@ def decode_latents(self, latents, use_base_vae, cpu_scheduling): if cpu_scheduling: latents_numpy = latents.detach().numpy() - #profile_device = start_profiling(file_path="vae.rdc") + # profile_device = start_profiling(file_path="vae.rdc") vae_start = time.time() images = self.run("vae_decode", latents_numpy).to_host() vae_inf_time = (time.time() - vae_start) * 1000 - #end_profiling(profile_device) + # end_profiling(profile_device) print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") if use_base_vae: @@ -451,7 +463,6 @@ def decode_latents(self, latents, use_base_vae, cpu_scheduling): pil_images = [Image.fromarray(image) for image in images.numpy()] return pil_images - def process_sd_init_image(self, sd_init_image, resample_type): if isinstance(sd_init_image, list): images = [] @@ -462,8 +473,12 @@ def process_sd_init_image(self, sd_init_image, resample_type): return images, is_img2img if isinstance(sd_init_image, str): if os.path.isfile(sd_init_image): - sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") - image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type) + sd_init_image = Image.open(sd_init_image, mode="r").convert( + "RGB" + ) + image, is_img2img = self.process_sd_init_image( + sd_init_image, resample_type + ) else: image = None is_img2img = False @@ -481,10 +496,14 @@ def process_sd_init_image(self, sd_init_image, resample_type): # Fallback to Lanczos else Image.Resampling.LANCZOS ) - image = image.resize((self.width, self.height), resample=resample_type) + image = image.resize( + (self.width, self.height), resample=resample_type + ) image_arr = np.stack([np.array(i) for i in (image,)], axis=0) image_arr = image_arr / 255.0 - image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) + image_arr = ( + torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) + ) image_arr = 2 * (image_arr - 0.5) is_img2img = True image = image_arr @@ -536,7 +555,6 @@ def shark_sd_fn( sd_kwargs = locals() is_img2img = True if sd_init_image[0] is not None else False - print("\n[LOG] Performing Stable Diffusion Pipeline setup...") from apps.shark_studio.modules.shared_cmd_opts import cmd_opts @@ -566,7 +584,7 @@ def shark_sd_fn( "strength": controlnets["strength"][i], } if model is not None: - is_controlled=True + is_controlled = True control_mode = controlnets["control_mode"] for i in controlnets["hint"]: hints.append[i] @@ -659,13 +677,15 @@ def view_json_file(file_path): return content - if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts - import apps.shark_studio.web.utils.globals as global_obj + import apps.shark_studio.web.utils.globals as global_obj + global_obj._init() - sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) + sd_json = view_json_file( + get_resource_path("../configs/default_sd_config.json") + ) sd_kwargs = json.loads(sd_json) for i in shark_sd_fn_dict_input(sd_kwargs): - print(i) \ No newline at end of file + print(i) diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 6c78515cca..af183f8eeb 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -43,8 +43,9 @@ def __init__( self.iree_module_dict = {} self.tempfiles = {} - - def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: + def get_compiled_map( + self, pipe_id, submodel="None", init_kwargs={} + ) -> None: # First checks whether we have .vmfbs precompiled, then populates the map # with the precompiled executables and fetches executables for the rest of the map. # The weights aren't static here anymore so this function should be a part of pipeline @@ -52,37 +53,47 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: # and your model map is populated with any IR - unique model IDs and their static params, # call this method to get the artifacts associated with your map. self.pipe_id = self.safe_name(pipe_id) - self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id)) + self.pipe_vmfb_path = Path( + os.path.join(get_checkpoints_path(".."), self.pipe_id) + ) self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True) if submodel == "None": print("\n[LOG] Gathering any pre-compiled artifacts....") for key in self.model_map: self.get_compiled_map(pipe_id, submodel=key) - else: + else: self.get_precompiled(pipe_id, submodel) ireec_flags = [] if submodel in self.iree_module_dict: if "vmfb" in self.iree_module_dict[submodel]: - print(f"\n[LOG] Executable for {submodel} already loaded...") + print( + f"\n[LOG] Executable for {submodel} already loaded..." + ) return elif "vmfb_path" in self.model_map[submodel]: return elif submodel not in self.tempfiles: - print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...") + print( + f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..." + ) if submodel in self.static_kwargs: init_kwargs = self.static_kwargs[submodel] for key in self.static_kwargs["pipe"]: if key not in init_kwargs: init_kwargs[key] = self.static_kwargs["pipe"][key] - self.import_torch_ir( - submodel, init_kwargs - ) + self.import_torch_ir(submodel, init_kwargs) self.get_compiled_map(pipe_id, submodel) - else: - ireec_flags = self.model_map[submodel]["ireec_flags"] if "ireec_flags" in self.model_map[submodel] else [] + else: + ireec_flags = ( + self.model_map[submodel]["ireec_flags"] + if "ireec_flags" in self.model_map[submodel] + else [] + ) if "external_weights_file" in self.model_map[submodel]: - weights_path = self.model_map[submodel]["external_weights_file"] + weights_path = self.model_map[submodel][ + "external_weights_file" + ] else: weights_path = None self.iree_module_dict[submodel] = get_iree_compiled_module( @@ -92,11 +103,12 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: mmap=True, external_weight_file=weights_path, extra_args=ireec_flags, - write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb") + write_to=os.path.join( + self.pipe_vmfb_path, submodel + ".vmfb" + ), ) return - def get_precompiled(self, pipe_id, submodel="None"): if submodel == "None": for model in self.model_map: @@ -109,10 +121,11 @@ def get_precompiled(self, pipe_id, submodel="None"): break for file in vmfbs: if submodel in file: - self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file) + self.model_map[submodel]["vmfb_path"] = os.path.join( + vmfbs_path, file + ) return - def import_torch_ir(self, submodel, kwargs): torch_ir = self.model_map[submodel]["initializer"]( **self.safe_dict(kwargs), compile_to="torch" @@ -120,17 +133,16 @@ def import_torch_ir(self, submodel, kwargs): if submodel == "clip": # clip.export_clip_model returns (torch_ir, tokenizer) torch_ir = torch_ir[0] - self.tempfiles[submodel] = get_resource_path(os.path.join( - "..", "shark_tmp", f"{submodel}.torch.tempfile" - )) - + self.tempfiles[submodel] = get_resource_path( + os.path.join("..", "shark_tmp", f"{submodel}.torch.tempfile") + ) + with open(self.tempfiles[submodel], "w+") as f: f.write(torch_ir) del torch_ir gc.collect() return - def load_submodels(self, submodels: list): for submodel in submodels: if submodel in self.iree_module_dict: @@ -149,13 +161,14 @@ def load_submodels(self, submodels: list): self.device, device_idx=0, rt_flags=[], - external_weight_file=self.model_map[submodel]['external_weight_file'], + external_weight_file=self.model_map[submodel][ + "external_weight_file" + ], ) else: self.get_compiled_map(self.pipe_id, submodel) return - def unload_submodels(self, submodels: list): for submodel in submodels: if submodel in self.iree_module_dict: @@ -163,18 +176,20 @@ def unload_submodels(self, submodels: list): gc.collect() return - def run(self, submodel, inputs): if not isinstance(inputs, list): inputs = [inputs] - inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs] - return self.iree_module_dict[submodel]['vmfb']['main'](*inp) - + inp = [ + ireert.asdevicearray( + self.iree_module_dict[submodel]["config"].device, input + ) + for input in inputs + ] + return self.iree_module_dict[submodel]["vmfb"]["main"](*inp) def safe_name(self, name): return name.replace("/", "_").replace("-", "_").replace("\\", "_") - def safe_dict(self, kwargs: dict): flat_args = {} for i in kwargs: @@ -183,4 +198,4 @@ def safe_dict(self, kwargs: dict): else: flat_args[i] = kwargs[i] - return flat_args + return flat_args diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py index b2a5e8a27e..2def4c311b 100644 --- a/apps/shark_studio/modules/prompt_encoding.py +++ b/apps/shark_studio/modules/prompt_encoding.py @@ -1,4 +1,3 @@ - from typing import List, Optional, Union from iree import runtime as ireert import re @@ -112,9 +111,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights( - pipe, prompt: List[str], max_length: int -): +def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. @@ -195,6 +192,7 @@ def pad_tokens_and_weights( return tokens, weights + def get_unweighted_text_embeddings( pipe, text_input, @@ -242,7 +240,6 @@ def get_unweighted_text_embeddings( return text_embeddings - # This function deals with NoneType values occuring in tokens after padding # It switches out None with 49407 as truncating None values causes matrix dimension errors, def filter_nonetype_tokens(tokens: List[List]): diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 7a42338b1a..215bfc6e68 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -17,7 +17,7 @@ def get_schedulers(model_id): - #TODO: switch over to turbine and run all on GPU + # TODO: switch over to turbine and run all on GPU print(f"\n[LOG] Initializing schedulers from model id: {model_id}") schedulers = dict() schedulers["PNDM"] = PNDMScheduler.from_pretrained( diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index 66ec452d0b..432bcc239a 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -50,6 +50,7 @@ "stabilityai/sdxl-turbo", ] + def view_json_file(file_path): content = "" with open(file_path, "r") as fopen: @@ -149,7 +150,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): else: sd_json = new_sd_config for i in sd_json["sd_init_image"]: - if i is not None: + if i is not None: if os.path.isfile(i): sd_image = [Image.open(i, mode="r")] else: @@ -651,7 +652,7 @@ def import_original(original_img, width, height): with gr.Column(scale=3, min_width=600): with gr.Group(): sd_gallery = gr.Gallery( - label="Generated images", + label="Generated images", show_label=False, elem_id="gallery", columns=2,