Skip to content

Commit

Permalink
add support for external weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey authored and monorimet committed Dec 12, 2023
1 parent b197df6 commit 93dab60
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,110 @@ def chat_fn(
history[-1][-1] = "Getting the model ready... Done"
yield history, ""
history[-1][-1] = ""
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")

from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args

new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512

# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)

if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add

if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1

assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")

elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")

if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)

if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")

prompt = create_prompt(model_name, history, prompt_prefix)

partial_text = ""
token_count = 0
total_time = 0.001 # In order to avoid divide by zero error
prefill_time = 0
Expand Down

0 comments on commit 93dab60

Please sign in to comment.