Skip to content

Commit

Permalink
Adding metal_utils for iree_utils (#1561)
Browse files Browse the repository at this point in the history
* Adding metal_utils for iree_utils

* Add patch for making compile API work for both MEGABYTE and MiniGPT4 (#1559)

-- It also modifies the mega_test.py script

Signed-off-by: Abhishek Varma <[email protected]>

* [SD] Update unet in_channels API and add PIL metadata to spec. (#1560)

* Fix deprecation warning for unet config.

* Include PIL metadata instead of hidden imports in SD spec.

* Fixing iree-metal-target-platform

* adding metal to txt2img pipeline

* Fixing Copyright date

* removing debug prints

* black lint formating

* fixing device dump

---------

Signed-off-by: Abhishek Varma <[email protected]>
Co-authored-by: Abhishek Varma <[email protected]>
Co-authored-by: Ean Garvey <[email protected]>
Co-authored-by: powderluv <[email protected]>
  • Loading branch information
4 people authored Jun 22, 2023
1 parent 18daec7 commit 07c1e1d
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 5 deletions.
7 changes: 7 additions & 0 deletions apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ def is_valid_file(arg):
help="Specify target triple for vulkan",
)

p.add_argument(
"--iree_metal_target_platfrom",
type=str,
default="",
help="Specify target triple for metal",
)

p.add_argument(
"--vulkan_debug_utils",
default=False,
Expand Down
12 changes: 12 additions & 0 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
Expand Down Expand Up @@ -274,6 +275,15 @@ def set_init_device_flags():
)
elif "cuda" in args.device:
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platfrom:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platfrom = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}."
)
elif "cpu" in args.device:
args.device = "cpu"

Expand Down Expand Up @@ -431,6 +441,8 @@ def get_devices_by_name(driver_name):
available_devices = []
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
Expand Down
2 changes: 2 additions & 0 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platfrom = args.iree_metal_target_platfrom
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir

Expand Down Expand Up @@ -137,6 +138,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platfrom = init_iree_metal_target_platfrom
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
Expand Down
6 changes: 3 additions & 3 deletions shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_supported_device_list():
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "level_zero",
}
Expand All @@ -84,7 +84,7 @@ def iree_target_map(device):
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
"metal": "vulkan",
"metal": "metal",
"rocm": "rocm",
"intel-gpu": "opencl-spirv",
}
Expand Down
10 changes: 8 additions & 2 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 The Nod Team. All rights reserved.
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,12 +43,18 @@ def get_iree_device_args(device, extra_args=[]):
from shark.iree_utils.gpu_utils import get_iree_gpu_args

return get_iree_gpu_args()
if device_uri[0] in ["metal", "vulkan"]:
if device_uri[0] == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args

return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args

return get_iree_metal_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args

Expand Down
121 changes: 121 additions & 0 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# All the iree_vulkan related functionalities go here.

from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag


def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
metal_device_list = [
s.split("\n#")[2] for s in iree_device_dump if "--device=metal" in s
]
if len(metal_device_list) == 0:
raise ValueError("No device name found in device dump!")
if len(metal_device_list) > 1:
print("Following devices found:")
for i, dname in enumerate(metal_device_list):
print(f"{i}. {dname}")
print(f"Choosing device: {metal_device_list[device_num]}")
return metal_device_list[device_num]


def get_os_name():
if platform.startswith("linux"):
return "linux"
elif platform == "darwin":
return "macos"
elif platform == "win32":
return "windows"
else:
print("Cannot detect OS type, defaulting to linux.")
return "linux"


def get_metal_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
Args:
device_name (str): name of the hardware device to be used with vulkan
Returns:
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"

else:
triple = None
return triple


def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
for flag in extra_args:
if "-iree-metal-target-platform=" in flag:
print(f"Using target triple {flag.split('=')[1]}")
return None

if device_name == "" or device_name == [] or device_name is None:
metal_device = get_metal_device_name(device_num=device_num)
else:
metal_device = device_name
triple = get_metal_target_triple(metal_device)
if triple is not None:
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platfrom={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
or pull up an issue."""
)
print(f"Target : {metal_device}")
return None


def get_iree_metal_args(device_num=0, extra_args=[]):
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]

res_metal_flag = []
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
meatal_triple_flag = arg
break

if metal_triple_flag is None:
meatal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)

if meatal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag


def set_iree_metal_runtime_flags(flags):
for flag in flags:
ireert.flags.parse_flags(flag)
return

0 comments on commit 07c1e1d

Please sign in to comment.