From 13e1d8d98a232c01e8aac1d2f37f26a477eef179 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 29 May 2024 17:39:36 -0500 Subject: [PATCH] Add igpu and custom triple support. (#2148) --- apps/shark_studio/api/sd.py | 5 ++++- apps/shark_studio/api/utils.py | 20 +++++++++++++++++++- apps/shark_studio/web/ui/sd.py | 9 +++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 83574d294d..5f27c11c71 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -73,6 +73,7 @@ def __init__( scheduler: str, precision: str, device: str, + target_triple: str = None, custom_vae: str = None, num_loras: int = 0, import_ir: bool = True, @@ -91,7 +92,7 @@ def __init__( self.model_map = EMPTY_SD_MAP external_weights = "safetensors" max_length = 64 - target_backend, self.rt_device, triple = parse_device(device) + target_backend, self.rt_device, triple = parse_device(device, target_triple) pipe_id_list = [ safe_name(base_model_id), str(batch_size), @@ -273,6 +274,7 @@ def shark_sd_fn( custom_vae: str, precision: str, device: str, + target_triple: str, ondemand: bool, repeatable_seeds: bool, resample_type: str, @@ -326,6 +328,7 @@ def shark_sd_fn( "batch_size": batch_size, "precision": precision, "device": device, + "target_triple": target_triple, "custom_vae": custom_vae, "num_loras": num_loras, "import_ir": import_ir, diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 0516255d2b..b87ee0e628 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -77,6 +77,21 @@ def get_devices_by_name(driver_name): available_devices.extend(cpu_device) cpu_device = get_devices_by_name("cpu-task") available_devices.extend(cpu_device) + print(available_devices) + for idx, device_str in enumerate(available_devices): + if "AMD Radeon(TM) Graphics =>" in device_str: + igpu_id_candidates = [ + x.split("w/")[-1].split("=>")[0] + for x in available_devices + if "M Graphics" in x + ] + for igpu_name in igpu_id_candidates: + if igpu_name: + print(f"Found iGPU: {igpu_name} for {device_str}") + available_devices[idx] = device_str.replace( + "AMD Radeon(TM) Graphics", f"AMD iGPU: {igpu_name}" + ) + break return available_devices @@ -129,7 +144,7 @@ def set_iree_runtime_flags(): set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) -def parse_device(device_str): +def parse_device(device_str, target_override=""): from shark.iree_utils.compile_utils import ( clean_device_info, get_iree_target_triple, @@ -143,6 +158,8 @@ def parse_device(device_str): else: rt_device = rt_driver + if target_override: + return target_backend, rt_device, target_override match target_backend: case "vulkan-spirv": triple = get_iree_target_triple(device_str) @@ -168,6 +185,7 @@ def get_rocm_target_chip(device_str): "MI100": "gfx908", "MI50": "gfx906", "MI60": "gfx906", + "780M": "gfx1103", } for key in rocm_chip_map: if key in device_str: diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index a4df173b1c..0698ebda5e 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -117,6 +117,7 @@ def pull_sd_configs( custom_vae, precision, device, + target_triple, ondemand, repeatable_seeds, resample_type, @@ -175,6 +176,7 @@ def load_sd_cfg(sd_json: dict, load_sd_config: str): sd_json["custom_vae"], sd_json["precision"], sd_json["device"], + sd_json["target_triple"], sd_json["ondemand"], sd_json["repeatable_seeds"], sd_json["resample_type"], @@ -253,6 +255,11 @@ def base_model_changed(base_model_id): choices=global_obj.get_device_list(), allow_custom_value=False, ) + target_triple = gr.Textbox( + elem_id="triple", + label="Architecture", + value="", + ) with gr.Row(): ondemand = gr.Checkbox( value=cmd_opts.lowvram, @@ -691,6 +698,7 @@ def base_model_changed(base_model_id): custom_vae, precision, device, + target_triple, ondemand, repeatable_seeds, resample_type, @@ -730,6 +738,7 @@ def base_model_changed(base_model_id): custom_vae, precision, device, + target_triple, ondemand, repeatable_seeds, resample_type,