Skip to content

Commit

Permalink
Add igpu and custom triple support. (#2148)
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored May 29, 2024
1 parent 2074df4 commit 13e1d8d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
5 changes: 4 additions & 1 deletion apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
scheduler: str,
precision: str,
device: str,
target_triple: str = None,
custom_vae: str = None,
num_loras: int = 0,
import_ir: bool = True,
Expand All @@ -91,7 +92,7 @@ def __init__(
self.model_map = EMPTY_SD_MAP
external_weights = "safetensors"
max_length = 64
target_backend, self.rt_device, triple = parse_device(device)
target_backend, self.rt_device, triple = parse_device(device, target_triple)
pipe_id_list = [
safe_name(base_model_id),
str(batch_size),
Expand Down Expand Up @@ -273,6 +274,7 @@ def shark_sd_fn(
custom_vae: str,
precision: str,
device: str,
target_triple: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
Expand Down Expand Up @@ -326,6 +328,7 @@ def shark_sd_fn(
"batch_size": batch_size,
"precision": precision,
"device": device,
"target_triple": target_triple,
"custom_vae": custom_vae,
"num_loras": num_loras,
"import_ir": import_ir,
Expand Down
20 changes: 19 additions & 1 deletion apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,21 @@ def get_devices_by_name(driver_name):
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
print(available_devices)
for idx, device_str in enumerate(available_devices):
if "AMD Radeon(TM) Graphics =>" in device_str:
igpu_id_candidates = [
x.split("w/")[-1].split("=>")[0]
for x in available_devices
if "M Graphics" in x
]
for igpu_name in igpu_id_candidates:
if igpu_name:
print(f"Found iGPU: {igpu_name} for {device_str}")
available_devices[idx] = device_str.replace(
"AMD Radeon(TM) Graphics", f"AMD iGPU: {igpu_name}"
)
break
return available_devices


Expand Down Expand Up @@ -129,7 +144,7 @@ def set_iree_runtime_flags():
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)


def parse_device(device_str):
def parse_device(device_str, target_override=""):
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
Expand All @@ -143,6 +158,8 @@ def parse_device(device_str):
else:
rt_device = rt_driver

if target_override:
return target_backend, rt_device, target_override
match target_backend:
case "vulkan-spirv":
triple = get_iree_target_triple(device_str)
Expand All @@ -168,6 +185,7 @@ def get_rocm_target_chip(device_str):
"MI100": "gfx908",
"MI50": "gfx906",
"MI60": "gfx906",
"780M": "gfx1103",
}
for key in rocm_chip_map:
if key in device_str:
Expand Down
9 changes: 9 additions & 0 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def pull_sd_configs(
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
Expand Down Expand Up @@ -175,6 +176,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str):
sd_json["custom_vae"],
sd_json["precision"],
sd_json["device"],
sd_json["target_triple"],
sd_json["ondemand"],
sd_json["repeatable_seeds"],
sd_json["resample_type"],
Expand Down Expand Up @@ -253,6 +255,11 @@ def base_model_changed(base_model_id):
choices=global_obj.get_device_list(),
allow_custom_value=False,
)
target_triple = gr.Textbox(
elem_id="triple",
label="Architecture",
value="",
)
with gr.Row():
ondemand = gr.Checkbox(
value=cmd_opts.lowvram,
Expand Down Expand Up @@ -691,6 +698,7 @@ def base_model_changed(base_model_id):
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
Expand Down Expand Up @@ -730,6 +738,7 @@ def base_model_changed(base_model_id):
custom_vae,
precision,
device,
target_triple,
ondemand,
repeatable_seeds,
resample_type,
Expand Down

0 comments on commit 13e1d8d

Please sign in to comment.