Skip to content

Commit

Permalink
Fix txt2img + control adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 7, 2023
1 parent 500636d commit c01bbec
Showing 1 changed file with 55 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ def unload_controlnet_512(self, index):
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None

def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)

self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents

def prepare_image_latents(
self,
image,
Expand Down Expand Up @@ -203,10 +228,16 @@ def produce_stencil_latents(
self.load_unet_512()

for i, name in enumerate(self.controlnet_names):
use_names = []
if name is not None:
use_names.append(name)
else:
continue
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
self.controlnet_names = use_names

for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
Expand Down Expand Up @@ -519,19 +550,30 @@ def generate_images(

# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)

# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
if image is not None:
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
else:
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps

# Get Image latents
latents = self.produce_stencil_latents(
Expand Down

0 comments on commit c01bbec

Please sign in to comment.