Skip to content

Commit

Permalink
Tweaks to chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 30, 2024
1 parent 222f387 commit 18ecd61
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
29 changes: 27 additions & 2 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import get_resource_path
from apps.shark_studio.web.utils.file_utils import get_resource_path, get_checkpoints_path
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
import iree.runtime as ireert
from itertools import chain
import gc
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(
use_system_prompt=True,
streaming_llm=False,
):
_, _, self.triple = parse_device(device)
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.device = device.split("=>")[-1].strip()
self.backend = self.device.split("://")[0]
Expand Down Expand Up @@ -165,6 +168,7 @@ def __init__(
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
Expand Down Expand Up @@ -194,11 +198,23 @@ def compile(self) -> None:
)
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend([
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
])
if "gfx9" in self.triple:
flags.extend([
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true"
])
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
device=self.device,
frontend="torch",
frontend="auto",
model_config_path=None,
extra_args=flags,
write_to=self.vmfb_name,
Expand Down Expand Up @@ -328,6 +344,15 @@ def chat_hf(self, prompt):
self.global_iter += 1
return result_output, total_time

def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir")
if os.path.exists(spec_path):
return spec_path
with open(spec_path, "w") as f:
f.write(attn_spec)
return spec_path

def llm_chat_api(InputData: dict):
from datetime import datetime as dt
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def view_json_file(file_obj):
label="Run in streaming mode (requires recompilation)",
value=True,
interactive=False,
visible=False,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
Expand Down

0 comments on commit 18ecd61

Please sign in to comment.