Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 23, 2024
1 parent 08d4824 commit f76fc4a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 24 deletions.
8 changes: 3 additions & 5 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name]["initializer"](
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
Expand Down Expand Up @@ -273,10 +275,6 @@ def format_out(results):
self.prev_token_len = token_len + len(history)

if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
if (
format_out(token)
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
):
break

for i in range(len(history)):
Expand Down
35 changes: 20 additions & 15 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import SharkSDXLPipeline
from turbine_models.custom_models.sdxl_inference.sdxl_compiled_pipeline import (
SharkSDXLPipeline,
)


from apps.shark_studio.api.controlnet import control_adapter_map
Expand Down Expand Up @@ -104,9 +106,9 @@ def __init__(
pipe_id_list.append(custom_vae)
self.pipe_id = "_".join(pipe_id_list)
self.pipeline_dir = Path(os.path.join(get_checkpoints_path(), self.pipe_id))
self.weights_path = Path(os.path.join(
get_checkpoints_path(), safe_name(self.base_model_id)
))
self.weights_path = Path(
os.path.join(get_checkpoints_path(), safe_name(self.base_model_id))
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)

Expand Down Expand Up @@ -140,18 +142,21 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
weights = copy.deepcopy(self.model_map)

if custom_weights:
custom_weights_params, _ = process_custom_pipe_weights(
custom_weights
)
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
for key in weights:
if key not in ["vae_decode", "pipeline", "full_pipeline"]:
weights[key] = custom_weights_params


vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False)
vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline)
print("\n[LOG] Pipeline successfully prepared for runtime. Generating images...")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
return

def generate_images(
Expand Down Expand Up @@ -236,7 +241,7 @@ def shark_sd_fn(
control_mode = None
hints = []
num_loras = 0
import_ir=True
import_ir = True
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
if "model" in controlnets:
Expand Down Expand Up @@ -305,7 +310,6 @@ def shark_sd_fn(
# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.


sd_pipe = StableDiffusion(
**submit_pipe_kwargs,
Expand All @@ -325,7 +329,7 @@ def shark_sd_fn(
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
# total_time = time.time() - start_time
# text_output = f"Total image(s) generation time: {total_time:.4f}sec"
#print(f"\n[LOG] {text_output}")
# print(f"\n[LOG] {text_output}")
# if global_obj.get_sd_status() == SD_STATE_CANCEL:
# break
# else:
Expand All @@ -352,8 +356,9 @@ def view_json_file(file_path):
content = fopen.read()
return content


def safe_name(name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_")
return name.replace("/", "_").replace("-", "_").replace("\\", "_").replace(".", "_")


if __name__ == "__main__":
Expand Down
19 changes: 15 additions & 4 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_iree_vulkan_runtime_flags,
)


def get_available_devices():
def get_devices_by_name(driver_name):
from shark.iree_utils._common import iree_device_map
Expand Down Expand Up @@ -49,7 +50,7 @@ def get_devices_by_name(driver_name):
return device_list

set_iree_runtime_flags()

available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
Expand Down Expand Up @@ -78,6 +79,7 @@ def get_devices_by_name(driver_name):
available_devices.extend(cpu_device)
return available_devices


def set_init_device_flags():
if "vulkan" in cmd_opts.device:
# set runtime flags for vulkan.
Expand Down Expand Up @@ -126,8 +128,14 @@ def set_iree_runtime_flags():
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)


def parse_device(device_str):
from shark.iree_utils.compile_utils import clean_device_info, get_iree_target_triple, iree_target_map
from shark.iree_utils.compile_utils import (
clean_device_info,
get_iree_target_triple,
iree_target_map,
)

rt_driver, device_id = clean_device_info(device_str)
target_backend = iree_target_map(rt_driver)
if device_id:
Expand All @@ -147,7 +155,7 @@ def parse_device(device_str):


def get_rocm_target_chip(device_str):
#TODO: Use a data file to map device_str to target chip.
# TODO: Use a data file to map device_str to target chip.
rocm_chip_map = {
"6700": "gfx1031",
"6800": "gfx1030",
Expand All @@ -164,7 +172,10 @@ def get_rocm_target_chip(device_str):
for key in rocm_chip_map:
if key in device_str:
return rocm_chip_map[key]
raise AssertionError(f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues.")
raise AssertionError(
f"Device {device_str} not recognized. Please file an issue at https://github.com/nod-ai/SHARK/issues."
)


def get_all_devices(driver_name):
"""
Expand Down

0 comments on commit f76fc4a

Please sign in to comment.