From c01bbecb6aa664af2275fef1a313bb2d08baa9b5 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 7 Dec 2023 00:35:48 -0600 Subject: [PATCH] Fix txt2img + control adapters --- ...pipeline_shark_stable_diffusion_stencil.py | 68 +++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py index ea19431d48..51fd51f92d 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py @@ -125,6 +125,31 @@ def unload_controlnet_512(self, index): self.controlnet_512_id[index] = None self.controlnet_512[index] = None + def prepare_latents( + self, + batch_size, + height, + width, + generator, + num_inference_steps, + dtype, + ): + latents = torch.randn( + ( + batch_size, + 4, + height // 8, + width // 8, + ), + generator=generator, + dtype=torch.float32, + ).to(dtype) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + latents = latents * self.scheduler.init_noise_sigma + return latents + def prepare_image_latents( self, image, @@ -203,10 +228,16 @@ def produce_stencil_latents( self.load_unet_512() for i, name in enumerate(self.controlnet_names): + use_names = [] + if name is not None: + use_names.append(name) + else: + continue if text_embeddings.shape[1] <= self.model_max_length: self.load_controlnet(i, name) else: self.load_controlnet_512(i, name) + self.controlnet_names = use_names for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() @@ -519,19 +550,30 @@ def generate_images( # guidance scale as a float32 tensor. guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - # Prepare input image latent - init_latents, final_timesteps = self.prepare_image_latents( - image=image, - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - strength=strength, - dtype=dtype, - resample_type=resample_type, - ) + if image is not None: + # Prepare input image latent + init_latents, final_timesteps = self.prepare_image_latents( + image=image, + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + strength=strength, + dtype=dtype, + resample_type=resample_type, + ) + else: + # Prepare initial latent. + init_latents = self.prepare_latents( + batch_size=batch_size, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + dtype=dtype, + ) + final_timesteps = self.scheduler.timesteps # Get Image latents latents = self.produce_stencil_latents(