Skip to content

Commit

Permalink
set task_topology_max_group to cpu_count
Browse files Browse the repository at this point in the history
by default. Can be overriden with a flag of the same str
  • Loading branch information
dan-garvey committed Jun 26, 2023
1 parent cdd505e commit 51ec733
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
11 changes: 10 additions & 1 deletion shark/iree_utils/cpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import subprocess
import platform
from shark.parser import shark_args


def get_cpu_count():
Expand Down Expand Up @@ -44,4 +45,12 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
cpu_count = (
get_cpu_count()
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [
f"--iree-llvmcpu-target-triple={target_triple}",
f"--task_topology_max_group_count={cpu_count}",
]
6 changes: 6 additions & 0 deletions shark/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,11 @@
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)

shark_args, unknown = parser.parse_known_args()

0 comments on commit 51ec733

Please sign in to comment.