Skip to content

Commit

Permalink
Merge branch 'main' into manual_annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
nithinsubbiah committed Jun 22, 2023
2 parents 382e70f + a811b86 commit d7d7b01
Show file tree
Hide file tree
Showing 15 changed files with 388 additions and 46 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/test-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ jobs:
include:
- os: ubuntu-latest
suite: lint
- os: MacStudio
suite: metal
exclude:
- os: ubuntu-latest
suite: vulkan
Expand All @@ -46,6 +48,8 @@ jobs:
suite: cuda
- os: MacStudio
suite: cpu
- os: MacStudio
suite: vulkan
- os: icelake
suite: vulkan
- os: icelake
Expand Down Expand Up @@ -125,15 +129,14 @@ jobs:
# python build_tools/stable_diffusion_testing.py --device=cuda
- name: Validate Vulkan Models (MacOS)
if: matrix.suite == 'vulkan' && matrix.os == 'MacStudio'
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb

# C extensions
*.so
Expand Down
28 changes: 23 additions & 5 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
first_vicuna_mlir_path=None,
second_vicuna_mlir_path=None,
first_vicuna_vmfb_path=None,
second_vicuna_vmfb_path=None,
load_mlir_from_shark_tank=True,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
Expand All @@ -42,9 +42,27 @@ def __init__(
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
if self.first_vicuna_mlir_path == None:
self.first_vicuna_mlir_path = self.get_model_path()
if self.second_vicuna_mlir_path == None:
self.second_vicuna_mlir_path = self.get_model_path("second")
if self.first_vicuna_vmfb_path == None:
self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
if self.second_vicuna_vmfb_path == None:
self.second_vicuna_vmfb_path = self.get_model_path(
"second", "vmfb"
)
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank

def get_model_path(self, model_number="first", suffix="mlir"):
safe_device = "_".join(self.device.split("-"))
if suffix == "mlir":
return Path(f"{model_number}_vicuna_{self.precision}.{suffix}")
return Path(
f"{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}"
)

def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def is_valid_file(arg):
)

p.add_argument(
"--iree_metal_target_platfrom",
"--iree_metal_target_platform",
type=str,
default="",
help="Specify target triple for metal",
Expand Down
11 changes: 7 additions & 4 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,12 @@ def set_init_device_flags():
args.device = "cuda"
elif "metal" in args.device:
device_name, args.device = map_device_to_name_path(args.device)
if not args.iree_metal_target_platfrom:
if not args.iree_metal_target_platform:
triple = get_metal_target_triple(device_name)
if triple is not None:
args.iree_metal_target_platfrom = triple
args.iree_metal_target_platform = triple
print(
f"Found device {device_name}. Using target triple {args.iree_metal_target_platfrom}."
f"Found device {device_name}. Using target triple {args.iree_metal_target_platform}."
)
elif "cpu" in args.device:
args.device = "cpu"
Expand Down Expand Up @@ -445,7 +445,10 @@ def get_devices_by_name(driver_name):
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices


Expand Down
5 changes: 5 additions & 0 deletions apps/stable_diffusion/web/index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from multiprocessing import Process, freeze_support
import os
import sys

if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
import torch_mlir

import shutil
import PIL, transformers # ensures inclusion in pysintaller exe generation
from apps.stable_diffusion.src import args, clear_all
Expand Down
16 changes: 9 additions & 7 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,21 @@ def chat(curr_system_message, history, model, device, precision):

curr_system_message = start_message_vicuna
if vicuna_model == 0:
first_vic_vmfb_path = Path("first_vicuna.vmfb")
second_vic_vmfb_path = Path("second_vicuna.vmfb")
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = Vicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
)
messages = curr_system_message + "".join(
[
Expand Down Expand Up @@ -120,9 +124,7 @@ def chat(curr_system_message, history, model, device, precision):
"TheBloke/vicuna-7B-1.1-HF",
],
)
supported_devices = [
device for device in available_devices if "cuda" in device
]
supported_devices = available_devices
enabled = len(supported_devices) > 0
device = gr.Dropdown(
label="Device",
Expand Down
4 changes: 2 additions & 2 deletions apps/stable_diffusion/web/ui/txt2img_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platfrom = args.iree_metal_target_platfrom
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir

Expand Down Expand Up @@ -138,7 +138,7 @@ def txt2img_inf(
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platfrom = init_iree_metal_target_platfrom
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
Expand Down
4 changes: 3 additions & 1 deletion shark/iree_utils/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def check_device_drivers(device):
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["metal", "vulkan"]:
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
Expand Down
111 changes: 101 additions & 10 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import os
import re
import tempfile


# Get the iree-compile arguments given device.
Expand Down Expand Up @@ -181,8 +182,10 @@ def compile_benchmark_dirs(bench_dir, device, dispatch_benchmarks):
vmfb_file.close()

config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance,
flatbuffer_blob,
warn_if_copy=False,
)

benchmark_cl = build_benchmark_args_non_tensor_input(
Expand Down Expand Up @@ -313,35 +316,123 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
vm_module = ireert.VmModule.from_buffer(
config.vm_instance, flatbuffer_blob, warn_if_copy=False
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = getattr(ctx.modules, vm_module.name)
return ModuleCompiled, config


def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)

haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink


def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
model_config_path: str = None,
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)

temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
# Got to find a cleaner way to unlink/delete the temporary file since
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
else:
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params

def load_flatbuffer(flatbuffer_path: str, device: str, device_idx: int = None):
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()

return get_iree_module(flatbuffer_blob, device, device_idx=device_idx)
def load_flatbuffer(
flatbuffer_path: str,
device: str,
device_idx: int = None,
mmap: bool = False,
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
else:
with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()
vmfb, config = get_iree_module(
flatbuffer_blob, device, device_idx=device_idx
)
ret_params = {
"vmfb": vmfb,
"config": config,
"temp_file_to_unlink": temp_file_to_unlink,
}
return ret_params


def export_iree_module_to_vmfb(
Expand Down
10 changes: 5 additions & 5 deletions shark/iree_utils/metal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
print(
f"Found metal device {metal_device}. Using metal target triple {triple}"
)
return f"-iree-metal-target-platfrom={triple}"
return f"-iree-metal-target-platform={triple}"
print(
"""Optimized kernel for your target device is not added yet.
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
Expand All @@ -101,16 +101,16 @@ def get_iree_metal_args(device_num=0, extra_args=[]):
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
meatal_triple_flag = arg
metal_triple_flag = arg
break

if metal_triple_flag is None:
meatal_triple_flag = get_metal_triple_flag(
metal_triple_flag = get_metal_triple_flag(
device_num=device_num, extra_args=extra_args
)

if meatal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(meatal_triple_flag)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(metal_triple_flag)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

Expand Down
Loading

0 comments on commit d7d7b01

Please sign in to comment.