Skip to content

Commit

Permalink
Add unet512 support for the other StableDiffusion pipelines (#1602)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 authored Jun 27, 2023
1 parent 1d6a1f9 commit 6274a81
Show file tree
Hide file tree
Showing 14 changed files with 61 additions and 17 deletions.
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/upscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
5 changes: 3 additions & 2 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,17 @@ def forward(self, latent, timestep, text_embedding, noise_level):
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# prompts and negative prompts must be a list.
Expand All @@ -156,7 +157,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand Down Expand Up @@ -408,7 +409,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand Down Expand Up @@ -409,7 +410,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# Control Embedding check & conversion
Expand All @@ -230,7 +231,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def produce_img_latents(
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
Expand All @@ -182,15 +185,26 @@ def produce_img_latents(

# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
Expand Down Expand Up @@ -219,6 +233,7 @@ def produce_img_latents(

if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"

Expand All @@ -243,6 +258,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand All @@ -264,7 +280,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# 4. Preprocess image
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/inpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def inpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/outpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def outpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/upscaler_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def upscaler_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
Expand Down

0 comments on commit 6274a81

Please sign in to comment.