From c00b0da35a13cb956d1a475ee4c589d3c2d570fc Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 25 Jun 2023 20:55:40 -0500 Subject: [PATCH 1/2] set task_topology_max_group to cpu_count by default. Can be overriden with a flag of the same str --- shark/iree_utils/compile_utils.py | 8 +++++++- shark/iree_utils/cpu_utils.py | 17 ++++++++++++++++- shark/parser.py | 6 ++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index a05bfc89c6..b0e6d042a2 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -14,6 +14,7 @@ import iree.runtime as ireert import iree.compiler as ireec from shark.iree_utils._common import iree_device_map, iree_target_map +from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args from shark.iree_utils.benchmark_utils import * from shark.parser import shark_args import numpy as np @@ -352,6 +353,12 @@ def load_vmfb_using_mmap( config = ireert.Config(device=haldevice) else: config = get_iree_runtime_config(device) + if "task" in device: + print( + f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}" + ) + for flag in get_iree_cpu_rt_args(): + ireert.flags.parse_flags(flag) # Now load vmfb. # Two scenarios we have here :- # 1. We either have the vmfb already saved and therefore pass the path of it. @@ -359,7 +366,6 @@ def load_vmfb_using_mmap( # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. # (This would arise if we're invoking `compile` from a SharkInference obj) temp_file_to_unlink = None - if isinstance(flatbuffer_blob_or_path, Path): flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() if ( diff --git a/shark/iree_utils/cpu_utils.py b/shark/iree_utils/cpu_utils.py index 38c294db27..182b8c8c48 100644 --- a/shark/iree_utils/cpu_utils.py +++ b/shark/iree_utils/cpu_utils.py @@ -16,6 +16,7 @@ import subprocess import platform +from shark.parser import shark_args def get_cpu_count(): @@ -44,4 +45,18 @@ def get_iree_cpu_args(): error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)" raise Exception(error_message) print(f"Target triple found:{target_triple}") - return [f"--iree-llvmcpu-target-triple={target_triple}"] + return [ + f"--iree-llvmcpu-target-triple={target_triple}", + ] + + +# Get iree runtime flags for cpu +def get_iree_cpu_rt_args(): + default = get_cpu_count() + default = default if default <= 8 else default - 2 + cpu_count = ( + default + if shark_args.task_topology_max_group_count is None + else shark_args.task_topology_max_group_count + ) + return [f"--task_topology_max_group_count={cpu_count}"] diff --git a/shark/parser.py b/shark/parser.py index 47cc1a86df..7c51290ecf 100644 --- a/shark/parser.py +++ b/shark/parser.py @@ -119,5 +119,11 @@ "to augment the base device allocator", choices=["debug", "caching"], ) +parser.add_argument( + "--task_topology_max_group_count", + type=str, + default=None, + help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count", +) shark_args, unknown = parser.parse_known_args() From 5c06a202d4e2b7ddf77ed47d1937031696508de2 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 29 Jun 2023 14:47:23 -0500 Subject: [PATCH 2/2] add download for int4/int8 mlir --- apps/language_models/src/pipelines/vicuna_pipeline.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 60cf258f4e..1bef8508ac 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -38,9 +38,6 @@ def __init__( super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device - if precision in ["int4", "int8"]: - print("int4 and int8 are not supported yet, using fp32") - precision = "fp32" self.precision = precision self.first_vicuna_vmfb_path = first_vicuna_vmfb_path self.second_vicuna_vmfb_path = second_vicuna_vmfb_path @@ -103,7 +100,7 @@ def compile_first_vicuna(self): else: mlir_generated = False if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16"]: + if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank for fp32/fp16 download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", @@ -245,7 +242,7 @@ def compile_second_vicuna(self): else: mlir_generated = False if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16"]: + if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank for fp32/fp16 download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",