Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Shark Exe #1614

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added .dmg_config.py.swp
Binary file not shown.
Binary file added apps/.DS_Store
Binary file not shown.
5 changes: 5 additions & 0 deletions apps/stable_diffusion/shark_sd.spec
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ datas += [
( 'web/ui/logos/*', 'logos' )
]

torch_path = sys.prefix + '/lib/python3.11/site-packages/torch/lib/libtorch_python.dylib'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work on Windows. I think you can just add to copy_data()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the spec file already has copy_metadata('torch')


binaries = []
binaries += [
(torch_path, 'resources')
]

block_cipher = None

Expand Down
17 changes: 7 additions & 10 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env


def get_metal_device_name(device_num=0):
Expand Down Expand Up @@ -58,14 +58,7 @@ def get_metal_target_triple(device_name):
str or None: target triple or None if no match found for given name
"""
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
triple = "m1-moltenvk-macos"
elif all(x in device_name for x in ("Apple", "M2")):
triple = "m1-moltenvk-macos"

else:
triple = None
return triple
return "macos"


def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
Expand Down Expand Up @@ -110,10 +103,14 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
)

if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
vulkan_target_env = get_metal_target_env_flag("=m1-moltenvk-macos")
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

def get_metal_target_env_flag(metal_target_triple):
target_env = get_vulkan_target_env(metal_target_triple)
target_env_flag = f"--iree-metal-target-env={target_env}"
return target_env_flag

def set_iree_metal_runtime_flags(flags):
for flag in flags:
Expand Down
4 changes: 3 additions & 1 deletion shark/shark_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def get_git_revision_short_hash() -> str:
else:
import json

dir_path = os.path.dirname(os.path.realpath(__file__))
dir_path = os.path.dirname(os.path.abspath(__file__))
print(f"\n\nIn get_git_revision_short_hash: {dir_path}\n\n")
src = os.path.join(dir_path, "..", "tank_version.json")
with open(src, "r") as f:
data = json.loads(f.read())
Expand Down Expand Up @@ -221,6 +222,7 @@ def download_model(
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
print(f"\n\nDownlaod_model, dir_name: {model_dir}\n\n")

if not tank_url:
tank_url = "gs://shark_tank/" + shark_args.shark_prefix
Expand Down
Binary file added tank/.DS_Store
Binary file not shown.
Binary file added tank/examples/.DS_Store
Binary file not shown.
Binary file added tank/tflite/.DS_Store
Binary file not shown.
Loading