Skip to content

Commit

Permalink
Merge branch 'main' into cpu_count_task_flag
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored Jun 29, 2023
2 parents 5c06a20 + d496053 commit 6cbfc30
Show file tree
Hide file tree
Showing 18 changed files with 190 additions and 50 deletions.
26 changes: 16 additions & 10 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,14 @@ def compile(self):
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model

def decode_tokens(self, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])

res_str = self.tokenizer.decode(res_tokens)
return res_str

def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
Expand All @@ -445,7 +453,6 @@ def generate(self, prompt, cli=False):
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res = []
res_tokens = []
params = {
"prompt": prompt,
Expand All @@ -461,8 +468,8 @@ def generate(self, prompt, cli=False):
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
yield detok

res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
Expand Down Expand Up @@ -495,25 +502,24 @@ def generate(self, prompt, cli=False):
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)

if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str

if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()

for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])

res_str = self.tokenizer.decode(res_tokens)
res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
yield res_str

def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/scripts/upscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
Expand Down
5 changes: 3 additions & 2 deletions apps/stable_diffusion/src/models/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,17 @@ def forward(self, latent, timestep, text_embedding, noise_level):
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# prompts and negative prompts must be a list.
Expand All @@ -156,7 +157,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand Down Expand Up @@ -408,7 +409,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand Down Expand Up @@ -409,7 +410,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# Control Embedding check & conversion
Expand All @@ -230,7 +231,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# guidance scale as a float32 tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def produce_img_latents(
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
Expand All @@ -182,15 +185,26 @@ def produce_img_latents(

# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
Expand Down Expand Up @@ -219,6 +233,7 @@ def produce_img_latents(

if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"

Expand All @@ -243,6 +258,7 @@ def generate_images(
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
Expand All @@ -264,7 +280,10 @@ def generate_images(

# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)

# 4. Preprocess image
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/img2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/inpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def inpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/outpaint_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def outpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time
Expand Down
9 changes: 2 additions & 7 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,11 @@ def chat(curr_system_message, history, model, device, precision):
)
prompt = messages.strip()
print("prompt = ", prompt)
sentence = vicuna_model.generate(prompt)

partial_text = ""
for new_text in sentence.split(" "):
# print(new_text)
partial_text += new_text + " "
for partial_text in vicuna_model.generate(prompt):
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
history[-1][1] = sentence

return history

# else Model is StableLM
Expand Down
1 change: 1 addition & 0 deletions apps/stable_diffusion/web/ui/upscaler_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def upscaler_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
Expand Down
28 changes: 12 additions & 16 deletions shark/examples/shark_inference/mega_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import torch_mlir
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.utils import (
compile_through_fx,
args,
)
from shark.shark_compile import shark_compile_through_fx
from MEGABYTE_pytorch import MEGABYTE

import os
Expand Down Expand Up @@ -37,23 +34,22 @@ def forward(self, input):


megaModel = MegaModel()
input = [torch.randint(0, 16000, (1, 1024, 4))]
inputs = [torch.randint(0, 16000, (1, 1024, 4))]

# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
# 1. aten.alias
shark_module, _ = compile_through_fx(
megaModel,
inputs=input,
shark_module, _ = shark_compile_through_fx(
model=megaModel,
inputs=inputs,
extended_model_name="mega_shark",
debug=False,
generate_vmfb=True,
is_f16=False,
f16_input_mask=None,
save_dir=os.getcwd(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
base_model_id=None,
model_name="mega_shark",
precision=None,
return_mlir=True,
device="cuda",
mlir_dialect="tm_tensor",
)
# logits = model(x)

Expand All @@ -63,10 +59,10 @@ def print_output_info(output, msg):
print("\n\t", output.shape)


ans = shark_module("forward", input)
ans = shark_module("forward", inputs)
print_output_info(torch.from_numpy(ans), "SHARK's output")

ans = megaModel.forward(*input)
ans = megaModel.forward(*inputs)
print_output_info(ans, "ORIGINAL Model's output")

# and sample from the logits accordingly
Expand Down
Loading

0 comments on commit 6cbfc30

Please sign in to comment.