diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index a05bfc89c6..b0e6d042a2 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -14,6 +14,7 @@ import iree.runtime as ireert import iree.compiler as ireec from shark.iree_utils._common import iree_device_map, iree_target_map +from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args from shark.iree_utils.benchmark_utils import * from shark.parser import shark_args import numpy as np @@ -352,6 +353,12 @@ def load_vmfb_using_mmap( config = ireert.Config(device=haldevice) else: config = get_iree_runtime_config(device) + if "task" in device: + print( + f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}" + ) + for flag in get_iree_cpu_rt_args(): + ireert.flags.parse_flags(flag) # Now load vmfb. # Two scenarios we have here :- # 1. We either have the vmfb already saved and therefore pass the path of it. @@ -359,7 +366,6 @@ def load_vmfb_using_mmap( # OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with. # (This would arise if we're invoking `compile` from a SharkInference obj) temp_file_to_unlink = None - if isinstance(flatbuffer_blob_or_path, Path): flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__() if ( diff --git a/shark/iree_utils/cpu_utils.py b/shark/iree_utils/cpu_utils.py index 38c294db27..182b8c8c48 100644 --- a/shark/iree_utils/cpu_utils.py +++ b/shark/iree_utils/cpu_utils.py @@ -16,6 +16,7 @@ import subprocess import platform +from shark.parser import shark_args def get_cpu_count(): @@ -44,4 +45,18 @@ def get_iree_cpu_args(): error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)" raise Exception(error_message) print(f"Target triple found:{target_triple}") - return [f"--iree-llvmcpu-target-triple={target_triple}"] + return [ + f"--iree-llvmcpu-target-triple={target_triple}", + ] + + +# Get iree runtime flags for cpu +def get_iree_cpu_rt_args(): + default = get_cpu_count() + default = default if default <= 8 else default - 2 + cpu_count = ( + default + if shark_args.task_topology_max_group_count is None + else shark_args.task_topology_max_group_count + ) + return [f"--task_topology_max_group_count={cpu_count}"] diff --git a/shark/parser.py b/shark/parser.py index 47cc1a86df..7c51290ecf 100644 --- a/shark/parser.py +++ b/shark/parser.py @@ -119,5 +119,11 @@ "to augment the base device allocator", choices=["debug", "caching"], ) +parser.add_argument( + "--task_topology_max_group_count", + type=str, + default=None, + help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count", +) shark_args, unknown = parser.parse_known_args()