Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed May 30, 2024
1 parent 18ecd61 commit 7e57c83
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
from shark.iree_utils.compile_utils import compile_module_to_flatbuffer
from apps.shark_studio.web.utils.file_utils import get_resource_path, get_checkpoints_path
from apps.shark_studio.web.utils.file_utils import (
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from apps.shark_studio.api.utils import parse_device
from urllib.request import urlopen
Expand Down Expand Up @@ -199,17 +202,21 @@ def compile(self) -> None:
elif self.backend == "vulkan":
flags.extend(["--iree-stream-resource-max-allocation-size=4294967296"])
elif self.backend == "rocm":
flags.extend([
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
])
flags.extend(
[
"--iree-codegen-llvmgpu-enable-transform-dialect-jit=false",
"--iree-llvmgpu-enable-prefetch=true",
"--iree-opt-outer-dim-concat=true",
"--iree-flow-enable-aggressive-fusion",
]
)
if "gfx9" in self.triple:
flags.extend([
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true"
])
flags.extend(
[
f"--iree-codegen-transform-dialect-library={get_mfma_spec_path(self.triple, get_checkpoints_path())}",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
]
)
flags.extend(llm_model_map[self.hf_model_name]["compile_flags"])
flatbuffer_blob = compile_module_to_flatbuffer(
self.tempfile_name,
Expand Down Expand Up @@ -344,6 +351,7 @@ def chat_hf(self, prompt):
self.global_iter += 1
return result_output, total_time


def get_mfma_spec_path(target_chip, save_dir):
url = "https://raw.githubusercontent.com/iree-org/iree/main/build_tools/pkgci/external_test_suite/attention_and_matmul_spec.mlir"
attn_spec = urlopen(url).read().decode("utf-8")
Expand All @@ -354,6 +362,7 @@ def get_mfma_spec_path(target_chip, save_dir):
f.write(attn_spec)
return spec_path


def llm_chat_api(InputData: dict):
from datetime import datetime as dt

Expand Down

0 comments on commit 7e57c83

Please sign in to comment.