From 5ad765b543fb3896414ea8b44fa5b063750ffa70 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 6 Dec 2023 21:25:01 -0600 Subject: [PATCH 1/4] Fix controlnet switching. --- apps/stable_diffusion/src/models/model_wrappers.py | 10 +++++++--- .../pipeline_shark_stable_diffusion_stencil.py | 2 +- apps/stable_diffusion/web/ui/img2img_ui.py | 7 +++++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index ee3da65547..d7bb6339b5 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -219,7 +219,6 @@ def __init__( self.model_name = self.model_name + "_" + get_path_stem(use_lora) self.use_lora = use_lora - print(self.model_name) self.model_name = self.get_extended_name_for_all_model() self.debug = debug self.sharktank_dir = sharktank_dir @@ -241,7 +240,7 @@ def __init__( args.hf_model_id = self.base_model_id self.return_mlir = return_mlir - def get_extended_name_for_all_model(self): + def get_extended_name_for_all_model(self, model_list=None): model_name = {} sub_model_list = [ "clip", @@ -255,6 +254,8 @@ def get_extended_name_for_all_model(self): "stencil_adapter", "stencil_adapter_512", ] + if model_list: + sub_model_list=model_list index = 0 for model in sub_model_list: sub_model = model @@ -283,6 +284,8 @@ def get_extended_name_for_all_model(self): else: model_name[model] = get_extended_name(sub_model + model_config) index += 1 + print(f"model name at {index} = {self.model_name}") + return model_name def check_params(self, max_len, width, height): @@ -765,7 +768,8 @@ def forward( inputs = tuple(self.inputs["stencil_adapter"]) model_name = "stencil_adapter_512" if use_large else "stencil_adapter" - ext_model_name = self.model_name[model_name] + stencil_names = self.get_extended_name_for_all_model([model_name]) + ext_model_name = stencil_names[model_name] if isinstance(ext_model_name, list): desired_name = None print(ext_model_name) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index 66877f0d6d..4c7b80935d 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -475,7 +475,7 @@ def generate_images( for i, stencil in enumerate(stencils): if stencil == None: continue - if len(stencil_hints) >= i: + if len(stencil_hints) > i: if stencil_hints[i] is not None: print(f"Using preprocessed controlnet hint for {stencil}") continue diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 11ca040f5b..84ec5c860d 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -105,7 +105,7 @@ def img2img_inf( for i, stencil in enumerate(stencils): if images[i] is None and stencil is not None: - return + continue if images[i] is not None: if isinstance(images[i], dict): images[i] = images[i]["composite"] @@ -120,6 +120,8 @@ def img2img_inf( else: # TODO: enable t2i + controlnets image = None + if image: + image, _, _ = resize_stencil(image, width, height) # set ckpt_loc and hf_model_id. args.ckpt_loc = "" @@ -152,7 +154,6 @@ def img2img_inf( stencil_count += 1 if stencil_count > 0: args.hf_model_id = "runwayml/stable-diffusion-v1-5" - image, _, _ = resize_stencil(image, width, height) elif "Shark" in args.scheduler: print( f"Shark schedulers are not supported. Switching to EulerDiscrete " @@ -162,6 +163,7 @@ def img2img_inf( cpu_scheduling = not args.scheduler.startswith("Shark") args.precision = precision dtype = torch.float32 if precision == "fp32" else torch.half + print(stencils) new_config_obj = Config( "img2img", args.hf_model_id, @@ -180,6 +182,7 @@ def img2img_inf( if ( not global_obj.get_sd_obj() or global_obj.get_cfg_obj() != new_config_obj + or global_obj.get_cfg_obj().stencils != new_config_obj.stencils ): global_obj.clear_cache() global_obj.set_cfg_obj(new_config_obj) From a622cfd3413de93a2897e76aab223c5529da70d6 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 7 Dec 2023 00:10:51 -0600 Subject: [PATCH 2/4] fixup switching --- apps/stable_diffusion/src/models/model_wrappers.py | 6 +++--- .../pipeline_shark_stable_diffusion_stencil.py | 1 + apps/stable_diffusion/web/ui/img2img_ui.py | 10 +++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index d7bb6339b5..4f70a3df0a 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -254,8 +254,8 @@ def get_extended_name_for_all_model(self, model_list=None): "stencil_adapter", "stencil_adapter_512", ] - if model_list: - sub_model_list=model_list + if model_list is not None: + sub_model_list = model_list index = 0 for model in sub_model_list: sub_model = model @@ -273,7 +273,7 @@ def get_extended_name_for_all_model(self, model_list=None): if stencil is not None: cnet_config = ( self.model_namedata - + "_v1-5" + + "_sd15_" + stencil.split("_")[-1] ) stencil_names.append( diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index 4c7b80935d..ea19431d48 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -461,6 +461,7 @@ def generate_images( # image, use_stencil, height, width, dtype, num_images_per_prompt=1 # ) stencil_hints = [] + self.sd_model.stencils = stencils for i, hint in enumerate(preprocessed_hints): if hint is not None: hint = controlnet_hint_reshaping( diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index 84ec5c860d..f3522656e4 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -182,8 +182,12 @@ def img2img_inf( if ( not global_obj.get_sd_obj() or global_obj.get_cfg_obj() != new_config_obj - or global_obj.get_cfg_obj().stencils != new_config_obj.stencils + or any( + global_obj.get_cfg_obj().stencils[idx] != stencil + for idx, stencil in enumerate(stencils) + ) ): + print("clearing config because you changed something important") global_obj.clear_cache() global_obj.set_cfg_obj(new_config_obj) args.batch_count = batch_count @@ -635,7 +639,7 @@ def update_cn_input( [cnet_1_image], ) - cnet_1_model.input( + cnet_1_model.change( fn=( lambda m, w, h, s, i, p: update_cn_input( m, w, h, s, i, p, 0 @@ -742,7 +746,7 @@ def update_cn_input( label="Preprocessed Hint", interactive=True, ) - cnet_2_model.select( + cnet_2_model.change( fn=( lambda m, w, h, s, i, p: update_cn_input( m, w, h, s, i, p, 0 From 500636decdd72e000e404e5b74c8c6c22152954a Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 7 Dec 2023 00:26:44 -0600 Subject: [PATCH 3/4] Remove print --- apps/stable_diffusion/src/models/model_wrappers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py index 4f70a3df0a..57be6da97f 100644 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ b/apps/stable_diffusion/src/models/model_wrappers.py @@ -284,7 +284,6 @@ def get_extended_name_for_all_model(self, model_list=None): else: model_name[model] = get_extended_name(sub_model + model_config) index += 1 - print(f"model name at {index} = {self.model_name}") return model_name From c01bbecb6aa664af2275fef1a313bb2d08baa9b5 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 7 Dec 2023 00:35:48 -0600 Subject: [PATCH 4/4] Fix txt2img + control adapters --- ...pipeline_shark_stable_diffusion_stencil.py | 68 +++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index ea19431d48..51fd51f92d 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -125,6 +125,31 @@ def unload_controlnet_512(self, index): self.controlnet_512_id[index] = None self.controlnet_512[index] = None + def prepare_latents( + self, + batch_size, + height, + width, + generator, + num_inference_steps, + dtype, + ): + latents = torch.randn( + ( + batch_size, + 4, + height // 8, + width // 8, + ), + generator=generator, + dtype=torch.float32, + ).to(dtype) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + latents = latents * self.scheduler.init_noise_sigma + return latents + def prepare_image_latents( self, image, @@ -203,10 +228,16 @@ def produce_stencil_latents( self.load_unet_512() for i, name in enumerate(self.controlnet_names): + use_names = [] + if name is not None: + use_names.append(name) + else: + continue if text_embeddings.shape[1] <= self.model_max_length: self.load_controlnet(i, name) else: self.load_controlnet_512(i, name) + self.controlnet_names = use_names for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() @@ -519,19 +550,30 @@ def generate_images( # guidance scale as a float32 tensor. guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - # Prepare input image latent - init_latents, final_timesteps = self.prepare_image_latents( - image=image, - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - strength=strength, - dtype=dtype, - resample_type=resample_type, - ) + if image is not None: + # Prepare input image latent + init_latents, final_timesteps = self.prepare_image_latents( + image=image, + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + strength=strength, + dtype=dtype, + resample_type=resample_type, + ) + else: + # Prepare initial latent. + init_latents = self.prepare_latents( + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + dtype=dtype, + ) + final_timesteps = self.scheduler.timesteps # Get Image latents latents = self.produce_stencil_latents(