Skip to content

Commit

Permalink
set task_topology_max_group to cpu_count
Browse files Browse the repository at this point in the history
by default. Can be overriden with a flag of the same str
  • Loading branch information
dan-garvey committed Jun 26, 2023
1 parent cdd505e commit 89baefa
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
7 changes: 7 additions & 0 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
Expand Down Expand Up @@ -345,6 +346,12 @@ def load_vmfb_using_mmap(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
if "local_task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu.\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
Expand Down
15 changes: 14 additions & 1 deletion shark/iree_utils/cpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import subprocess
import platform
from shark.parser import shark_args


def get_cpu_count():
Expand Down Expand Up @@ -44,4 +45,16 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [
f"--iree-llvmcpu-target-triple={target_triple}",
]


# Get iree runtime flags for cpu
def get_iree_cpu_rt_args():
cpu_count = (
get_cpu_count()
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [f"--task_topology_max_group_count={cpu_count}"]
6 changes: 6 additions & 0 deletions shark/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,11 @@
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)

shark_args, unknown = parser.parse_known_args()

0 comments on commit 89baefa

Please sign in to comment.