diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 1a03b817ff..a209d8d1ba 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -106,6 +106,7 @@ def compile(self) -> None: frontend="torch", external_weight_file=self.external_weight_file, write_to=self.vmfb_name, + extra_args=["--iree-global-opt-enable-quantized-matmul-reassociation"], ) # TODO: delete the temp file diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index bae1908e1c..ca6a12c45b 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -43,7 +43,6 @@ def get_iree_device_args(device, extra_args=[]): get_iree_cpu_args() + u_kernel_flag + stack_size_flag - + ["--iree-global-opt-enable-quantized-matmul-reassociation"] ) if device == "cuda": from shark.iree_utils.gpu_utils import get_iree_gpu_args