Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch count and tweaks to chatbot. #2151

Merged
merged 8 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import get_resource_path
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
Expand Down Expand Up @@ -65,6 +70,7 @@ def __init__(
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
Expand Down Expand Up @@ -165,6 +171,7 @@ def __init__(
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
Expand Down Expand Up @@ -194,11 +201,27 @@ def compile(self) -> None:
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="torch",
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
Expand Down Expand Up @@ -329,6 +352,17 @@ def chat_hf(self, prompt):
return result_output, total_time


def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path


def llm_chat_api(InputData: dict):
from datetime import datetime as dt

Expand Down
44 changes: 30 additions & 14 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,17 @@ def __init__(
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()

def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
def prepare_pipe(
self, custom_weights, adapters, embeddings, is_img2img, compiled_pipeline
):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = False
mlirs = copy.deepcopy(self.model_map)
vmfbs = copy.deepcopy(self.model_map)
weights = copy.deepcopy(self.model_map)
if not self.is_sdxl:
compiled_pipeline = False
self.compiled_pipeline = compiled_pipeline

if custom_weights:
custom_weights = os.path.join(
Expand Down Expand Up @@ -253,7 +258,6 @@ def generate_images(
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
hints,
Expand All @@ -272,7 +276,7 @@ def generate_images(
def shark_sd_fn_dict_input(
sd_kwargs: dict,
):
print("[LOG] Submitting Request...")
print("\n[LOG] Submitting Request...")

for key in sd_kwargs:
if sd_kwargs[key] in [None, []]:
Expand All @@ -282,9 +286,8 @@ def shark_sd_fn_dict_input(
if key == "seed":
sd_kwargs[key] = int(sd_kwargs[key])

for i in range(1):
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
yield generated_imgs
generated_imgs = yield from shark_sd_fn(**sd_kwargs)
return generated_imgs


def shark_sd_fn(
Expand All @@ -307,7 +310,7 @@ def shark_sd_fn(
device: str,
target_triple: str,
ondemand: bool,
repeatable_seeds: bool,
compiled_pipeline: bool,
resample_type: str,
controlnets: dict,
embeddings: dict,
Expand Down Expand Up @@ -370,6 +373,7 @@ def shark_sd_fn(
"adapters": adapters,
"embeddings": embeddings,
"is_img2img": is_img2img,
"compiled_pipeline": compiled_pipeline,
}
submit_run_kwargs = {
"prompt": prompt,
Expand All @@ -379,7 +383,6 @@ def shark_sd_fn(
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"hints": hints,
Expand Down Expand Up @@ -412,22 +415,35 @@ def shark_sd_fn(
for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
if not isinstance(out_imgs, list):
out_imgs = [out_imgs]
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
save_output_img(
out_imgs[current_batch],
seed,
sd_kwargs,
)
for batch in range(batch_size):
save_output_img(
out_imgs[batch],
seed,
sd_kwargs,
)
generated_imgs.extend(out_imgs)
# TODO: make seed changes over batch counts more configurable.
submit_run_kwargs["seed"] = submit_run_kwargs["seed"] + 1
yield generated_imgs, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
)
return generated_imgs, ""
return (generated_imgs, "")


def unload_sd():
print("Unloading models.")
import apps.shark_studio.web.utils.globals as global_obj

global_obj.clear_cache()
gc.collect()


def cancel_sd():
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def view_json_file(file_obj):
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
Expand Down
54 changes: 27 additions & 27 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from apps.shark_studio.api.sd import (
shark_sd_fn_dict_input,
cancel_sd,
unload_sd,
)
from apps.shark_studio.api.controlnet import (
cnet_preview,
Expand Down Expand Up @@ -119,7 +120,7 @@ def pull_sd_configs(
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
controlnets,
embeddings,
Expand Down Expand Up @@ -178,7 +179,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["compiled_pipeline"],
sd_json["resample_type"],
sd_json["controlnets"],
sd_json["embeddings"],
Expand Down Expand Up @@ -587,21 +588,6 @@ def base_model_changed(base_model_id):
object_fit="fit",
preview=True,
)
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Row():
batch_count = gr.Slider(
1,
Expand All @@ -620,17 +606,15 @@ def base_model_changed(base_model_id):
interactive=True,
visible=True,
)
repeatable_seeds = gr.Checkbox(
cmd_opts.repeatable_seeds,
label="Use Repeatable Seeds for Batches",
compiled_pipeline = gr.Checkbox(
False,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
stable_diffusion = gr.Button("Start")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
unload = gr.Button("Unload Models")
unload.click(
fn=unload_sd,
queue=False,
show_progress=False,
)
Expand Down Expand Up @@ -701,7 +685,7 @@ def base_model_changed(base_model_id):
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
Expand All @@ -718,6 +702,22 @@ def base_model_changed(base_model_id):
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)

pull_kwargs = dict(
fn=pull_sd_configs,
Expand All @@ -741,7 +741,7 @@ def base_model_changed(base_model_id):
device,
target_triple,
ondemand,
repeatable_seeds,
compiled_pipeline,
resample_type,
cnet_config,
embeddings_config,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torch==2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models
diffusers @ git+https://github.com/nod-ai/[email protected]
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b

# SHARK Runner
tqdm
Expand Down
Loading