Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding metal_utils for iree_utils #1561

Merged
merged 12 commits into from
Jun 22, 2023
7 changes: 7 additions & 0 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ def is_valid_file(arg):
help="Specify target triple for vulkan",
)

p.add_argument(
"--iree_metal_target_platfrom",
type=str,
default="",
help="Specify target triple for metal",
)

p.add_argument(
"--vulkan_debug_utils",
default=False,
Expand Down
12 changes: 12 additions & 0 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
Expand Down Expand Up @@ -274,6 +275,15 @@ def set_init_device_flags():
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platfrom:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platfrom = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}."
)
elif "cpu" in args.device:
args.device = "cpu"

Expand Down Expand Up @@ -431,6 +441,8 @@ def get_devices_by_name(driver_name):
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
Expand Down
2 changes: 2 additions & 0 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platfrom = args.iree_metal_target_platfrom
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir

Expand Down Expand Up @@ -137,6 +138,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platfrom = init_iree_metal_target_platfrom
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
Expand Down
6 changes: 3 additions & 3 deletions shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_supported_device_list():
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "level_zero",
}
Expand All @@ -84,7 +84,7 @@ def iree_target_map(device):
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "opencl-spirv",
}
Expand Down
10 changes: 8 additions & 2 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,12 +43,18 @@ def get_iree_device_args(device, extra_args=[]):
from shark.iree_utils.gpu_utils import get_iree_gpu_args

return get_iree_gpu_args()
if device_uri[0] in ["metal", "vulkan"]:
if device_uri[0] == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args

return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args

return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

Expand Down
121 changes: 121 additions & 0 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# All the iree_vulkan related functionalities go here.

from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag


def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
metal_device_list = [
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
]
if len(metal_device_list) == 0:
raise ValueError("No device name found in device dump!")
if len(metal_device_list) > 1:
print("Following devices found:")
for i, dname in enumerate(metal_device_list):
print(f"{i}. {dname}")
print(f"Choosing device: {metal_device_list[device_num]}")
return metal_device_list[device_num]


def get_os_name():
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform == "win32":
return "windows"
else:
print("Cannot detect OS type, defaulting to linux.")
return "linux"


def get_metal_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.

Args:
device_name (str): name of the hardware device to be used with vulkan

Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"

else:
triple = None
return triple


def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-metal-target-platform=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None

if device_name == "" or device_name == [] or device_name is None:
metal_device = get_metal_device_name(device_num=device_num)
else:
metal_device = device_name
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platfrom={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {metal_device}")
return None


def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]

res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
meatal_triple_flag = arg
break

if metal_triple_flag is None:
meatal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)

if meatal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag


def set_iree_metal_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return