From 18ecd61cce5707d93553b679069c840482067513 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 30 May 2024 18:30:40 -0400 Subject: [PATCH] Tweaks to chatbot --- apps/shark_studio/api/llm.py | 29 +++++++++++++++++++++++++++-- apps/shark_studio/web/ui/chat.py | 1 + 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 217fb6784f..5207002c8d 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -3,8 +3,10 @@ from turbine_models.gen_external_params.gen_external_params import gen_external_params import time from shark.iree_utils.compile_utils import compile_module_to_flatbuffer -from apps.shark_studio.web.utils.file_utils import get_resource_path +from apps.shark_studio.web.utils.file_utils import get_resource_path, get_checkpoints_path from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.api.utils import parse_device +from urllib.request import urlopen import iree.runtime as ireert from itertools import chain import gc @@ -65,6 +67,7 @@ def __init__( use_system_prompt=True, streaming_llm=False, ): + _, _, self.triple = parse_device(device) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] self.device = device.split("=>")[-1].strip() self.backend = self.device.split("://")[0] @@ -165,6 +168,7 @@ def __init__( precision=self.precision, quantization=self.quantization, streaming_llm=self.streaming_llm, + decomp_attn=True, ) with open(self.tempfile_name, "w+") as f: f.write(self.torch_ir) @@ -194,11 +198,23 @@ def compile(self) -> None: ) elif self.backend == "vulkan": flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) + elif self.backend == "rocm": + flags.extend([ + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-outer-dim-concat=true", + "--iree-flow-enable-aggressive-fusion", + ]) + if "gfx9" in self.triple: + flags.extend([ + f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}", + "--iree-codegen-llvmgpu-use-vector-distribution=true" + ]) flags.extend(llm_model_map[self.hf_model_name]["compile_flags"]) flatbuffer_blob = compile_module_to_flatbuffer( self.tempfile_name, device=self.device, - frontend="torch", + frontend="auto", model_config_path=None, extra_args=flags, write_to=self.vmfb_name, @@ -328,6 +344,15 @@ def chat_hf(self, prompt): self.global_iter += 1 return result_output, total_time +def get_mfma_spec_path(target_chip, save_dir): + url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" + attn_spec = urlopen(url).read().decode("utf-8") + spec_path = os.path.join(save_dir, "attention_and_matmul_spec_mfma.mlir") + if os.path.exists(spec_path): + return spec_path + with open(spec_path, "w") as f: + f.write(attn_spec) + return spec_path def llm_chat_api(InputData: dict): from datetime import datetime as dt diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 54ae4a139f..cad9f4cb00 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -138,6 +138,7 @@ def view_json_file(file_obj): label="Run in streaming mode (requires recompilation)", value=True, interactive=False, + visible=False, ) prompt_prefix = gr.Checkbox( label="Add System Prompt",