From 58130432ab6444dd01b84067fb9ae1d693a13e41 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Sun, 17 Dec 2023 23:51:34 -0600 Subject: [PATCH] Complete SD pipeline. --- apps/shark_studio/api/sd.py | 410 ++++++++++++++---- apps/shark_studio/modules/img_processing.py | 6 +- apps/shark_studio/modules/pipeline.py | 99 +++-- apps/shark_studio/modules/prompt_encoding.py | 58 +-- apps/shark_studio/modules/schedulers.py | 2 +- .../web/configs/default_sd_config.json | 21 +- apps/shark_studio/web/ui/sd.py | 42 +- shark/iree_utils/compile_utils.py | 8 + 8 files changed, 434 insertions(+), 212 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 14d508f298..43c6a1830c 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -4,8 +4,10 @@ import os import json import numpy as np +from tqdm.auto import tqdm from pathlib import Path +from random import randint from turbine_models.custom_models.sd_inference import clip, unet, vae from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.web.utils.state import status_label @@ -16,6 +18,8 @@ from apps.shark_studio.modules.img_processing import ( resize_stencil, save_output_img, + resamplers, + resampler_list, ) from apps.shark_studio.modules.ckpt_processing import ( @@ -30,8 +34,7 @@ "clip": { "initializer": clip.export_clip_model, "external_weight_file": None, - "ireec_flags": ["--iree-flow-collapse-reduction-dims", - ], + "ireec_flags": ["--iree-flow-collapse-reduction-dims"], }, "vae_encode": { "initializer": vae.export_vae_model, @@ -48,6 +51,10 @@ "vae_decode": { "initializer": vae.export_vae_model, "external_weight_file": None, + "ireec_flags": ["--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + ], }, } @@ -78,15 +85,15 @@ def __init__( custom_vae: str = None, num_loras: int = 0, import_ir: bool = True, - is_img2img: bool = False, is_controlled: bool = False, ): self.model_max_length = 77 self.batch_size = batch_size self.precision = precision - self.is_img2img = is_img2img + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.height = height + self.width = width self.scheduler_obj = {} - self.precision = precision static_kwargs = { "pipe": {}, "clip": {"hf_model_name": base_model_id}, @@ -98,6 +105,8 @@ def __init__( #"num_loras": num_loras, "height": height, "width": width, + "precision": precision, + "max_length": 77 * 8, }, "vae_encode": { "hf_model_name": custom_vae if custom_vae else base_model_id, @@ -105,13 +114,15 @@ def __init__( "batch_size": batch_size, "height": height, "width": width, + "precision": precision, }, "vae_decode": { - "hf_model_name": custom_vae, + "hf_model_name": custom_vae if custom_vae else base_model_id, "vae_model": vae.VaeModel(hf_model_name=base_model_id, hf_auth_token=None), "batch_size": batch_size, "height": height, "width": width, + "precision": precision, }, } super().__init__( @@ -135,26 +146,26 @@ def __init__( gc.collect() - def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings): + def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img): print( f"\n[LOG] Preparing pipeline with scheduler {scheduler}" f"\n[LOG] Custom embeddings currently unsupported." ) + self.is_img2img = is_img2img schedulers = get_schedulers(self.base_model_id) - self.weights_path = get_checkpoints_path(self.pipe_id) + self.weights_path = get_checkpoints_path(self.safe_name(self.pipe_id)) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) - # accepting a list of schedulers in batched cases. - for i in scheduler: - self.scheduler_obj[i] = schedulers[i] - print(f"[LOG] Loaded scheduler: {i}") + self.scheduler = schedulers[scheduler] + print(f"[LOG] Loaded scheduler: {scheduler}") for model in adapters: self.model_map[model] = adapters[model] - if os.path.isfile(custom_weights): - for i in self.model_map: - self.model_map[i]["external_weights_file"] = None - elif custom_weights != "": - print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?") + if custom_weights: + if os.path.isfile(custom_weights): + for i in self.model_map: + self.model_map[i]["external_weights_file"] = None + elif custom_weights: + print(f"\n[LOG][WARNING] Custom weights were not found at {custom_weights}. Did you mean to pass a base model ID?") self.static_kwargs["pipe"] = { # "external_weight_path": self.weights_path, # "external_weights": "safetensors", @@ -162,6 +173,92 @@ def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings): self.get_compiled_map(pipe_id=self.pipe_id) print("\n[LOG] Pipeline successfully prepared for runtime.") return + + + def generate_images( + self, + prompt, + negative_prompt, + image, + steps, + strength, + guidance_scale, + seed, + ondemand, + repeatable_seeds, + use_base_vae, + resample_type, + control_mode, + hints, + ): + #TODO: Batched args + self.ondemand = ondemand + if self.is_img2img: + image, _ = self.process_sd_init_image(image, resample_type) + else: + image = None + + print("\n[LOG] Generating images...") + batched_args=[ + prompt, + negative_prompt, + #steps, + #strength, + #guidance_scale, + #seed, + #resample_type, + #control_mode, + #hints, + ] + for arg in batched_args: + if not isinstance(arg, list): + arg = [arg] * self.batch_size + if len(arg) < self.batch_size: + arg = arg * self.batch_size + else: + arg = [arg[i] for i in range(self.batch_size)] + + text_embeddings = self.encode_prompts_weight( + prompt, + negative_prompt, + ) + + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + + generator = torch.manual_seed(seed) + + init_latents, final_timesteps = self.prepare_latents( + generator=generator, + num_inference_steps=steps, + image=image, + strength=strength, + ) + + latents = self.produce_img_latents( + latents=init_latents, + text_embeddings=text_embeddings, + guidance_scale=guidance_scale, + total_timesteps=final_timesteps, + cpu_scheduling=True, # until we have schedulers through Turbine + ) + + # Img latents -> PIL images + all_imgs = [] + self.load_submodels(["vae_decode"]) + for i in tqdm(range(0, latents.shape[0], self.batch_size)): + imgs = self.decode_latents( + latents=latents[i : i + self.batch_size], + use_base_vae=use_base_vae, + cpu_scheduling=True, + ) + all_imgs.extend(imgs) + if self.ondemand: + self.unload_submodels(["vae_decode"]) + + return all_imgs def encode_prompts_weight( @@ -191,84 +288,220 @@ def encode_prompts_weight( text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) pad = (0, 0) * (len(text_embeddings.shape) - 2) - pad = pad + (0, 512 - text_embeddings.shape[1]) + pad = pad + (0, 77 * 8 - text_embeddings.shape[1]) text_embeddings = torch.nn.functional.pad(text_embeddings, pad) # SHARK: Report clip inference time clip_inf_time = (time.time() - clip_inf_start) * 1000 if self.ondemand: - self.unload_clip() + self.unload_submodels(["clip"]) gc.collect() print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") return text_embeddings.numpy().astype(np.float16) - - def generate_images( + + def prepare_latents( self, - prompt, - negative_prompt, - steps, + generator, + num_inference_steps, + image, strength, + ): + noise = torch.randn( + ( + self.batch_size, + 4, + self.height // 8, + self.width // 8, + ), + generator=generator, + dtype=self.dtype, + ).to("cpu") + + self.scheduler.set_timesteps(num_inference_steps) + if self.is_img2img: + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps + ) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + latents = self.encode_image(image) + latents = self.scheduler.add_noise( + latents, noise, timesteps[0].repeat(1) + ) + return latents, [timesteps] + else: + self.scheduler.is_scale_input_called = True + latents = noise * self.scheduler.init_noise_sigma + return latents, self.scheduler.timesteps + + + def encode_image(self, input_image): + self.load_submodels(["vae_encode"]) + vae_encode_start = time.time() + latents = self.run("vae_encode", input_image) + vae_inf_time = (time.time() - vae_encode_start) * 1000 + if self.ondemand: + self.unload_submodels(["vae_encode"]) + print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}") + + return latents + + + def produce_img_latents( + self, + latents, + text_embeddings, guidance_scale, - seed, - ondemand, - repeatable_seeds, - resample_type, - control_mode, - hints, + total_timesteps, + cpu_scheduling, + mask=None, + masked_image_latents=None, + return_all_latents=False, ): - print("\n[LOG] Generating images...") - batched_args=[ - prompt, - negative_prompt, - steps, - strength, - guidance_scale, - seed, - resample_type, - control_mode, - hints, - ] - for arg in batched_args: - if not isinstance(arg, list): - arg = [arg] * self.batch_size - if len(arg) < self.batch_size: - arg = arg * self.batch_size + # self.status = SD_STATE_IDLE + step_time_sum = 0 + latent_history = [latents] + text_embeddings = torch.from_numpy(text_embeddings).to(torch.float16) + text_embeddings_numpy = text_embeddings.detach().numpy() + guidance_scale = np.asarray([guidance_scale], dtype=np.float16) + self.load_submodels(["unet"]) + for i, t in tqdm(enumerate(total_timesteps)): + step_start_time = time.time() + timestep = torch.tensor([t]).to(self.dtype).detach().numpy() + latent_model_input = self.scheduler.scale_model_input(latents, t).to(self.dtype) + if mask is not None and masked_image_latents is not None: + latent_model_input = torch.cat( + [ + torch.from_numpy(np.asarray(latent_model_input)).to(torch.float16), + mask, + masked_image_latents, + ], + dim=1, + ).to(self.dtype) + if cpu_scheduling: + latent_model_input = latent_model_input.detach().numpy() + + # Profiling Unet. + # profile_device = start_profiling(file_path="unet.rdc") + noise_pred = self.run( + "unet", + [ + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + ], + ) + # end_profiling(profile_device) + + if cpu_scheduling: + noise_pred = torch.from_numpy(noise_pred.to_host()) + latents = self.scheduler.step( + noise_pred, t, latents + ).prev_sample else: - arg = [arg[i] for i in range(self.batch_size)] + latents = self.run("scheduler_step", (noise_pred, t, latents)) - text_embeddings = self.encode_prompts_weight( - prompt, - negative_prompt, - ) - print(text_embeddings) - test_img = [ - Image.open( - get_resource_path("../../tests/jupiter.png"), mode="r" - ).convert("RGB") - ] * self.batch_size - return test_img + latent_history.append(latents) + step_time = (time.time() - step_start_time) * 1000 + # self.log += ( + # f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms" + # ) + step_time_sum += step_time + + #if self.status == SD_STATE_CANCEL: + # break + + if self.ondemand: + self.unload_submodels(["unet"]) + gc.collect() + + avg_step_time = step_time_sum / len(total_timesteps) + print(f"\n[LOG] Average step time: {avg_step_time}ms/it") + + if not return_all_latents: + return latents + all_latents = torch.cat(latent_history, dim=0) + return all_latents + + + def decode_latents(self, latents, use_base_vae, cpu_scheduling): + if use_base_vae: + latents = 1 / 0.18215 * latents + + latents_numpy = latents.to(self.dtype) + if cpu_scheduling: + latents_numpy = latents.detach().numpy() + + #profile_device = start_profiling(file_path="vae.rdc") + vae_start = time.time() + images = self.run("vae_decode", latents_numpy).to_host() + vae_inf_time = (time.time() - vae_start) * 1000 + #end_profiling(profile_device) + print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") + + if use_base_vae: + images = torch.from_numpy(images) + images = (images.detach().cpu() * 255.0).numpy() + images = images.round() + + images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1) + pil_images = [Image.fromarray(image) for image in images.numpy()] + return pil_images + + + def process_sd_init_image(self, sd_init_image, resample_type): + if isinstance(sd_init_image, list): + images = [] + for img in sd_init_image: + img, _ = self.process_sd_init_image(img, resample_type) + images.append(img) + is_img2img = True + return images, is_img2img + if isinstance(sd_init_image, str): + if os.path.isfile(sd_init_image): + sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") + image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type) + else: + image = None + is_img2img = False + elif isinstance(sd_init_image, Image.Image): + image = sd_init_image.convert("RGB") + elif sd_init_image: + image = sd_init_image["image"].convert("RGB") + else: + image = None + is_img2img = False + if image: + resample_type = ( + resamplers[resample_type] + if resample_type in resampler_list + # Fallback to Lanczos + else Image.Resampling.LANCZOS + ) + image = image.resize((self.width, self.height), resample=resample_type) + image_arr = np.stack([np.array(i) for i in (image,)], axis=0) + image_arr = image_arr / 255.0 + image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) + image_arr = 2 * (image_arr - 0.5) + is_img2img = True + image = image_arr + return image, is_img2img def shark_sd_fn_dict_input( sd_kwargs: dict, ): print("[LOG] Submitting Request...") - input_imgs = [] - img_paths = sd_kwargs["sd_init_image"] - - for img_path in img_paths: - if img_path: - if os.path.isfile(img_path): - input_imgs.append( - Image.open(img_path, mode="r").convert("RGB") - ) - sd_kwargs["sd_init_image"] = input_imgs - # result = shark_sd_fn(**sd_kwargs) - # for i in range(sd_kwargs["batch_count"]): - # yield from result - # return result + + for key in sd_kwargs: + if sd_kwargs[key] in ["None", "", None, []]: + sd_kwargs[key] = None + if key == "seed": + sd_kwargs[key] = int(sd_kwargs[key]) + for i in range(1): generated_imgs = yield from shark_sd_fn(**sd_kwargs) yield generated_imgs @@ -278,7 +511,7 @@ def shark_sd_fn_dict_input( def shark_sd_fn( prompt, negative_prompt, - sd_init_image, + sd_init_image: list, height: int, width: int, steps: int, @@ -291,6 +524,7 @@ def shark_sd_fn( base_model_id: str, custom_weights: str, custom_vae: str, + use_base_vae: bool, precision: str, device: str, ondemand: bool, @@ -300,20 +534,9 @@ def shark_sd_fn( embeddings: dict, ): sd_kwargs = locals() - if isinstance(sd_init_image, Image.Image): - image = sd_init_image.convert("RGB") - elif sd_init_image: - image = sd_init_image["image"].convert("RGB") - else: - image = None - is_img2img = False - if image: - ( - image, - _, - _, - ) = resize_stencil(image, width, height) - is_img2img = True + is_img2img = True if sd_init_image[0] is not None else False + + print("\n[LOG] Performing Stable Diffusion Pipeline setup...") from apps.shark_studio.modules.shared_cmd_opts import cmd_opts @@ -358,7 +581,6 @@ def shark_sd_fn( "custom_vae": custom_vae, "num_loras": num_loras, "import_ir": cmd_opts.import_mlir, - "is_img2img": is_img2img, "is_controlled": is_controlled, } submit_prep_kwargs = { @@ -366,16 +588,19 @@ def shark_sd_fn( "custom_weights": custom_weights, "adapters": adapters, "embeddings": embeddings, + "is_img2img": is_img2img, } submit_run_kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, + "image": sd_init_image, "steps": steps, "strength": strength, "guidance_scale": guidance_scale, "seed": seed, "ondemand": ondemand, "repeatable_seeds": repeatable_seeds, + "use_base_vae": use_base_vae, "resample_type": resample_type, "control_mode": control_mode, "hints": hints, @@ -410,13 +635,9 @@ def shark_sd_fn( # if global_obj.get_sd_status() == SD_STATE_CANCEL: # break # else: - try: - this_seed = seed[current_batch] - except: - this_seed = seed[0] save_output_img( out_imgs[0], - this_seed, + seed, sd_kwargs, ) generated_imgs.extend(out_imgs) @@ -438,6 +659,7 @@ def view_json_file(file_path): return content + if __name__ == "__main__": from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import apps.shark_studio.web.utils.globals as global_obj diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py index d0e1b0196f..80062814cf 100644 --- a/apps/shark_studio/modules/img_processing.py +++ b/apps/shark_studio/modules/img_processing.py @@ -75,9 +75,9 @@ def save_output_img(output_img, img_seed, extra_info=None): "parameters", f"{extra_info['prompt'][0]}" f"\nNegative prompt: {extra_info['negative_prompt'][0]}" - f"\nSteps: {extra_info['steps'][0]}," - f"Sampler: {extra_info['scheduler'][0]}, " - f"CFG scale: {extra_info['guidance_scale'][0]}, " + f"\nSteps: {extra_info['steps']}," + f"Sampler: {extra_info['scheduler']}, " + f"CFG scale: {extra_info['guidance_scale']}, " f"Seed: {img_seed}," f"Size: {png_size_text}, " f"Model: {img_model}, " diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 30a493734b..6c78515cca 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -1,5 +1,10 @@ from msvcrt import kbhit -from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, + clean_device_info, + get_iree_target_triple, +) from apps.shark_studio.web.utils.file_utils import ( get_checkpoints_path, get_resource_path, @@ -32,8 +37,8 @@ def __init__( self.model_map = model_map self.static_kwargs = static_kwargs self.base_model_id = base_model_id - self.device_name = device - self.device = device.split("=>")[-1].strip(" ") + self.triple = get_iree_target_triple(device) + self.device, self.device_id = clean_device_info(device) self.import_mlir = import_mlir self.iree_module_dict = {} self.tempfiles = {} @@ -46,11 +51,11 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: # initialization. As soon as you have a pipeline ID unique to your static torch IR parameters, # and your model map is populated with any IR - unique model IDs and their static params, # call this method to get the artifacts associated with your map. - self.pipe_id = pipe_id + self.pipe_id = self.safe_name(pipe_id) self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id)) self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True) - print("\n[LOG] Checking for pre-compiled artifacts.") if submodel == "None": + print("\n[LOG] Gathering any pre-compiled artifacts....") for key in self.model_map: self.get_compiled_map(pipe_id, submodel=key) else: @@ -58,10 +63,12 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: ireec_flags = [] if submodel in self.iree_module_dict: if "vmfb" in self.iree_module_dict[submodel]: - print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...") + print(f"\n[LOG] Executable for {submodel} already loaded...") return + elif "vmfb_path" in self.model_map[submodel]: + return elif submodel not in self.tempfiles: - print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...") + print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...") if submodel in self.static_kwargs: init_kwargs = self.static_kwargs[submodel] for key in self.static_kwargs["pipe"]: @@ -90,16 +97,6 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: return - def hijack_weights(self, weights_path, submodel="None"): - if submodel == "None": - for i in self.model_map: - self.hijack_weights(weights_path, i) - else: - if submodel in self.iree_module_dict: - self.model_map[submodel]["external_weights_file"] = weights_path - return - - def get_precompiled(self, pipe_id, submodel="None"): if submodel == "None": for model in self.model_map: @@ -112,33 +109,10 @@ def get_precompiled(self, pipe_id, submodel="None"): break for file in vmfbs: if submodel in file: - print(f"Found existing .vmfb at {file}") - self.iree_module_dict[submodel] = {} - ( - self.iree_module_dict[submodel]["vmfb"], - self.iree_module_dict[submodel]["config"], - self.iree_module_dict[submodel]["temp_file_to_unlink"], - ) = load_vmfb_using_mmap( - os.path.join(vmfbs_path, file), - self.device, - device_idx=0, - rt_flags=[], - external_weight_file=self.model_map[submodel]['external_weight_file'], - ) + self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file) return - def safe_dict(self, kwargs: dict): - flat_args = {} - for i in kwargs: - if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]: - flat_args[i] = [kwargs[i][j] for j in kwargs[i]] - else: - flat_args[i] = kwargs[i] - - return flat_args - - def import_torch_ir(self, submodel, kwargs): torch_ir = self.model_map[submodel]["initializer"]( **self.safe_dict(kwargs), compile_to="torch" @@ -160,18 +134,53 @@ def import_torch_ir(self, submodel, kwargs): def load_submodels(self, submodels: list): for submodel in submodels: if submodel in self.iree_module_dict: + print(f"\n[LOG] {submodel} is ready for inference.") + if "vmfb_path" in self.model_map[submodel]: print( - f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}" + f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}" + ) + self.iree_module_dict[submodel] = {} + ( + self.iree_module_dict[submodel]["vmfb"], + self.iree_module_dict[submodel]["config"], + self.iree_module_dict[submodel]["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.model_map[submodel]["vmfb_path"], + self.device, + device_idx=0, + rt_flags=[], + external_weight_file=self.model_map[submodel]['external_weight_file'], ) else: self.get_compiled_map(self.pipe_id, submodel) return + def unload_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + del self.iree_module_dict[submodel] + gc.collect() + return + + def run(self, submodel, inputs): - inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)] + if not isinstance(inputs, list): + inputs = [inputs] + inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs] return self.iree_module_dict[submodel]['vmfb']['main'](*inp) - def safe_name(name): - return name.replace("/", "_").replace("-", "_") + def safe_name(self, name): + return name.replace("/", "_").replace("-", "_").replace("\\", "_") + + + def safe_dict(self, kwargs: dict): + flat_args = {} + for i in kwargs: + if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]: + flat_args[i] = [kwargs[i][j] for j in kwargs[i]] + else: + flat_args[i] = kwargs[i] + + return flat_args diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py index d97d334f29..b2a5e8a27e 100644 --- a/apps/shark_studio/modules/prompt_encoding.py +++ b/apps/shark_studio/modules/prompt_encoding.py @@ -3,6 +3,7 @@ from iree import runtime as ireert import re import torch +import numpy as np re_attention = re.compile( r""" @@ -161,7 +162,7 @@ def pad_tokens_and_weights( r""" Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. """ - max_embeddings_multiples = 8 + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) weights_length = ( max_length if no_boseos_middle @@ -194,13 +195,16 @@ def pad_tokens_and_weights( return tokens, weights - def get_unweighted_text_embeddings( pipe, - text_input: torch.Tensor, + text_input, chunk_length: int, no_boseos_middle: Optional[bool] = True, ): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) if max_embeddings_multiples > 1: text_embeddings = [] @@ -214,7 +218,7 @@ def get_unweighted_text_embeddings( text_input_chunk[:, 0] = text_input[0, 0] text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = pipe.run("clip", text_input_chunk)[0] + text_embedding = pipe.run("clip", text_input_chunk)[0].to_host() if no_boseos_middle: if i == 0: @@ -231,50 +235,14 @@ def get_unweighted_text_embeddings( # SHARK: Convert the result to tensor # text_embeddings = torch.concat(text_embeddings, axis=1) text_embeddings_np = np.concatenate(np.array(text_embeddings)) - text_embeddings = torch.from_numpy(text_embeddings_np)[None, :] + text_embeddings = torch.from_numpy(text_embeddings_np) else: text_embeddings = pipe.run("clip", text_input)[0] - # text_embeddings = torch.from_numpy(text_embeddings)[None, :] - return torch.from_numpy(text_embeddings.to_host()) - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = 8 - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[ - :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 - ].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - # text_embedding = pipe.text_encoder(text_input_chunk)[0] - - print(text_input_chunk) - breakpoint() - text_embedding = pipe.run("clip", text_input_chunk) - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - # SHARK: Convert the result to tensor - # text_embeddings = torch.concat(text_embeddings, axis=1) - text_embeddings_np = np.concatenate(np.array(text_embeddings)) - text_embeddings = torch.from_numpy(text_embeddings_np)[None, :] + text_embeddings = torch.from_numpy(text_embeddings.to_host()) return text_embeddings + # This function deals with NoneType values occuring in tokens after padding # It switches out None with 49407 as truncating None values causes matrix dimension errors, def filter_nonetype_tokens(tokens: List[List]): @@ -286,7 +254,7 @@ def get_weighted_text_embeddings( prompt: List[str], uncond_prompt: List[str] = None, max_embeddings_multiples: Optional[int] = 8, - no_boseos_middle: Optional[bool] = False, + no_boseos_middle: Optional[bool] = True, skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, ): @@ -325,12 +293,12 @@ def get_weighted_text_embeddings( max_length = max( max_length, max([len(token) for token in uncond_tokens]) ) - max_embeddings_multiples = min( max_embeddings_multiples, (max_length - 1) // (pipe.model_max_length - 2) + 1, ) max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 # pad the length of tokens and weights diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py index 484c8384a6..7a42338b1a 100644 --- a/apps/shark_studio/modules/schedulers.py +++ b/apps/shark_studio/modules/schedulers.py @@ -18,7 +18,7 @@ def get_schedulers(model_id): #TODO: switch over to turbine and run all on GPU - print(f"[LOG] Initializing schedulers from model id: {model_id}") + print(f"\n[LOG] Initializing schedulers from model id: {model_id}") schedulers = dict() schedulers["PNDM"] = PNDMScheduler.from_pretrained( model_id, diff --git a/apps/shark_studio/web/configs/default_sd_config.json b/apps/shark_studio/web/configs/default_sd_config.json index 5886e2e569..26b67660c1 100644 --- a/apps/shark_studio/web/configs/default_sd_config.json +++ b/apps/shark_studio/web/configs/default_sd_config.json @@ -1,23 +1,24 @@ { "prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ], "negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ], - "sd_init_image": [ "None" ], + "sd_init_image": [ null ], "height": 512, "width": 512, - "steps": [ 50 ], - "strength": [ 0.8 ], - "guidance_scale": [ 7.5 ], - "seed": [ -1 ], + "steps": 50, + "strength": 0.8, + "guidance_scale": 7.5, + "seed": -1, "batch_count": 1, "batch_size": 1, - "scheduler": [ "EulerDiscrete" ], + "scheduler": "EulerDiscrete", "base_model_id": "runwayml/stable-diffusion-v1-5", - "custom_weights": "", - "custom_vae": "", + "custom_weights": null, + "custom_vae": null, + "use_base_vae": false, "precision": "fp16", "device": "vulkan", - "ondemand": "False", - "repeatable_seeds": "False", + "ondemand": false, + "repeatable_seeds": false, "resample_type": "Nearest Neighbor", "controlnets": {}, "embeddings": {} diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index c926b83798..66ec452d0b 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -41,6 +41,14 @@ from apps.shark_studio.modules import logger import apps.shark_studio.web.utils.globals as global_obj +sd_default_models = [ + "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", + "stabilityai/stable-diffusion-2-1-base", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-xl-1.0", + "stabilityai/sdxl-turbo", +] def view_json_file(file_path): content = "" @@ -105,6 +113,7 @@ def pull_sd_configs( base_model_id, custom_weights, custom_vae, + use_base_vae, precision, device, ondemand, @@ -120,11 +129,6 @@ def pull_sd_configs( "prompt", "negative_prompt", "sd_init_image", - "steps", - "strength", - "guidance_scale", - "seed", - "scheduler", ]: sd_cfg[arg] = [sd_args[arg]] elif arg in ["controlnets", "embeddings"]: @@ -144,8 +148,10 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): sd_json[key] = new_sd_config[key] else: sd_json = new_sd_config - if os.path.isfile(sd_json["sd_init_image"][0]): - sd_image = Image.open(sd_json["sd_init_image"][0], mode="r") + for i in sd_json["sd_init_image"]: + if i is not None: + if os.path.isfile(i): + sd_image = [Image.open(i, mode="r")] else: sd_image = None @@ -155,16 +161,17 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): sd_image, sd_json["height"], sd_json["width"], - sd_json["steps"][0], - sd_json["strength"][0], + sd_json["steps"], + sd_json["strength"], sd_json["guidance_scale"], - sd_json["seed"][0], + sd_json["seed"], sd_json["batch_count"], sd_json["batch_size"], - sd_json["scheduler"][0], + sd_json["scheduler"], sd_json["base_model_id"], sd_json["custom_weights"], sd_json["custom_vae"], + sd_json["use_base_vae"], sd_json["precision"], sd_json["device"], sd_json["ondemand"], @@ -320,7 +327,7 @@ def import_original(original_img, width, height): info="Select or enter HF model ID", elem_id="custom_model", value="stabilityai/stable-diffusion-2-1-base", - choices=sd_model_map.keys(), + choices=sd_default_models, ) # base_model_id custom_weights = gr.Dropdown( label="Custom Weights", @@ -328,7 +335,7 @@ def import_original(original_img, width, height): elem_id="custom_model", value="None", allow_custom_value=True, - choices=get_checkpoints(base_model_id), + choices=["None"] + get_checkpoints(base_model_id), ) # with gr.Column(scale=2): sd_vae_info = ( @@ -361,6 +368,11 @@ def import_original(original_img, width, height): ], visible=True, ) + use_base_vae = gr.Checkbox( + value=False, + label="Baked VAE", + interactive=True, + ) with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( @@ -639,7 +651,7 @@ def import_original(original_img, width, height): with gr.Column(scale=3, min_width=600): with gr.Group(): sd_gallery = gr.Gallery( - label="Generated images", + label="Generated images", show_label=False, elem_id="gallery", columns=2, @@ -719,6 +731,7 @@ def import_original(original_img, width, height): base_model_id, custom_weights, custom_vae, + use_base_vae, precision, device, ondemand, @@ -753,6 +766,7 @@ def import_original(original_img, width, height): base_model_id, custom_weights, custom_vae, + use_base_vae, precision, device, ondemand, diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 25c363f652..fd21f9812f 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -65,6 +65,14 @@ def get_iree_device_args(device, extra_args=[]): return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) return [] +def get_iree_target_triple(device): + args = get_iree_device_args(device) + for flag in args: + if "triple" in flag.split("-"): + triple = flag.split("=") + return triple + return "" + def clean_device_info(raw_device): # return appropriate device and device_id for consumption by Studio pipeline