diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 78ee1ca6d5..ddb015ca50 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -14,6 +14,7 @@ import iree.runtime as ireert import iree.compiler as ireec from shark.iree_utils._common import iree_device_map, iree_target_map +from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args from shark.iree_utils.benchmark_utils import * from shark.parser import shark_args import numpy as np @@ -345,6 +346,12 @@ def load_vmfb_using_mmap( haldriver.query_available_devices()[device_idx]["device_id"], allocators=shark_args.device_allocator, ) + if "local_task" in device: + print( + f"[DEBUG] setting iree runtime flags for cpu.\n{' '.join(get_iree_cpu_rt_args())}" + ) + for flag in get_iree_cpu_rt_args(): + ireert.flags.parse_flags(flag) config = ireert.Config(device=haldevice) else: config = get_iree_runtime_config(device) diff --git a/shark/iree_utils/cpu_utils.py b/shark/iree_utils/cpu_utils.py index 38c294db27..f553c92c17 100644 --- a/shark/iree_utils/cpu_utils.py +++ b/shark/iree_utils/cpu_utils.py @@ -16,6 +16,7 @@ import subprocess import platform +from shark.parser import shark_args def get_cpu_count(): @@ -44,4 +45,16 @@ def get_iree_cpu_args(): error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)" raise Exception(error_message) print(f"Target triple found:{target_triple}") - return [f"--iree-llvmcpu-target-triple={target_triple}"] + return [ + f"--iree-llvmcpu-target-triple={target_triple}", + ] + + +# Get iree runtime flags for cpu +def get_iree_cpu_rt_args(): + cpu_count = ( + get_cpu_count() + if shark_args.task_topology_max_group_count is None + else shark_args.task_topology_max_group_count + ) + return [f"--task_topology_max_group_count={cpu_count}"] diff --git a/shark/parser.py b/shark/parser.py index 47cc1a86df..7c51290ecf 100644 --- a/shark/parser.py +++ b/shark/parser.py @@ -119,5 +119,11 @@ "to augment the base device allocator", choices=["debug", "caching"], ) +parser.add_argument( + "--task_topology_max_group_count", + type=str, + default=None, + help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count", +) shark_args, unknown = parser.parse_known_args()