Skip to content

Commit

Permalink
take all ireert calls out of studio flow
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 4, 2024
1 parent 4aa2d8b commit 67b438e
Showing 1 changed file with 63 additions and 62 deletions.
125 changes: 63 additions & 62 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def iree_target_map(device):


def get_available_devices():
return ['rocm', 'cpu']
def get_devices_by_name(driver_name):

device_list = []
Expand Down Expand Up @@ -225,65 +226,65 @@ def get_all_devices(driver_name):
return device_list_src


def get_device_mapping(driver, key_combination=3):
"""This method ensures consistent device ordering when choosing
specific devices for execution
Args:
driver (str): execution driver (vulkan, cuda, rocm, etc)
key_combination (int, optional): choice for mapping value for
device name.
1 : path
2 : name
3 : (name, path)
Defaults to 3.
Returns:
dict: map to possible device names user can input mapped to desired
combination of name/path.
"""

driver = iree_device_map(driver)
device_list = get_all_devices(driver)
device_map = dict()

def get_output_value(dev_dict):
if key_combination == 1:
return f"{driver}://{dev_dict['path']}"
if key_combination == 2:
return dev_dict["name"]
if key_combination == 3:
return dev_dict["name"], f"{driver}://{dev_dict['path']}"

# mapping driver name to default device (driver://0)
device_map[f"{driver}"] = get_output_value(device_list[0])
for i, device in enumerate(device_list):
# mapping with index
device_map[f"{driver}://{i}"] = get_output_value(device)
# mapping with full path
device_map[f"{driver}://{device['path']}"] = get_output_value(device)
return device_map


def get_opt_flags(model, precision="fp16"):
iree_flags = []
if len(cmd_opts.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
)
if "rocm" in cmd_opts.device:
from shark.iree_utils.gpu_utils import get_iree_rocm_args

rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
if cmd_opts.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if cmd_opts.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")

if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags
# def get_device_mapping(driver, key_combination=3):
# """This method ensures consistent device ordering when choosing
# specific devices for execution
# Args:
# driver (str): execution driver (vulkan, cuda, rocm, etc)
# key_combination (int, optional): choice for mapping value for
# device name.
# 1 : path
# 2 : name
# 3 : (name, path)
# Defaults to 3.
# Returns:
# dict: map to possible device names user can input mapped to desired
# combination of name/path.
# """

# driver = iree_device_map(driver)
# device_list = get_all_devices(driver)
# device_map = dict()

# def get_output_value(dev_dict):
# if key_combination == 1:
# return f"{driver}://{dev_dict['path']}"
# if key_combination == 2:
# return dev_dict["name"]
# if key_combination == 3:
# return dev_dict["name"], f"{driver}://{dev_dict['path']}"

# # mapping driver name to default device (driver://0)
# device_map[f"{driver}"] = get_output_value(device_list[0])
# for i, device in enumerate(device_list):
# # mapping with index
# device_map[f"{driver}://{i}"] = get_output_value(device)
# # mapping with full path
# device_map[f"{driver}://{device['path']}"] = get_output_value(device)
# return device_map


# def get_opt_flags(model, precision="fp16"):
# iree_flags = []
# if len(cmd_opts.iree_vulkan_target_triple) > 0:
# iree_flags.append(
# f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}"
# )
# if "rocm" in cmd_opts.device:
# from shark.iree_utils.gpu_utils import get_iree_rocm_args

# rocm_args = get_iree_rocm_args()
# iree_flags.extend(rocm_args)
# if cmd_opts.iree_constant_folding == False:
# iree_flags.append("--iree-opt-const-expr-hoisting=False")
# iree_flags.append(
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
# )
# if cmd_opts.data_tiling == False:
# iree_flags.append("--iree-opt-data-tiling=False")

# if "vae" not in model:
# # Due to lack of support for multi-reduce, we always collapse reduction
# # dims before dispatch formation right now.
# iree_flags += ["--iree-flow-collapse-reduction-dims"]
# return iree_flags

0 comments on commit 67b438e

Please sign in to comment.