Skip to content

Commit

Permalink
fix formatting and disable explicit vulkan env settings.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Mar 29, 2024
1 parent 73765cd commit 9f59a16
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 100 deletions.
2 changes: 1 addition & 1 deletion apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def dumpstacks():
if line:
code.append(" " + line.strip())
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
f.write("\n".join(code))
f.write("\n".join(code))


def setup_middleware(app):
Expand Down
3 changes: 2 additions & 1 deletion apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def chat_hf(self, prompt):
self.global_iter += 1
return result_output, total_time


def llm_chat_api(InputData: dict):
from datetime import datetime as dt

Expand Down Expand Up @@ -392,7 +393,6 @@ def llm_chat_api(InputData: dict):
print("prompt = ", prompt)

for res_op, _ in llm_model.chat(prompt):

if is_chat_completion_api:
choices = [
{
Expand Down Expand Up @@ -421,6 +421,7 @@ def llm_chat_api(InputData: dict):
"choices": choices,
}


if __name__ == "__main__":
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
Expand Down
4 changes: 1 addition & 3 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def __init__(
"clip": {"hf_model_name": base_model_id},
"unet": {
"hf_model_name": base_model_id,
"unet_model": unet.UnetModel(
hf_model_name=base_model_id
),
"unet_model": unet.UnetModel(hf_model_name=base_model_id),
"batch_size": batch_size,
# "is_controlled": is_controlled,
# "num_loras": num_loras,
Expand Down
4 changes: 4 additions & 0 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_devices_by_name(driver_name):
available_devices.extend(cpu_device)
return available_devices


def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
Expand Down Expand Up @@ -109,6 +110,7 @@ def set_init_device_flags():
elif "cpu" in cmd_opts.device:
cmd_opts.device = "cpu"


def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
Expand Down Expand Up @@ -177,6 +179,7 @@ def get_output_value(dev_dict):
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map


def get_opt_flags(model, precision="fp16"):
iree_flags = []
if len(cmd_opts.iree_vulkan_target_triple) > 0:
Expand All @@ -202,6 +205,7 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags


def map_device_to_name_path(device, key_combination=3):
"""Gives the appropriate device data (supported name/path) for user
selected execution device
Expand Down
2 changes: 2 additions & 0 deletions apps/shark_studio/modules/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
)

weights_path = self.get_io_params(submodel)
if weights_path:
ireec_flags.append("--iree-opt-const-eval=False")

self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
Expand Down
6 changes: 2 additions & 4 deletions apps/shark_studio/studio_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@
# hidden imports for pyinstaller
hiddenimports = ["shark", "apps"]
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
blacklist = ["tests", "convert"]
hiddenimports += [
x
Expand All @@ -67,4 +65,4 @@
]
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
hiddenimports += ["iree._runtime"]
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
1 change: 0 additions & 1 deletion apps/shark_studio/tests/rest_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def llm_chat_test(verbose=False):


if __name__ == "__main__":

# "Exercises the chatbot REST API of Shark. Make sure "
# "Shark is running in API mode on 127.0.0.1:8080 before running"
# "this script."
Expand Down
1 change: 0 additions & 1 deletion apps/shark_studio/web/api/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

def decode_base64_to_image(encoding):
if encoding.startswith("http://") or encoding.startswith("https://"):

headers = {}
response = requests.get(encoding, timeout=30, headers=headers)
try:
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from multiprocessing import Process, freeze_support

freeze_support()
from PIL import Image

Expand Down
4 changes: 3 additions & 1 deletion apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,9 @@ def base_model_changed(base_model_id):
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
Expand Down
2 changes: 2 additions & 0 deletions apps/shark_studio/web/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@
"embeddings": {}
}"""


def write_default_sd_config(path):
with open(path, "w") as f:
f.write(default_sd_config)


def safe_name(name):
return name.replace("/", "_").replace("-", "_")

Expand Down
2 changes: 1 addition & 1 deletion shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args(debug=False):
common_args = [
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
"--mlir-elide-elementsattrs-if-larger=10",
]
if debug == True:
common_args.extend(
Expand Down
Loading

0 comments on commit 9f59a16

Please sign in to comment.