From 7e57c8394b3c33c1ada2c3a1510242e6f62c6717 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 30 May 2024 17:33:38 -0500 Subject: [PATCH] Formatting --- apps/shark_studio/api/llm.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 5207002c8d..f6d33adcb6 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -3,7 +3,10 @@ from turbine_models.gen_external_params.gen_external_params import gen_external_params import time from shark.iree_utils.compile_utils import compile_module_to_flatbuffer -from apps.shark_studio.web.utils.file_utils import get_resource_path, get_checkpoints_path +from apps.shark_studio.web.utils.file_utils import ( + get_resource_path, + get_checkpoints_path, +) from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from apps.shark_studio.api.utils import parse_device from urllib.request import urlopen @@ -199,17 +202,21 @@ def compile(self) -> None: elif self.backend == "vulkan": flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"]) elif self.backend == "rocm": - flags.extend([ - "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", - "--iree-llvmgpu-enable-prefetch=true", - "--iree-opt-outer-dim-concat=true", - "--iree-flow-enable-aggressive-fusion", - ]) + flags.extend( + [ + "--iree-codegen-llvmgpu-enable-transform-dialect-jit=false", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-outer-dim-concat=true", + "--iree-flow-enable-aggressive-fusion", + ] + ) if "gfx9" in self.triple: - flags.extend([ - f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}", - "--iree-codegen-llvmgpu-use-vector-distribution=true" - ]) + flags.extend( + [ + f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}", + "--iree-codegen-llvmgpu-use-vector-distribution=true", + ] + ) flags.extend(llm_model_map[self.hf_model_name]["compile_flags"]) flatbuffer_blob = compile_module_to_flatbuffer( self.tempfile_name, @@ -344,6 +351,7 @@ def chat_hf(self, prompt): self.global_iter += 1 return result_output, total_time + def get_mfma_spec_path(target_chip, save_dir): url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir" attn_spec = urlopen(url).read().decode("utf-8") @@ -354,6 +362,7 @@ def get_mfma_spec_path(target_chip, save_dir): f.write(attn_spec) return spec_path + def llm_chat_api(InputData: dict): from datetime import datetime as dt