Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Studio) Fix controlnet switching. #2026

Merged
merged 4 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora

print(self.model_name)
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
self.sharktank_dir = sharktank_dir
Expand All @@ -241,7 +240,7 @@ def __init__(
args.hf_model_id = self.base_model_id
self.return_mlir = return_mlir

def get_extended_name_for_all_model(self):
def get_extended_name_for_all_model(self, model_list=None):
model_name = {}
sub_model_list = [
"clip",
Expand All @@ -255,6 +254,8 @@ def get_extended_name_for_all_model(self):
"stencil_adapter",
"stencil_adapter_512",
]
if model_list is not None:
sub_model_list = model_list
index = 0
for model in sub_model_list:
sub_model = model
Expand All @@ -272,7 +273,7 @@ def get_extended_name_for_all_model(self):
if stencil is not None:
cnet_config = (
self.model_namedata
+ "_v1-5"
+ "_sd15_"
+ stencil.split("_")[-1]
)
stencil_names.append(
Expand All @@ -283,6 +284,7 @@ def get_extended_name_for_all_model(self):
else:
model_name[model] = get_extended_name(sub_model + model_config)
index += 1

return model_name

def check_params(self, max_len, width, height):
Expand Down Expand Up @@ -765,7 +767,8 @@ def forward(

inputs = tuple(self.inputs["stencil_adapter"])
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
ext_model_name = self.model_name[model_name]
stencil_names = self.get_extended_name_for_all_model([model_name])
ext_model_name = stencil_names[model_name]
if isinstance(ext_model_name, list):
desired_name = None
print(ext_model_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,31 @@ def unload_controlnet_512(self, index):
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None

def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)

self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents

def prepare_image_latents(
self,
image,
Expand Down Expand Up @@ -203,10 +228,16 @@ def produce_stencil_latents(
self.load_unet_512()

for i, name in enumerate(self.controlnet_names):
use_names = []
if name is not None:
use_names.append(name)
else:
continue
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
self.controlnet_names = use_names

for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
Expand Down Expand Up @@ -461,6 +492,7 @@ def generate_images(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
self.sd_model.stencils = stencils
for i, hint in enumerate(preprocessed_hints):
if hint is not None:
hint = controlnet_hint_reshaping(
Expand All @@ -475,7 +507,7 @@ def generate_images(
for i, stencil in enumerate(stencils):
if stencil == None:
continue
if len(stencil_hints) >= i:
if len(stencil_hints) > i:
if stencil_hints[i] is not None:
print(f"Using preprocessed controlnet hint for {stencil}")
continue
Expand Down Expand Up @@ -518,19 +550,30 @@ def generate_images(

# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)

# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
if image is not None:
# Prepare input image latent
init_latents, final_timesteps = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
else:
# Prepare initial latent.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
final_timesteps = self.scheduler.timesteps

# Get Image latents
latents = self.produce_stencil_latents(
Expand Down
15 changes: 11 additions & 4 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def img2img_inf(

for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return
continue
if images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
Expand All @@ -120,6 +120,8 @@ def img2img_inf(
else:
# TODO: enable t2i + controlnets
image = None
if image:
image, _, _ = resize_stencil(image, width, height)

# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
Expand Down Expand Up @@ -152,7 +154,6 @@ def img2img_inf(
stencil_count += 1
if stencil_count > 0:
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, _, _ = resize_stencil(image, width, height)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete "
Expand All @@ -162,6 +163,7 @@ def img2img_inf(
cpu_scheduling = not args.scheduler.startswith("Shark")
args.precision = precision
dtype = torch.float32 if precision == "fp32" else torch.half
print(stencils)
new_config_obj = Config(
"img2img",
args.hf_model_id,
Expand All @@ -180,7 +182,12 @@ def img2img_inf(
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
or any(
global_obj.get_cfg_obj().stencils[idx] != stencil
for idx, stencil in enumerate(stencils)
)
):
print("clearing config because you changed something important")
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.batch_count = batch_count
Expand Down Expand Up @@ -632,7 +639,7 @@ def update_cn_input(
[cnet_1_image],
)

cnet_1_model.input(
cnet_1_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
Expand Down Expand Up @@ -739,7 +746,7 @@ def update_cn_input(
label="Preprocessed Hint",
interactive=True,
)
cnet_2_model.select(
cnet_2_model.change(
fn=(
lambda m, w, h, s, i, p: update_cn_input(
m, w, h, s, i, p, 0
Expand Down
Loading