Skip to content

Commit

Permalink
(Studio) Fix controlnet switching. (#2026)
Browse files Browse the repository at this point in the history
* Fix controlnet switching.

* Fix txt2img + control adapters
  • Loading branch information
monorimet authored Dec 7, 2023
1 parent 7e12d17 commit 7159698
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 22 deletions.
11 changes: 7 additions & 4 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora

print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
Expand All @@ -241,7 +240,7 @@ def __init__(
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir

def get_extended_name_for_all_model(self):
def get_extended_name_for_all_model(self, model_list=None):
model_name = {}
sub_model_list = [
"clip",
Expand All @@ -255,6 +254,8 @@ def get_extended_name_for_all_model(self):
"stencil_adapter",
"stencil_adapter_512",
]
if model_list is not None:
sub_model_list = model_list
index = 0
for model in sub_model_list:
sub_model = model
Expand All @@ -272,7 +273,7 @@ def get_extended_name_for_all_model(self):
if stencil is not None:
cnet_config = (
self.model_namedata
+ "_v1-5"
+ "_sd15_"
+ stencil.split("_")[-1]
)
stencil_names.append(
Expand All @@ -283,6 +284,7 @@ def get_extended_name_for_all_model(self):
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1

return model_name

def check_params(self, max_len, width, height):
Expand Down Expand Up @@ -765,7 +767,8 @@ def forward(

inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
ext_model_name = self.model_name[model_name]
stencil_names = self.get_extended_name_for_all_model([model_name])
ext_model_name = stencil_names[model_name]
if isinstance(ext_model_name, list):
desired_name = None
print(ext_model_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ def unload_controlnet_512(self, index):
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None

def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)

self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents

def prepare_image_latents(
self,
image,
Expand Down Expand Up @@ -203,10 +228,16 @@ def produce_stencil_latents(
self.load_unet_512()

for i, name in enumerate(self.controlnet_names):
use_names = []
if name is not None:
use_names.append(name)
else:
continue
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
self.controlnet_names = use_names

for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
Expand Down Expand Up @@ -461,6 +492,7 @@ def generate_images(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
self.sd_model.stencils = stencils
for i, hint in enumerate(preprocessed_hints):
if hint is not None:
hint = controlnet_hint_reshaping(
Expand All @@ -475,7 +507,7 @@ def generate_images(
for i, stencil in enumerate(stencils):
if stencil == None:
continue
if len(stencil_hints) >= i:
if len(stencil_hints) > i:
if stencil_hints[i] is not None:
print(f"Using preprocessed controlnet hint for {stencil}")
continue
Expand Down Expand Up @@ -518,19 +550,30 @@ def generate_images(

# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)

# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
if image is not None:
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
else:
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps

# Get Image latents
latents = self.produce_stencil_latents(
Expand Down
15 changes: 11 additions & 4 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def img2img_inf(

for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return
continue
if images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
Expand All @@ -120,6 +120,8 @@ def img2img_inf(
else:
# TODO: enable t2i + controlnets
image = None
if image:
image, _, _ = resize_stencil(image, width, height)

# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
Expand Down Expand Up @@ -152,7 +154,6 @@ def img2img_inf(
stencil_count += 1
if stencil_count > 0:
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, _, _ = resize_stencil(image, width, height)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete "
Expand All @@ -162,6 +163,7 @@ def img2img_inf(
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
print(stencils)
new_config_obj = Config(
"img2img",
args.hf_model_id,
Expand All @@ -180,7 +182,12 @@ def img2img_inf(
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
or any(
global_obj.get_cfg_obj().stencils[idx] != stencil
for idx, stencil in enumerate(stencils)
)
):
print("clearing config because you changed something important")
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
Expand Down Expand Up @@ -632,7 +639,7 @@ def update_cn_input(
[cnet_1_image],
)

cnet_1_model.input(
cnet_1_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
Expand Down Expand Up @@ -739,7 +746,7 @@ def update_cn_input(
label="Preprocessed Hint",
interactive=True,
)
cnet_2_model.select(
cnet_2_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
Expand Down

0 comments on commit 7159698

Please sign in to comment.