Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding metal_utils for iree_utils #1561

Merged
merged 12 commits into from
Jun 22, 2023
4 changes: 2 additions & 2 deletions shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_supported_device_list():
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "level_zero",
}
Expand All @@ -84,7 +84,7 @@ def iree_target_map(device):
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "opencl-spirv",
}
Expand Down
8 changes: 7 additions & 1 deletion shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@ def get_iree_device_args(device, extra_args=[]):
from shark.iree_utils.gpu_utils import get_iree_gpu_args

return get_iree_gpu_args()
if device_uri[0] in ["metal", "vulkan"]:
if device_uri[0] == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args

return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args

return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

Expand Down
176 changes: 176 additions & 0 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright 2020 The Nod Team. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# All the iree_vulkan related functionalities go here.

from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag


def get_metal_device_name(device_num=0):
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wont work. I used

(shark.venv) ➜  web git:(main) ✗ sysctl -a | grep brand             
machdep.cpu.brand_string: Apple M1 Ultra

But I think Lei suggested using --dump_devices

vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]


def get_os_name():
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform == "win32":
return "windows"
else:
print("Cannot detect OS type, defaulting to linux.")
return "linux"


def get_metal_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.

Args:
device_name (str): name of the hardware device to be used with vulkan

Returns:
str or None: target triple or None if no match found for given name
"""
system_os = get_os_name()
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"

# Nvidia Targets
elif all(x in device_name for x in ("RTX", "2080")):
triple = f"turing-rtx2080-{system_os}"
elif all(x in device_name for x in ("A100", "SXM4")):
triple = f"ampere-a100-{system_os}"
elif all(x in device_name for x in ("RTX", "3090")):
triple = f"ampere-rtx3090-{system_os}"
elif all(x in device_name for x in ("RTX", "3080")):
triple = f"ampere-rtx3080-{system_os}"
elif all(x in device_name for x in ("RTX", "3070")):
triple = f"ampere-rtx3070-{system_os}"
elif all(x in device_name for x in ("RTX", "3060")):
triple = f"ampere-rtx3060-{system_os}"
elif all(x in device_name for x in ("RTX", "3050")):
triple = f"ampere-rtx3050-{system_os}"
# We use ampere until lovelace target triples are plumbed in.
elif all(x in device_name for x in ("RTX", "4090")):
triple = f"ampere-rtx4090-{system_os}"
elif all(x in device_name for x in ("RTX", "4080")):
triple = f"ampere-rtx4080-{system_os}"
elif all(x in device_name for x in ("RTX", "4070")):
triple = f"ampere-rtx4070-{system_os}"
elif all(x in device_name for x in ("RTX", "4000")):
triple = f"turing-rtx4000-{system_os}"
elif all(x in device_name for x in ("RTX", "5000")):
triple = f"turing-rtx5000-{system_os}"
elif all(x in device_name for x in ("RTX", "6000")):
triple = f"turing-rtx6000-{system_os}"
elif all(x in device_name for x in ("RTX", "8000")):
triple = f"turing-rtx8000-{system_os}"
elif all(x in device_name for x in ("TITAN", "RTX")):
triple = f"turing-titanrtx-{system_os}"
elif all(x in device_name for x in ("GTX", "1060")):
triple = f"pascal-gtx1060-{system_os}"
elif all(x in device_name for x in ("GTX", "1070")):
triple = f"pascal-gtx1070-{system_os}"
elif all(x in device_name for x in ("GTX", "1080")):
triple = f"pascal-gtx1080-{system_os}"

# Amd Targets
# Linux: Radeon RX 7900 XTX
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
triple = f"rdna2-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"

# Adreno Targets
elif all(x in device_name for x in ("Adreno", "740")):
triple = f"adreno-a740-{system_os}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove everything but the apple targets.

else:
triple = None
return triple


def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-metal-target-triple=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None

if device_name == "" or device_name == [] or device_name is None:
metal_device = get_metal_device_name(device_num=device_num)
else:
metal_device = device_name
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-triple={triple}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be -iree-metal-target-platform=

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, missed that one
Will fix it, and the vulkaninfo

print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {metal_device}")
return None

def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]

res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-triple=" in arg:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix that

print(f"Using target triple {arg} from command line args")
meatal_triple_flag = arg
break

if metal_triple_flag is None:
meatal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)

if meatal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag


def set_iree_metal_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return