Skip to content

Commit

Permalink
set task_topology_max_group to cpu_count
Browse files Browse the repository at this point in the history
by default. Can be overriden with a flag of the same str
  • Loading branch information
dan-garvey committed Jun 26, 2023
1 parent 74a7202 commit 0a486a0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
8 changes: 7 additions & 1 deletion shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
Expand Down Expand Up @@ -352,14 +353,19 @@ def load_vmfb_using_mmap(
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None

if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
Expand Down
15 changes: 14 additions & 1 deletion shark/iree_utils/cpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import subprocess
import platform
from shark.parser import shark_args


def get_cpu_count():
Expand Down Expand Up @@ -44,4 +45,16 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [
f"--iree-llvmcpu-target-triple={target_triple}",
]


# Get iree runtime flags for cpu
def get_iree_cpu_rt_args():
cpu_count = (
get_cpu_count()
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [f"--task_topology_max_group_count={cpu_count}"]
6 changes: 6 additions & 0 deletions shark/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,11 @@
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)

shark_args, unknown = parser.parse_known_args()

0 comments on commit 0a486a0

Please sign in to comment.