diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py new file mode 100644 index 0000000000..2c8a8b566b --- /dev/null +++ b/apps/shark_studio/api/controlnet.py @@ -0,0 +1,107 @@ +# from turbine_models.custom_models.controlnet import control_adapter, preprocessors +import os +import PIL +import numpy as np +from apps.shark_studio.web.utils.file_utils import ( + get_generated_imgs_path, +) +from datetime import datetime +from PIL import Image +from gradio.components.image_editor import ( + EditorValue, +) + + +class control_adapter: + def __init__( + self, + model: str, + ): + self.model = None + + def export_control_adapter_model(model_keyword): + return None + + def export_xl_control_adapter_model(model_keyword): + return None + + +class preprocessors: + def __init__( + self, + model: str, + ): + self.model = None + + def export_controlnet_model(model_keyword): + return None + + +control_adapter_map = { + "sd15": { + "canny": {"initializer": control_adapter.export_control_adapter_model}, + "openpose": {"initializer": control_adapter.export_control_adapter_model}, + "scribble": {"initializer": control_adapter.export_control_adapter_model}, + "zoedepth": {"initializer": control_adapter.export_control_adapter_model}, + }, + "sdxl": { + "canny": {"initializer": control_adapter.export_xl_control_adapter_model}, + }, +} +preprocessor_model_map = { + "canny": {"initializer": preprocessors.export_controlnet_model}, + "openpose": {"initializer": preprocessors.export_controlnet_model}, + "scribble": {"initializer": preprocessors.export_controlnet_model}, + "zoedepth": {"initializer": preprocessors.export_controlnet_model}, +} + + +class PreprocessorModel: + def __init__( + self, + hf_model_id, + device="cpu", + ): + self.model = hf_model_id + self.device = device + + def compile(self): + print("compile not implemented for preprocessor.") + return + + def run(self, inputs): + print("run not implemented for preprocessor.") + return inputs + + +def cnet_preview(model, input_image): + curr_datetime = datetime.now().strftime("%Y-%m-%d.%H-%M-%S") + control_imgs_path = os.path.join(get_generated_imgs_path(), "control_hints") + if not os.path.exists(control_imgs_path): + os.mkdir(control_imgs_path) + img_dest = os.path.join(control_imgs_path, model + curr_datetime + ".png") + match model: + case "canny": + canny = PreprocessorModel("canny") + result = canny( + np.array(input_image), + 100, + 200, + ) + Image.fromarray(result).save(fp=img_dest) + return result, img_dest + case "openpose": + openpose = PreprocessorModel("openpose") + result = openpose(np.array(input_image)) + Image.fromarray(result[0]).save(fp=img_dest) + return result, img_dest + case "zoedepth": + zoedepth = PreprocessorModel("ZoeDepth") + result = zoedepth(np.array(input_image)) + Image.fromarray(result).save(fp=img_dest) + return result, img_dest + case "scribble": + input_image.save(fp=img_dest) + return input_image, img_dest + case _: + return None, None diff --git a/apps/shark_studio/api/initializers.py b/apps/shark_studio/api/initializers.py new file mode 100644 index 0000000000..48e7246df6 --- /dev/null +++ b/apps/shark_studio/api/initializers.py @@ -0,0 +1,125 @@ +import importlib +import os +import signal +import sys +import warnings +import json +from threading import Thread + +from apps.shark_studio.modules.timer import startup_timer + +from apps.shark_studio.web.utils.tmp_configs import ( + config_tmp, + clear_tmp_mlir, + clear_tmp_imgs, + shark_tmp, +) + + +def imports(): + import torch # noqa: F401 + + startup_timer.record("import torch") + warnings.filterwarnings( + action="ignore", category=DeprecationWarning, module="torch" + ) + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torch") + + import gradio # noqa: F401 + + startup_timer.record("import gradio") + + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + startup_timer.record("initialize globals") + + from apps.shark_studio.modules import ( + img_processing, + ) # noqa: F401 + + startup_timer.record("other imports") + + +def initialize(): + configure_sigint_handler() + # Setup to use shark_tmp for gradio's temporary image files and clear any + # existing temporary images there if they exist. Then we can import gradio. + # It has to be in this order or gradio ignores what we've set up. + + config_tmp() + # clear_tmp_mlir() + clear_tmp_imgs() + + from apps.shark_studio.web.utils.file_utils import ( + create_checkpoint_folders, + ) + + # Create custom models folders if they don't exist + create_checkpoint_folders() + + import gradio as gr + + # initialize_rest(reload_script_modules=False) + + +def initialize_rest(*, reload_script_modules=False): + """ + Called both from initialize() and when reloading the webui. + """ + # Keep this for adding reload options to the webUI. + + +def dumpstacks(): + import threading + import traceback + + id2name = {th.ident: th.name for th in threading.enumerate()} + code = [] + for threadId, stack in sys._current_frames().items(): + code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})") + for filename, lineno, name, line in traceback.extract_stack(stack): + code.append(f"""File: "{filename}", line {lineno}, in {name}""") + if line: + code.append(" " + line.strip()) + with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f: + f.write("\n".join(code)) + + +def setup_middleware(app): + from starlette.middleware.gzip import GZipMiddleware + + app.middleware_stack = ( + None # reset current middleware to allow modifying user provided list + ) + app.add_middleware(GZipMiddleware, minimum_size=1000) + configure_cors_middleware(app) + app.build_middleware_stack() # rebuild middleware stack on-the-fly + + +def configure_cors_middleware(app): + from starlette.middleware.cors import CORSMiddleware + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_credentials": True, + } + if cmd_opts.api_accept_origin: + cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",") + + app.add_middleware(CORSMiddleware, **cors_options) + + +def configure_sigint_handler(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f"Interrupted with signal {sig} in {frame}") + + dumpstacks() + + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 647d6a5af1..a88aaa9b02 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -3,7 +3,8 @@ from turbine_models.gen_external_params.gen_external_params import gen_external_params import time from shark.iree_utils.compile_utils import compile_module_to_flatbuffer -from apps.shark_studio.web.utils import get_resource_path +from apps.shark_studio.web.utils.file_utils import get_resource_path +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts import iree.runtime as ireert from itertools import chain import gc @@ -88,21 +89,29 @@ def __init__( if self.quantization != "None": self.file_spec += "_" + self.quantization - if external_weights is not None: + if external_weights in ["safetensors", "gguf"]: self.external_weight_file = get_resource_path( - self.file_spec + "." + external_weights + os.path.join("..", self.file_spec + "." + external_weights) ) + else: + self.external_weights = None + self.external_weight_file = None if streaming_llm: # Add streaming suffix to file spec after setting external weights filename. self.file_spec += "_streaming" self.streaming_llm = streaming_llm - self.tempfile_name = get_resource_path(f"{self.file_spec}.tempfile") + self.tempfile_name = get_resource_path( + os.path.join("..", f"{self.file_spec}.tempfile") + ) # TODO: Tag vmfb with target triple of device instead of HAL backend - self.vmfb_name = get_resource_path( - f"{self.file_spec}_{self.backend}.vmfb.tempfile" + self.vmfb_name = str( + get_resource_path( + os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile") + ) ) + self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None self.use_system_prompt = use_system_prompt @@ -126,6 +135,8 @@ def __init__( print( f"External weight file {self.external_weight_file} found for {self.vmfb_name}" ) + self.external_weight_file = str(self.external_weight_file) + if os.path.exists(self.vmfb_name) and ( external_weights is None or os.path.exists(str(self.external_weight_file)) ): @@ -209,10 +220,8 @@ def sanitize_prompt(self, prompt): prompt = prompt.replace("\r", " ") if self.use_system_prompt and self.global_iter == 0: prompt = append_user_prompt(DEFAULT_CHAT_SYS_PROMPT, prompt) - print(prompt) return prompt else: - print(prompt) return f"{B_INST} {prompt} {E_INST}" def chat(self, prompt): @@ -248,7 +257,10 @@ def format_out(results): token_len += 1 history.append(format_out(token)) - while format_out(token) != llm_model_map["llama2_7b"]["stop_token"]: + while ( + format_out(token) != llm_model_map["llama2_7b"]["stop_token"] + and len(history) < self.max_tokens + ): dec_time = time.time() if self.streaming_llm and self.model["get_seq_step"]() > 600: print("Evicting cache space!") @@ -315,6 +327,101 @@ def chat_hf(self, prompt): return result_output, total_time +def llm_chat_api(InputData: dict): + from datetime import datetime as dt + + import apps.shark_studio.web.utils.globals as global_obj + + print(f"Input keys : {InputData.keys()}") + + # print(f"model : {InputData['model']}") + + is_chat_completion_api = ( + "messages" in InputData.keys() + ) # else it is the legacy `completion` api + + # For Debugging input data from API + if is_chat_completion_api: + print(f"message -> role : {InputData['messages'][0]['role']}") + print(f"message -> content : {InputData['messages'][0]['content']}") + else: + print(f"prompt : {InputData['prompt']}") + + model_name = InputData["model"] if "model" in InputData.keys() else "llama2_7b" + model_path = llm_model_map[model_name] + device = InputData["device"] if "device" in InputData.keys() else "cpu" + precision = "fp16" + max_tokens = InputData["max_tokens"] if "max_tokens" in InputData.keys() else 4096 + + device_id = None + if not global_obj.get_llm_obj(): + print("\n[LOG] Initializing new pipeline...") + global_obj.clear_cache() + gc.collect() + if "cuda" in device: + device = "cuda" + elif "vulkan" in device: + device_id = int(device.split("://")[1]) + device = "vulkan" + elif "cpu" in device: + device = "cpu" + precision = "fp32" + else: + print("unrecognized device") + llm_model = LanguageModel( + model_name=model_name, + hf_auth_token=cmd_opts.hf_auth_token, + device=device, + quantization=cmd_opts.quantization, + external_weights="safetensors", + use_system_prompt=True, + streaming_llm=False, + ) + global_obj.set_llm_obj(llm_model) + else: + llm_model = global_obj.get_llm_obj() + + llm_model.max_tokens = max_tokens + # TODO: add role dict for different models + if is_chat_completion_api: + # TODO: add funtionality for multiple messages + prompt = append_user_prompt( + InputData["messages"][0]["role"], InputData["messages"][0]["content"] + ) + else: + prompt = InputData["prompt"] + print("prompt = ", prompt) + + for res_op, _ in llm_model.chat(prompt): + if is_chat_completion_api: + choices = [ + { + "index": 0, + "message": { + "role": "assistant", + "content": res_op, # since we are yeilding the result + }, + "finish_reason": "stop", # or length + } + ] + else: + choices = [ + { + "text": res_op, + "index": 0, + "logprobs": None, + "finish_reason": "stop", # or length + } + ] + end_time = dt.now().strftime("%Y%m%d%H%M%S%f") + return { + "id": end_time, + "object": "chat.completion" if is_chat_completion_api else "text_completion", + "created": int(end_time), + "choices": choices, + } + + if __name__ == "__main__": lm = LanguageModel( "Trelis/Llama-2-7b-chat-hf-function-calling-v2", diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py new file mode 100644 index 0000000000..1b37384725 --- /dev/null +++ b/apps/shark_studio/api/sd.py @@ -0,0 +1,611 @@ +import gc +import torch +import time +import os +import json +import numpy as np +from tqdm.auto import tqdm + +from pathlib import Path +from random import randint +from turbine_models.custom_models.sd_inference import clip, unet, vae +from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.web.utils.state import status_label +from apps.shark_studio.web.utils.file_utils import ( + safe_name, + get_resource_path, + get_checkpoints_path, +) +from apps.shark_studio.modules.pipeline import SharkPipelineBase +from apps.shark_studio.modules.schedulers import get_schedulers +from apps.shark_studio.modules.prompt_encoding import ( + get_weighted_text_embeddings, +) +from apps.shark_studio.modules.img_processing import ( + resize_stencil, + save_output_img, + resamplers, + resampler_list, +) + +from apps.shark_studio.modules.ckpt_processing import ( + preprocessCKPT, + process_custom_pipe_weights, +) +from transformers import CLIPTokenizer +from diffusers.image_processor import VaeImageProcessor + +sd_model_map = { + "clip": { + "initializer": clip.export_clip_model, + }, + "unet": { + "initializer": unet.export_unet_model, + }, + "vae_decode": { + "initializer": vae.export_vae_model, + }, +} + + +class StableDiffusion(SharkPipelineBase): + # This class is responsible for executing image generation and creating + # /managing a set of compiled modules to run Stable Diffusion. The init + # aims to be as general as possible, and the class will infer and compile + # a list of necessary modules or a combined "pipeline module" for a + # specified job based on the inference task. + + def __init__( + self, + base_model_id, + height: int, + width: int, + batch_size: int, + precision: str, + device: str, + custom_vae: str = None, + num_loras: int = 0, + import_ir: bool = True, + is_controlled: bool = False, + hf_auth_token=None, + ): + self.model_max_length = 77 + self.batch_size = batch_size + self.precision = precision + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.height = height + self.width = width + self.scheduler_obj = {} + static_kwargs = { + "pipe": { + "external_weights": "safetensors", + }, + "clip": {"hf_model_name": base_model_id}, + "unet": { + "hf_model_name": base_model_id, + "unet_model": unet.UnetModel(hf_model_name=base_model_id), + "batch_size": batch_size, + # "is_controlled": is_controlled, + # "num_loras": num_loras, + "height": height, + "width": width, + "precision": precision, + "max_length": self.model_max_length, + }, + "vae_encode": { + "hf_model_name": base_model_id, + "vae_model": vae.VaeModel( + hf_model_name=custom_vae if custom_vae else base_model_id, + ), + "batch_size": batch_size, + "height": height, + "width": width, + "precision": precision, + }, + "vae_decode": { + "hf_model_name": base_model_id, + "vae_model": vae.VaeModel( + hf_model_name=custom_vae if custom_vae else base_model_id, + ), + "batch_size": batch_size, + "height": height, + "width": width, + "precision": precision, + }, + } + super().__init__(sd_model_map, base_model_id, static_kwargs, device, import_ir) + pipe_id_list = [ + safe_name(base_model_id), + str(batch_size), + str(self.model_max_length), + f"{str(height)}x{str(width)}", + precision, + self.device, + ] + if num_loras > 0: + pipe_id_list.append(str(num_loras) + "lora") + if is_controlled: + pipe_id_list.append("controlled") + if custom_vae: + pipe_id_list.append(custom_vae) + self.pipe_id = "_".join(pipe_id_list) + print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") + del static_kwargs + gc.collect() + + def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): + print(f"\n[LOG] Preparing pipeline...") + self.is_img2img = is_img2img + self.schedulers = get_schedulers(self.base_model_id) + + self.weights_path = os.path.join( + get_checkpoints_path(), self.safe_name(self.base_model_id) + ) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + + for model in adapters: + self.model_map[model] = adapters[model] + + for submodel in self.static_kwargs: + if custom_weights: + custom_weights_params, _ = process_custom_pipe_weights(custom_weights) + if submodel not in ["clip", "clip2"]: + self.static_kwargs[submodel][ + "external_weights" + ] = custom_weights_params + else: + self.static_kwargs[submodel]["external_weight_path"] = os.path.join( + self.weights_path, submodel + ".safetensors" + ) + else: + self.static_kwargs[submodel]["external_weight_path"] = os.path.join( + self.weights_path, submodel + ".safetensors" + ) + + self.get_compiled_map(pipe_id=self.pipe_id) + print("\n[LOG] Pipeline successfully prepared for runtime.") + return + + def encode_prompts_weight( + self, + prompt, + negative_prompt, + do_classifier_free_guidance=True, + ): + # Encodes the prompt into text encoder hidden states. + self.load_submodels(["clip"]) + self.tokenizer = CLIPTokenizer.from_pretrained( + self.base_model_id, + subfolder="tokenizer", + ) + clip_inf_start = time.time() + + text_embeddings, uncond_embeddings = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + ) + + if do_classifier_free_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + pad = (0, 0) * (len(text_embeddings.shape) - 2) + pad = pad + ( + 0, + self.static_kwargs["unet"]["max_length"] - text_embeddings.shape[1], + ) + text_embeddings = torch.nn.functional.pad(text_embeddings, pad) + + # SHARK: Report clip inference time + clip_inf_time = (time.time() - clip_inf_start) * 1000 + if self.ondemand: + self.unload_submodels(["clip"]) + gc.collect() + print(f"\n[LOG] Clip Inference time (ms) = {clip_inf_time:.3f}") + + return text_embeddings.numpy().astype(np.float16) + + def prepare_latents( + self, + generator, + num_inference_steps, + image, + strength, + ): + noise = torch.randn( + ( + self.batch_size, + 4, + self.height // 8, + self.width // 8, + ), + generator=generator, + dtype=self.dtype, + ).to("cpu") + + self.scheduler.set_timesteps(num_inference_steps) + if self.is_img2img: + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps + ) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + latents = self.encode_image(image) + latents = self.scheduler.add_noise(latents, noise, timesteps[0].repeat(1)) + return latents, [timesteps] + else: + self.scheduler.is_scale_input_called = True + latents = noise * self.scheduler.init_noise_sigma + return latents, self.scheduler.timesteps + + def encode_image(self, input_image): + self.load_submodels(["vae_encode"]) + vae_encode_start = time.time() + latents = self.run("vae_encode", input_image) + vae_inf_time = (time.time() - vae_encode_start) * 1000 + if self.ondemand: + self.unload_submodels(["vae_encode"]) + print(f"\n[LOG] VAE Encode Inference time (ms): {vae_inf_time:.3f}") + + return latents + + def produce_img_latents( + self, + latents, + text_embeddings, + guidance_scale, + total_timesteps, + cpu_scheduling, + mask=None, + masked_image_latents=None, + return_all_latents=False, + ): + # self.status = SD_STATE_IDLE + step_time_sum = 0 + latent_history = [latents] + text_embeddings = torch.from_numpy(text_embeddings).to(self.dtype) + text_embeddings_numpy = text_embeddings.detach().numpy() + guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype) + self.load_submodels(["unet"]) + for i, t in tqdm(enumerate(total_timesteps)): + step_start_time = time.time() + timestep = torch.tensor([t]).to(self.dtype).detach().numpy() + latent_model_input = self.scheduler.scale_model_input(latents, t).to( + self.dtype + ) + if mask is not None and masked_image_latents is not None: + latent_model_input = torch.cat( + [ + torch.from_numpy(np.asarray(latent_model_input)).to(self.dtype), + mask, + masked_image_latents, + ], + dim=1, + ).to(self.dtype) + if cpu_scheduling: + latent_model_input = latent_model_input.detach().numpy() + + # Profiling Unet. + # profile_device = start_profiling(file_path="unet.rdc") + noise_pred = self.run( + "unet", + [ + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + ], + ) + # end_profiling(profile_device) + + if cpu_scheduling: + noise_pred = torch.from_numpy(noise_pred.to_host()) + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + else: + latents = self.run("scheduler_step", (noise_pred, t, latents)) + + latent_history.append(latents) + step_time = (time.time() - step_start_time) * 1000 + # print( + # f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms" + # ) + step_time_sum += step_time + + # if self.status == SD_STATE_CANCEL: + # break + + if self.ondemand: + self.unload_submodels(["unet"]) + gc.collect() + + avg_step_time = step_time_sum / len(total_timesteps) + print(f"\n[LOG] Average step time: {avg_step_time}ms/it") + + if not return_all_latents: + return latents + all_latents = torch.cat(latent_history, dim=0) + return all_latents + + def decode_latents(self, latents, cpu_scheduling=True): + latents_numpy = latents.to(self.dtype) + if cpu_scheduling: + latents_numpy = latents.detach().numpy() + + # profile_device = start_profiling(file_path="vae.rdc") + vae_start = time.time() + images = self.run("vae_decode", latents_numpy).to_host() + vae_inf_time = (time.time() - vae_start) * 1000 + # end_profiling(profile_device) + print(f"\n[LOG] VAE Inference time (ms): {vae_inf_time:.3f}") + + images = torch.from_numpy(images).permute(0, 2, 3, 1).float().numpy() + pil_images = self.image_processor.numpy_to_pil(images) + return pil_images + + def generate_images( + self, + prompt, + negative_prompt, + image, + scheduler, + steps, + strength, + guidance_scale, + seed, + ondemand, + repeatable_seeds, + resample_type, + control_mode, + hints, + ): + # TODO: Batched args + self.image_processor = VaeImageProcessor(do_convert_rgb=True) + self.scheduler = self.schedulers[scheduler] + self.ondemand = ondemand + if self.is_img2img: + image, _ = self.image_processor.preprocess(image, resample_type) + else: + image = None + + print("\n[LOG] Generating images...") + batched_args = [ + prompt, + negative_prompt, + image, + ] + for arg in batched_args: + if not isinstance(arg, list): + arg = [arg] * self.batch_size + if len(arg) < self.batch_size: + arg = arg * self.batch_size + else: + arg = [arg[i] for i in range(self.batch_size)] + + text_embeddings = self.encode_prompts_weight( + prompt, + negative_prompt, + ) + + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + + generator = torch.manual_seed(seed) + + init_latents, final_timesteps = self.prepare_latents( + generator=generator, + num_inference_steps=steps, + image=image, + strength=strength, + ) + + latents = self.produce_img_latents( + latents=init_latents, + text_embeddings=text_embeddings, + guidance_scale=guidance_scale, + total_timesteps=final_timesteps, + cpu_scheduling=True, # until we have schedulers through Turbine + ) + + # Img latents -> PIL images + all_imgs = [] + self.load_submodels(["vae_decode"]) + for i in tqdm(range(0, latents.shape[0], self.batch_size)): + imgs = self.decode_latents( + latents=latents[i : i + self.batch_size], + cpu_scheduling=True, + ) + all_imgs.extend(imgs) + if self.ondemand: + self.unload_submodels(["vae_decode"]) + + return all_imgs + + +def shark_sd_fn_dict_input( + sd_kwargs: dict, +): + print("[LOG] Submitting Request...") + + for key in sd_kwargs: + if sd_kwargs[key] in [None, []]: + sd_kwargs[key] = None + if sd_kwargs[key] in ["None"]: + sd_kwargs[key] = "" + if key == "seed": + sd_kwargs[key] = int(sd_kwargs[key]) + + for i in range(1): + generated_imgs = yield from shark_sd_fn(**sd_kwargs) + yield generated_imgs + + +def shark_sd_fn( + prompt, + negative_prompt, + sd_init_image: list, + height: int, + width: int, + steps: int, + strength: float, + guidance_scale: float, + seed: list, + batch_count: int, + batch_size: int, + scheduler: str, + base_model_id: str, + custom_weights: str, + custom_vae: str, + precision: str, + device: str, + ondemand: bool, + repeatable_seeds: bool, + resample_type: str, + controlnets: dict, + embeddings: dict, +): + sd_kwargs = locals() + if not isinstance(sd_init_image, list): + sd_init_image = [sd_init_image] + is_img2img = True if sd_init_image[0] is not None else False + + print("\n[LOG] Performing Stable Diffusion Pipeline setup...") + + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + adapters = {} + is_controlled = False + control_mode = None + hints = [] + num_loras = 0 + for i in embeddings: + num_loras += 1 if embeddings[i] else 0 + if "model" in controlnets: + for i, model in enumerate(controlnets["model"]): + if "xl" not in base_model_id.lower(): + adapters[f"control_adapter_{model}"] = { + "hf_id": control_adapter_map["runwayml/stable-diffusion-v1-5"][ + model + ], + "strength": controlnets["strength"][i], + } + else: + adapters[f"control_adapter_{model}"] = { + "hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][ + model + ], + "strength": controlnets["strength"][i], + } + if model is not None: + is_controlled = True + control_mode = controlnets["control_mode"] + for i in controlnets["hint"]: + hints.append[i] + + submit_pipe_kwargs = { + "base_model_id": base_model_id, + "height": height, + "width": width, + "batch_size": batch_size, + "precision": precision, + "device": device, + "custom_vae": custom_vae, + "num_loras": num_loras, + "import_ir": cmd_opts.import_mlir, + "is_controlled": is_controlled, + } + submit_prep_kwargs = { + "custom_weights": custom_weights, + "adapters": adapters, + "embeddings": embeddings, + "is_img2img": is_img2img, + } + submit_run_kwargs = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "image": sd_init_image, + "steps": steps, + "scheduler": scheduler, + "strength": strength, + "guidance_scale": guidance_scale, + "seed": seed, + "ondemand": ondemand, + "repeatable_seeds": repeatable_seeds, + "resample_type": resample_type, + "control_mode": control_mode, + "hints": hints, + } + if ( + not global_obj.get_sd_obj() + or global_obj.get_pipe_kwargs() != submit_pipe_kwargs + ): + print("\n[LOG] Initializing new pipeline...") + global_obj.clear_cache() + gc.collect() + + # Initializes the pipeline and retrieves IR based on all + # parameters that are static in the turbine output format, + # which is currently MLIR in the torch dialect. + + sd_pipe = StableDiffusion( + **submit_pipe_kwargs, + ) + global_obj.set_sd_obj(sd_pipe) + global_obj.set_pipe_kwargs(submit_pipe_kwargs) + if ( + not global_obj.get_prep_kwargs() + or global_obj.get_prep_kwargs() != submit_prep_kwargs + ): + global_obj.set_prep_kwargs(submit_prep_kwargs) + global_obj.get_sd_obj().prepare_pipe(**submit_prep_kwargs) + + generated_imgs = [] + for current_batch in range(batch_count): + start_time = time.time() + out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs) + total_time = time.time() - start_time + text_output = f"Total image(s) generation time: {total_time:.4f}sec" + print(f"\n[LOG] {text_output}") + # if global_obj.get_sd_status() == SD_STATE_CANCEL: + # break + # else: + save_output_img( + out_imgs[current_batch], + seed, + sd_kwargs, + ) + generated_imgs.extend(out_imgs) + yield generated_imgs, status_label( + "Stable Diffusion", current_batch + 1, batch_count, batch_size + ) + return generated_imgs, "" + + +def cancel_sd(): + print("Inject call to cancel longer API calls.") + return + + +def view_json_file(file_path): + content = "" + with open(file_path, "r") as fopen: + content = fopen.read() + return content + + +if __name__ == "__main__": + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + import apps.shark_studio.web.utils.globals as global_obj + + global_obj._init() + + sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) + sd_kwargs = json.loads(sd_json) + for arg in vars(cmd_opts): + if arg in sd_kwargs: + sd_kwargs[arg] = getattr(cmd_opts, arg) + for i in shark_sd_fn_dict_input(sd_kwargs): + print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index 7a6e9bb4b7..e9268aa83b 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -8,8 +8,7 @@ ) from pathlib import Path - -# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts from cpuinfo import get_cpu_info # TODO: migrate these utils to studio @@ -79,11 +78,52 @@ def get_devices_by_name(driver_name): return available_devices +def set_init_device_flags(): + if "vulkan" in cmd_opts.device: + # set runtime flags for vulkan. + set_iree_runtime_flags() + + # set triple flag to avoid multiple calls to get_vulkan_triple_flag + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_vulkan_target_triple: + triple = get_vulkan_target_triple(device_name) + if triple is not None: + cmd_opts.iree_vulkan_target_triple = triple + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_vulkan_target_triple}." + ) + elif "cuda" in cmd_opts.device: + cmd_opts.device = "cuda" + elif "metal" in cmd_opts.device: + device_name, cmd_opts.device = map_device_to_name_path(cmd_opts.device) + if not cmd_opts.iree_metal_target_platform: + from shark.iree_utils.metal_utils import get_metal_target_triple + + triple = get_metal_target_triple(device_name) + if triple is not None: + cmd_opts.iree_metal_target_platform = triple.split("-")[-1] + print( + f"Found device {device_name}. Using target triple " + f"{cmd_opts.iree_metal_target_platform}." + ) + elif "cpu" in cmd_opts.device: + cmd_opts.device = "cpu" + + def set_iree_runtime_flags(): # TODO: This function should be device-agnostic and piped properly # to general runtime driver init. vulkan_runtime_flags = get_iree_vulkan_runtime_flags() - + if cmd_opts.enable_rgp: + vulkan_runtime_flags += [ + f"--enable_rgp=true", + f"--vulkan_debug_utils=true", + ] + if cmd_opts.device_allocator_heap_key: + vulkan_runtime_flags += [ + f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}", + ] set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) @@ -140,6 +180,32 @@ def get_output_value(dev_dict): return device_map +def get_opt_flags(model, precision="fp16"): + iree_flags = [] + if len(cmd_opts.iree_vulkan_target_triple) > 0: + iree_flags.append( + f"-iree-vulkan-target-triple={cmd_opts.iree_vulkan_target_triple}" + ) + if "rocm" in cmd_opts.device: + from shark.iree_utils.gpu_utils import get_iree_rocm_args + + rocm_args = get_iree_rocm_args() + iree_flags.extend(rocm_args) + if cmd_opts.iree_constant_folding == False: + iree_flags.append("--iree-opt-const-expr-hoisting=False") + iree_flags.append( + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + ) + if cmd_opts.data_tiling == False: + iree_flags.append("--iree-opt-data-tiling=False") + + if "vae" not in model: + # Due to lack of support for multi-reduce, we always collapse reduction + # dims before dispatch formation right now. + iree_flags += ["--iree-flow-collapse-reduction-dims"] + return iree_flags + + def map_device_to_name_path(device, key_combination=3): """Gives the appropriate device data (supported name/path) for user selected execution device @@ -165,6 +231,63 @@ def map_device_to_name_path(device, key_combination=3): raise ValueError(f"Device '{device}' is not a valid device.") return device_mapping + def get_devices_by_name(driver_name): + from shark.iree_utils._common import iree_device_map + + device_list = [] + try: + driver_name = iree_device_map(driver_name) + device_list_dict = get_all_devices(driver_name) + print(f"{driver_name} devices are available.") + except: + print(f"{driver_name} devices are not available.") + else: + cpu_name = get_cpu_info()["brand_raw"] + for i, device in enumerate(device_list_dict): + device_name = ( + cpu_name if device["name"] == "default" else device["name"] + ) + if "local" in driver_name: + device_list.append( + f"{device_name} => {driver_name.replace('local', 'cpu')}" + ) + else: + # for drivers with single devices + # let the default device be selected without any indexing + if len(device_list_dict) == 1: + device_list.append(f"{device_name} => {driver_name}") + else: + device_list.append(f"{device_name} => {driver_name}://{i}") + return device_list + + set_iree_runtime_flags() + + available_devices = [] + from shark.iree_utils.vulkan_utils import ( + get_all_vulkan_devices, + ) + + vulkaninfo_list = get_all_vulkan_devices() + vulkan_devices = [] + id = 0 + for device in vulkaninfo_list: + vulkan_devices.append(f"{device.strip()} => vulkan://{id}") + id += 1 + if id != 0: + print(f"vulkan devices are available.") + available_devices.extend(vulkan_devices) + metal_devices = get_devices_by_name("metal") + available_devices.extend(metal_devices) + cuda_devices = get_devices_by_name("cuda") + available_devices.extend(cuda_devices) + rocm_devices = get_devices_by_name("rocm") + available_devices.extend(rocm_devices) + cpu_device = get_devices_by_name("cpu-sync") + available_devices.extend(cpu_device) + cpu_device = get_devices_by_name("cpu-task") + available_devices.extend(cpu_device) + return available_devices + # Generate and return a new seed if the provided one is not in the # supported range (including -1) diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py new file mode 100644 index 0000000000..08681f6c56 --- /dev/null +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -0,0 +1,122 @@ +import os +import json +import re +import requests +from io import BytesIO +from pathlib import Path +from tqdm import tqdm +from omegaconf import OmegaConf +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, + create_vae_diffusers_config, + convert_ldm_vae_checkpoint, +) + + +def get_path_to_diffusers_checkpoint(custom_weights): + path = Path(custom_weights) + diffusers_path = path.parent.absolute() + diffusers_directory_name = os.path.join("diffusers", path.stem) + complete_path_to_diffusers = diffusers_path / diffusers_directory_name + complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) + path_to_diffusers = complete_path_to_diffusers.as_posix() + return path_to_diffusers + + +def preprocessCKPT(custom_weights, is_inpaint=False): + path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) + if next(Path(path_to_diffusers).iterdir(), None): + print("Checkpoint already loaded at : ", path_to_diffusers) + return + else: + print( + "Diffusers' checkpoint will be identified here : ", + path_to_diffusers, + ) + from_safetensors = ( + True if custom_weights.lower().endswith(".safetensors") else False + ) + # EMA weights usually yield higher quality images for inference but + # non-EMA weights have been yielding better results in our case. + # TODO: Add an option `--ema` (`--no-ema`) for users to specify if + # they want to go for EMA weight extraction or not. + extract_ema = False + print("Loading diffusers' pipeline from original stable diffusion checkpoint") + num_in_channels = 9 if is_inpaint else 4 + pipe = download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict=custom_weights, + extract_ema=extract_ema, + from_safetensors=from_safetensors, + num_in_channels=num_in_channels, + ) + pipe.save_pretrained(path_to_diffusers) + print("Loading complete") + + +def convert_original_vae(vae_checkpoint): + vae_state_dict = {} + for key in list(vae_checkpoint.keys()): + vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) + + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/" + "main/configs/stable-diffusion/v1-inference.yaml" + ) + original_config_file = BytesIO(requests.get(config_url).content) + original_config = OmegaConf.load(original_config_file) + vae_config = create_vae_diffusers_config(original_config, image_size=512) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config) + return converted_vae_checkpoint + + +def process_custom_pipe_weights(custom_weights): + if custom_weights != "": + if custom_weights.startswith("https://civitai.com/api/"): + # download the checkpoint from civitai if we don't already have it + weights_path = get_civitai_checkpoint(custom_weights) + + # act as if we were given the local file as custom_weights originally + custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path) + custom_weights_params = weights_path + + else: + assert custom_weights.lower().endswith( + (".ckpt", ".safetensors") + ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" + custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights) + custom_weights_params = custom_weights + return custom_weights_params, custom_weights_tgt + + +def get_civitai_checkpoint(url: str): + with requests.get(url, allow_redirects=True, stream=True) as response: + response.raise_for_status() + + # civitai api returns the filename in the content disposition + base_filename = re.findall( + '"([^"]*)"', response.headers["Content-Disposition"] + )[0] + destination_path = Path.cwd() / (cmd_opts.ckpt_dir or "models") / base_filename + + # we don't have this model downloaded yet + if not destination_path.is_file(): + print(f"downloading civitai model from {url} to {destination_path}") + + size = int(response.headers["content-length"], 0) + progress_bar = tqdm(total=size, unit="iB", unit_scale=True) + + with open(destination_path, "wb") as f: + for chunk in response.iter_content(chunk_size=65536): + f.write(chunk) + progress_bar.update(len(chunk)) + + progress_bar.close() + + # we already have this model downloaded + else: + print(f"civitai model already downloaded to {destination_path}") + + response.close() + return destination_path.as_posix() diff --git a/apps/shark_studio/modules/embeddings.py b/apps/shark_studio/modules/embeddings.py new file mode 100644 index 0000000000..95d228d7c5 --- /dev/null +++ b/apps/shark_studio/modules/embeddings.py @@ -0,0 +1,185 @@ +import os +import sys +import torch +import json +import safetensors +from dataclasses import dataclass +from safetensors.torch import load_file +from apps.shark_studio.web.utils.file_utils import ( + get_checkpoint_pathfile, + get_path_stem, +) + + +@dataclass +class LoRAweight: + up: torch.tensor + down: torch.tensor + mid: torch.tensor + alpha: torch.float32 = 1.0 + + +def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75): + state_dict = "" + if ".safetensors" in use_lora: + state_dict = load_file(use_lora) + else: + state_dict = torch.load(use_lora) + + # gather the weights from the LoRA in a more convenient form, assumes + # everything will have an up.weight. + weight_dict: dict[str, LoRAweight] = {} + for key in state_dict: + if key.startswith(splitting_prefix) and key.endswith("up.weight"): + stem = key.split("up.weight")[0] + weight_key = stem.removesuffix(".lora_") + weight_key = weight_key.removesuffix("_lora_") + weight_key = weight_key.removesuffix(".lora_linear_layer.") + + if weight_key not in weight_dict: + weight_dict[weight_key] = LoRAweight( + state_dict[f"{stem}up.weight"], + state_dict[f"{stem}down.weight"], + state_dict.get(f"{stem}mid.weight", None), + ( + state_dict[f"{weight_key}.alpha"] + / state_dict[f"{stem}up.weight"].shape[1] + if f"{weight_key}.alpha" in state_dict + else 1.0 + ), + ) + + # Directly update weight in model + + # Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py + # and similar code in https://github.com/huggingface/diffusers/issues/3064 + + # TODO: handle mid weights (how do they even work?) + for key, lora_weight in weight_dict.items(): + curr_layer = model + layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_") + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + weight = curr_layer.weight.data + scale = lora_weight.alpha * lora_strength + if len(weight.size()) == 2: + if len(lora_weight.up.shape) == 4: + weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) + change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + change = torch.mm(lora_weight.up, lora_weight.down) + elif lora_weight.down.size()[2:4] == (1, 1): + weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32) + weight_down = lora_weight.down.squeeze(3).squeeze(2).to(torch.float32) + change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + change = torch.nn.functional.conv2d( + lora_weight.down.permute(1, 0, 2, 3), + lora_weight.up, + ).permute(1, 0, 2, 3) + + curr_layer.weight.data += change * scale + + return model + + +def update_lora_weight_for_unet(unet, use_lora, lora_strength): + extensions = [".bin", ".safetensors", ".pt"] + if not any([extension in use_lora for extension in extensions]): + # We assume if it is a HF ID with standalone LoRA weights. + unet.load_attn_procs(use_lora) + return unet + + main_file_name = get_path_stem(use_lora) + if ".bin" in use_lora: + main_file_name += ".bin" + elif ".safetensors" in use_lora: + main_file_name += ".safetensors" + elif ".pt" in use_lora: + main_file_name += ".pt" + else: + sys.exit("Only .bin and .safetensors format for LoRA is supported") + + try: + dir_name = os.path.dirname(use_lora) + unet.load_attn_procs(dir_name, weight_name=main_file_name) + return unet + except: + return processLoRA(unet, use_lora, "lora_unet_", lora_strength) + + +def update_lora_weight(model, use_lora, model_name, lora_strength=1.0): + if "unet" in model_name: + return update_lora_weight_for_unet(model, use_lora, lora_strength) + try: + return processLoRA(model, use_lora, "lora_te_", lora_strength) + except: + return None + + +def get_lora_metadata(lora_filename): + # get the metadata from the file + filename = get_checkpoint_pathfile(lora_filename, "lora") + with safetensors.safe_open(filename, framework="pt", device="cpu") as f: + metadata = f.metadata() + + # guard clause for if there isn't any metadata + if not metadata: + return None + + # metadata is a dictionary of strings, the values of the keys we're + # interested in are actually json, and need to be loaded as such + tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}"))) + dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}"))) + tag_dirs = [dir for dir in tag_frequencies.keys()] + + # gather the tag frequency information for all the datasets trained + all_frequencies = {} + for dataset in tag_dirs: + frequencies = sorted( + [entry for entry in tag_frequencies[dataset].items()], + reverse=True, + key=lambda x: x[1], + ) + + # get a figure for the total number of images processed for this dataset + # either then number actually listed or in its dataset_dir entry or + # the highest frequency's number if that doesn't exist + img_count = dataset_dirs.get(dir, {}).get("img_count", frequencies[0][1]) + + # add the dataset frequencies to the overall frequencies replacing the + # frequency counts on the tags with a percentage/ratio + all_frequencies.update( + [(entry[0], entry[1] / img_count) for entry in frequencies] + ) + + trained_model_id = " ".join( + [ + metadata.get("ss_sd_model_hash", ""), + metadata.get("ss_sd_model_name", ""), + metadata.get("ss_base_model_version", ""), + ] + ).strip() + + # return the topmost of all frequencies in all datasets + return { + "model": trained_model_id, + "frequencies": sorted( + all_frequencies.items(), reverse=True, key=lambda x: x[1] + ), + } diff --git a/apps/shark_studio/modules/img_processing.py b/apps/shark_studio/modules/img_processing.py new file mode 100644 index 0000000000..401c042ad2 --- /dev/null +++ b/apps/shark_studio/modules/img_processing.py @@ -0,0 +1,202 @@ +import os +import re +import json +import torch +import numpy as np + +from csv import DictWriter +from PIL import Image, PngImagePlugin +from pathlib import Path +from datetime import datetime as dt +from base64 import decode + + +resamplers = { + "Lanczos": Image.Resampling.LANCZOS, + "Nearest Neighbor": Image.Resampling.NEAREST, + "Bilinear": Image.Resampling.BILINEAR, + "Bicubic": Image.Resampling.BICUBIC, + "Hamming": Image.Resampling.HAMMING, + "Box": Image.Resampling.BOX, +} + +resampler_list = resamplers.keys() + + +# save output images and the inputs corresponding to it. +def save_output_img(output_img, img_seed, extra_info=None): + from apps.shark_studio.web.utils.file_utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, + ) + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + + if extra_info is None: + extra_info = {} + generated_imgs_path = Path( + get_generated_imgs_path(), get_generated_imgs_todays_subdir() + ) + generated_imgs_path.mkdir(parents=True, exist_ok=True) + csv_path = Path(generated_imgs_path, "imgs_details.csv") + + prompt_slice = re.sub("[^a-zA-Z0-9]", "_", extra_info["prompt"][0][:15]) + out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}" + + img_model = extra_info["base_model_id"] + if extra_info["custom_weights"] not in [None, "None"]: + img_model = Path(os.path.basename(extra_info["custom_weights"])).stem + + img_vae = None + if extra_info["custom_vae"]: + img_vae = Path(os.path.basename(extra_info["custom_vae"])).stem + + img_loras = None + if extra_info["embeddings"]: + img_lora = [] + for i in extra_info["embeddings"]: + img_lora += Path(os.path.basename(cmd_opts.use_lora)).stem + img_loras = ", ".join(img_lora) + + if cmd_opts.output_img_format == "jpg": + out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") + output_img.save(out_img_path, quality=95, subsampling=0) + else: + out_img_path = Path(generated_imgs_path, f"{out_img_name}.png") + pngInfo = PngImagePlugin.PngInfo() + + if cmd_opts.write_metadata_to_png: + # Using a conditional expression caused problems, so setting a new + # variable for now. + # if cmd_opts.use_hiresfix: + # png_size_text = ( + # f"{cmd_opts.hiresfix_width}x{cmd_opts.hiresfix_height}" + # ) + # else: + png_size_text = f"{extra_info['width']}x{extra_info['height']}" + + pngInfo.add_text( + "parameters", + f"{extra_info['prompt'][0]}" + f"\nNegative prompt: {extra_info['negative_prompt'][0]}" + f"\nSteps: {extra_info['steps']}," + f"Sampler: {extra_info['scheduler']}, " + f"CFG scale: {extra_info['guidance_scale']}, " + f"Seed: {img_seed}," + f"Size: {png_size_text}, " + f"Model: {img_model}, " + f"VAE: {img_vae}, " + f"LoRA: {img_loras}", + ) + + output_img.save(out_img_path, "PNG", pnginfo=pngInfo) + + if cmd_opts.output_img_format not in ["png", "jpg"]: + print( + f"[ERROR] Format {cmd_opts.output_img_format} is not " + f"supported yet. Image saved as png instead." + f"Supported formats: png / jpg" + ) + + # To be as low-impact as possible to the existing CSV format, we append + # "VAE" and "LORA" to the end. However, it does not fit the hierarchy of + # importance for each data point. Something to consider. + new_entry = {} + + new_entry.update(extra_info) + + csv_mode = "a" if os.path.isfile(csv_path) else "w" + with open(csv_path, csv_mode, encoding="utf-8") as csv_obj: + dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys())) + if csv_mode == "w": + dictwriter_obj.writeheader() + dictwriter_obj.writerow(new_entry) + csv_obj.close() + + json_path = Path(generated_imgs_path, f"{out_img_name}.json") + with open(json_path, "w") as f: + json.dump(new_entry, f, indent=4) + + +# For stencil, the input image can be of any size, but we need to ensure that +# it conforms with our model constraints :- +# Both width and height should be in the range of [128, 768] and multiple of 8. +# This utility function performs the transformation on the input image while +# also maintaining the aspect ratio before sending it to the stencil pipeline. +def resize_stencil(image: Image.Image, width, height, resampler_type=None): + aspect_ratio = width / height + min_size = min(width, height) + if min_size < 128: + n_size = 128 + if width == min_size: + width = n_size + height = n_size / aspect_ratio + else: + height = n_size + width = n_size * aspect_ratio + width = int(width) + height = int(height) + n_width = width // 8 + n_height = height // 8 + n_width *= 8 + n_height *= 8 + + min_size = min(width, height) + if min_size > 768: + n_size = 768 + if width == min_size: + height = n_size + width = n_size * aspect_ratio + else: + width = n_size + height = n_size / aspect_ratio + width = int(width) + height = int(height) + n_width = width // 8 + n_height = height // 8 + n_width *= 8 + n_height *= 8 + if resampler_type in resamplers: + resampler = resamplers[resampler_type] + else: + resampler = resamplers["Nearest Neighbor"] + new_image = image.resize((n_width, n_height), resampler=resampler) + return new_image, n_width, n_height + + +def process_sd_init_image(self, sd_init_image, resample_type): + if isinstance(sd_init_image, list): + images = [] + for img in sd_init_image: + img, _ = self.process_sd_init_image(img, resample_type) + images.append(img) + is_img2img = True + return images, is_img2img + if isinstance(sd_init_image, str): + if os.path.isfile(sd_init_image): + sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB") + image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type) + else: + image = None + is_img2img = False + elif isinstance(sd_init_image, Image.Image): + image = sd_init_image.convert("RGB") + elif sd_init_image: + image = sd_init_image["image"].convert("RGB") + else: + image = None + is_img2img = False + if image: + resample_type = ( + resamplers[resample_type] + if resample_type in resampler_list + # Fallback to Lanczos + else Image.Resampling.LANCZOS + ) + image = image.resize((self.width, self.height), resample=resample_type) + image_arr = np.stack([np.array(i) for i in (image,)], axis=0) + image_arr = image_arr / 255.0 + image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype) + image_arr = 2 * (image_arr - 0.5) + is_img2img = True + image = image_arr + return image, is_img2img diff --git a/apps/shark_studio/modules/logger.py b/apps/shark_studio/modules/logger.py new file mode 100644 index 0000000000..bff6c933b7 --- /dev/null +++ b/apps/shark_studio/modules/logger.py @@ -0,0 +1,37 @@ +import sys + + +class Logger: + def __init__(self, filename, filter=None): + self.terminal = sys.stdout + self.log = open(filename, "w") + self.filter = filter + + def write(self, message): + for x in message.split("\n"): + if self.filter in x: + self.log.write(message) + else: + self.terminal.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + + +def logger_test(x): + print("[LOG] This is a test") + print(f"This is another test, without the filter") + return x + + +def read_sd_logs(): + sys.stdout.flush() + with open("shark_tmp/sd.log", "r") as f: + return f.read() + + +sys.stdout = Logger("shark_tmp/sd.log", filter="[LOG]") diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py new file mode 100644 index 0000000000..053858c5df --- /dev/null +++ b/apps/shark_studio/modules/pipeline.py @@ -0,0 +1,207 @@ +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, + clean_device_info, + get_iree_target_triple, +) +from apps.shark_studio.web.utils.file_utils import ( + get_checkpoints_path, + get_resource_path, +) +from apps.shark_studio.modules.shared_cmd_opts import ( + cmd_opts, +) +from iree import runtime as ireert +from pathlib import Path +import gc +import os + + +class SharkPipelineBase: + # This class is a lightweight base for managing an + # inference API class. It should provide methods for: + # - compiling a set (model map) of torch IR modules + # - preparing weights for an inference job + # - loading weights for an inference job + # - utilites like benchmarks, tests + + def __init__( + self, + model_map: dict, + base_model_id: str, + static_kwargs: dict, + device: str, + import_mlir: bool = True, + ): + self.model_map = model_map + self.pipe_map = {} + self.static_kwargs = static_kwargs + self.base_model_id = base_model_id + self.triple = get_iree_target_triple(device) + self.device, self.device_id = clean_device_info(device) + self.import_mlir = import_mlir + self.iree_module_dict = {} + self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp")) + if not os.path.exists(self.tmp_dir): + os.mkdir(self.tmp_dir) + self.tempfiles = {} + self.pipe_vmfb_path = "" + + def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None: + # First checks whether we have .vmfbs precompiled, then populates the map + # with the precompiled executables and fetches executables for the rest of the map. + # The weights aren't static here anymore so this function should be a part of pipeline + # initialization. As soon as you have a pipeline ID unique to your static torch IR parameters, + # and your model map is populated with any IR - unique model IDs and their static params, + # call this method to get the artifacts associated with your map. + self.pipe_id = self.safe_name(pipe_id) + self.pipe_vmfb_path = Path( + os.path.join(get_checkpoints_path(".."), self.pipe_id) + ) + self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True) + if submodel == "None": + print("\n[LOG] Gathering any pre-compiled artifacts....") + for key in self.model_map: + self.get_compiled_map(pipe_id, submodel=key) + else: + self.pipe_map[submodel] = {} + self.get_precompiled(self.pipe_id, submodel) + ireec_flags = [] + if submodel in self.iree_module_dict: + return + elif "vmfb_path" in self.pipe_map[submodel]: + return + elif submodel not in self.tempfiles: + print( + f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..." + ) + if submodel in self.static_kwargs: + init_kwargs = self.static_kwargs[submodel] + for key in self.static_kwargs["pipe"]: + if key not in init_kwargs: + init_kwargs[key] = self.static_kwargs["pipe"][key] + self.import_torch_ir(submodel, init_kwargs) + self.get_compiled_map(pipe_id, submodel) + else: + ireec_flags = ( + self.model_map[submodel]["ireec_flags"] + if "ireec_flags" in self.model_map[submodel] + else [] + ) + + weights_path = self.get_io_params(submodel) + if weights_path: + ireec_flags.append("--iree-opt-const-eval=False") + + self.iree_module_dict[submodel] = get_iree_compiled_module( + self.tempfiles[submodel], + device=self.device, + frontend="torch", + mmap=True, + external_weight_file=weights_path, + extra_args=ireec_flags, + write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"), + ) + return + + def get_io_params(self, submodel): + if "external_weight_file" in self.static_kwargs[submodel]: + # we are using custom weights + weights_path = self.static_kwargs[submodel]["external_weight_file"] + elif "external_weight_path" in self.static_kwargs[submodel]: + # we are using the default weights for the HF model + weights_path = self.static_kwargs[submodel]["external_weight_path"] + else: + # assume the torch IR contains the weights. + weights_path = None + return weights_path + + def get_precompiled(self, pipe_id, submodel="None"): + if submodel == "None": + for model in self.model_map: + self.get_precompiled(pipe_id, model) + vmfbs = [] + for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path): + vmfbs.extend(filenames) + break + for file in vmfbs: + if submodel in file: + self.pipe_map[submodel]["vmfb_path"] = os.path.join( + self.pipe_vmfb_path, file + ) + return + + def import_torch_ir(self, submodel, kwargs): + torch_ir = self.model_map[submodel]["initializer"]( + **self.safe_dict(kwargs), compile_to="torch" + ) + if submodel == "clip": + # clip.export_clip_model returns (torch_ir, tokenizer) + torch_ir = torch_ir[0] + + self.tempfiles[submodel] = os.path.join( + self.tmp_dir, f"{submodel}.torch.tempfile" + ) + + with open(self.tempfiles[submodel], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + return + + def load_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + print(f"\n[LOG] {submodel} is ready for inference.") + continue + if "vmfb_path" in self.pipe_map[submodel]: + weights_path = self.get_io_params(submodel) + # print( + # f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}" + # ) + self.iree_module_dict[submodel] = {} + ( + self.iree_module_dict[submodel]["vmfb"], + self.iree_module_dict[submodel]["config"], + self.iree_module_dict[submodel]["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.pipe_map[submodel]["vmfb_path"], + self.device, + device_idx=0, + rt_flags=[], + external_weight_file=weights_path, + ) + else: + self.get_compiled_map(self.pipe_id, submodel) + return + + def unload_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + del self.iree_module_dict[submodel] + gc.collect() + return + + def run(self, submodel, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + inp = [ + ireert.asdevicearray( + self.iree_module_dict[submodel]["config"].device, input + ) + for input in inputs + ] + return self.iree_module_dict[submodel]["vmfb"]["main"](*inp) + + def safe_name(self, name): + return name.replace("/", "_").replace("-", "_").replace("\\", "_") + + def safe_dict(self, kwargs: dict): + flat_args = {} + for i in kwargs: + if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]: + flat_args[i] = [kwargs[i][j] for j in kwargs[i]] + else: + flat_args[i] = kwargs[i] + + return flat_args diff --git a/apps/shark_studio/modules/prompt_encoding.py b/apps/shark_studio/modules/prompt_encoding.py new file mode 100644 index 0000000000..3dc61aba08 --- /dev/null +++ b/apps/shark_studio/modules/prompt_encoding.py @@ -0,0 +1,376 @@ +from typing import List, Optional, Union +from iree import runtime as ireert +import re +import torch +import numpy as np + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: + text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print( + "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" + ) + return tokens, weights + + +def pad_tokens_and_weights( + tokens, + weights, + max_length, + bos, + eos, + no_boseos_middle=True, + chunk_length=77, +): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = ( + max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + ) + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][ + j + * (chunk_length - 2) : min( + len(weights[i]), (j + 1) * (chunk_length - 2) + ) + ] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe, + text_input, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[ + :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 + ].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + + text_embedding = pipe.run("clip", text_input_chunk)[0].to_host() + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + # SHARK: Convert the result to tensor + # text_embeddings = torch.concat(text_embeddings, axis=1) + text_embeddings_np = np.concatenate(np.array(text_embeddings)) + text_embeddings = torch.from_numpy(text_embeddings_np) + else: + text_embeddings = pipe.run("clip", text_input)[0] + text_embeddings = torch.from_numpy(text_embeddings.to_host()) + return text_embeddings + + +# This function deals with NoneType values occuring in tokens after padding +# It switches out None with 49407 as truncating None values causes matrix dimension errors, +def filter_nonetype_tokens(tokens: List[List]): + return [[49407 if token is None else token for token in tokens[0]]] + + +def get_weighted_text_embeddings( + pipe, + prompt: List[str], + uncond_prompt: List[str] = None, + max_embeddings_multiples: Optional[int] = 8, + no_boseos_middle: Optional[bool] = True, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, +): + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights( + pipe, prompt, max_length - 2 + ) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = get_prompts_with_weights( + pipe, uncond_prompt, max_length - 2 + ) + else: + prompt_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + prompt, max_length=max_length, truncation=True + ).input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer( + uncond_prompt, max_length=max_length, truncation=True + ).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + + max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.model_max_length, + ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + prompt_tokens = filter_nonetype_tokens(prompt_tokens) + + # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu") + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.model_max_length, + ) + + # FIXME: This is a hacky fix caused by tokenizer padding with None values + uncond_tokens = filter_nonetype_tokens(uncond_tokens) + + # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device="cpu") + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + prompt_weights = torch.tensor(prompt_weights, dtype=torch.float, device="cpu") + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + uncond_weights = torch.tensor(uncond_weights, dtype=torch.float, device="cpu") + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = ( + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + ) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = ( + text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + ) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = ( + uncond_embeddings.float() + .mean(axis=[-2, -1]) + .to(uncond_embeddings.dtype) + ) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = ( + uncond_embeddings.float() + .mean(axis=[-2, -1]) + .to(uncond_embeddings.dtype) + ) + uncond_embeddings *= ( + (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + ) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None diff --git a/apps/shark_studio/modules/schedulers.py b/apps/shark_studio/modules/schedulers.py new file mode 100644 index 0000000000..3e931b1c78 --- /dev/null +++ b/apps/shark_studio/modules/schedulers.py @@ -0,0 +1,117 @@ +# from shark_turbine.turbine_models.schedulers import export_scheduler_model +from diffusers import ( + LCMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDPMScheduler, + DDIMScheduler, + DPMSolverMultistepScheduler, + KDPM2DiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DEISMultistepScheduler, + DPMSolverSinglestepScheduler, + KDPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler, +) + + +def get_schedulers(model_id): + # TODO: switch over to turbine and run all on GPU + print(f"\n[LOG] Initializing schedulers from model id: {model_id}") + schedulers = dict() + schedulers["PNDM"] = PNDMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDPM"] = DDPMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DDIM"] = DDIMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["LCMScheduler"] = LCMScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver" + ) + schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained( + model_id, subfolder="scheduler", algorithm_type="dpmsolver++" + ) + schedulers["DPMSolverMultistepKarras"] = ( + DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + use_karras_sigmas=True, + ) + ) + schedulers["DPMSolverMultistepKarras++"] = ( + DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="dpmsolver++", + use_karras_sigmas=True, + ) + ) + schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["EulerAncestralDiscrete"] = ( + EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + ) + schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["DPMSolverSinglestep"] = DPMSolverSinglestepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + schedulers["KDPM2AncestralDiscrete"] = ( + KDPM2AncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + ) + schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) + return schedulers + + +def export_scheduler_model(model): + return "None", "None" + + +scheduler_model_map = { + "EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"), + "EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"), + "LCM": export_scheduler_model("LCMScheduler"), + "LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"), + "PNDM": export_scheduler_model("PNDMScheduler"), + "DDPM": export_scheduler_model("DDPMScheduler"), + "DDIM": export_scheduler_model("DDIMScheduler"), + "DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"), + "KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"), + "DEISMultistep": export_scheduler_model("DEISMultistepScheduler"), + "DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"), + "KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"), + "HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"), +} diff --git a/apps/shark_studio/modules/seed.py b/apps/shark_studio/modules/seed.py new file mode 100644 index 0000000000..d0b022a6f1 --- /dev/null +++ b/apps/shark_studio/modules/seed.py @@ -0,0 +1,66 @@ +import numpy as np +import json +from random import ( + randint, + seed as seed_random, + getstate as random_getstate, + setstate as random_setstate, +) + + +# Generate and return a new seed if the provided one is not in the +# supported range (including -1) +def sanitize_seed(seed: int | str): + seed = int(seed) + uint32_info = np.iinfo(np.uint32) + uint32_min, uint32_max = uint32_info.min, uint32_info.max + if seed < uint32_min or seed >= uint32_max: + seed = randint(uint32_min, uint32_max) + return seed + + +# take a seed expression in an input format and convert it to +# a list of integers, where possible +def parse_seed_input(seed_input: str | list | int): + if isinstance(seed_input, str): + try: + seed_input = json.loads(seed_input) + except (ValueError, TypeError): + seed_input = None + + if isinstance(seed_input, int): + return [seed_input] + + if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input): + return seed_input + + raise TypeError( + "Seed input must be an integer or an array of integers in JSON format" + ) + + +# Generate a set of seeds from an input expression for batch_count batches, +# optionally using that input as the rng seed for any randomly generated seeds. +def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False): + # turn the input into a list if possible + seeds = parse_seed_input(seed_input) + + # slice or pad the list to be of batch_count length + seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds)) + + if repeatable: + if all(seed < 0 for seed in seeds): + seeds[0] = sanitize_seed(seeds[0]) + + # set seed for the rng based on what we have so far + saved_random_state = random_getstate() + seed_random(str([n for n in seeds if n > -1])) + + # generate any seeds that are unspecified + seeds = [sanitize_seed(seed) for seed in seeds] + + if repeatable: + # reset the rng back to normal + random_setstate(saved_random_state) + + return seeds diff --git a/apps/shark_studio/modules/shared_cmd_opts.py b/apps/shark_studio/modules/shared_cmd_opts.py new file mode 100644 index 0000000000..7992660d96 --- /dev/null +++ b/apps/shark_studio/modules/shared_cmd_opts.py @@ -0,0 +1,776 @@ +import argparse +import os +from pathlib import Path + +from apps.shark_studio.modules.img_processing import resampler_list + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# Stable Diffusion Params +############################################################################## + +p.add_argument( + "-a", + "--app", + default="txt2img", + help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.", +) +p.add_argument( + "-p", + "--prompt", + nargs="+", + default=[ + "a photo taken of the front of a super-car drifting on a road near " + "mountains at high speeds with smoke coming off the tires, front " + "angle, front point of view, trees in the mountains of the " + "background, ((sharp focus))" + ], + help="Text of which images to be generated.", +) + +p.add_argument( + "--negative_prompt", + nargs="+", + default=[ + "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), " + "blurry, ugly, blur, oversaturated, cropped" + ], + help="Text you don't want to see in the generated image.", +) + +p.add_argument( + "--sd_init_image", + type=str, + help="Path to the image input for img2img/inpainting.", +) + +p.add_argument( + "--steps", + type=int, + default=50, + help="The number of steps to do the sampling.", +) + +p.add_argument( + "--seed", + type=str, + default=-1, + help="The seed or list of seeds to use. -1 for a random one.", +) + +p.add_argument( + "--batch_size", + type=int, + default=1, + choices=range(1, 4), + help="The number of inferences to be made in a single `batch_count`.", +) + +p.add_argument( + "--height", + type=int, + default=512, + choices=range(128, 1025, 8), + help="The height of the output image.", +) + +p.add_argument( + "--width", + type=int, + default=512, + choices=range(128, 1025, 8), + help="The width of the output image.", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="The value to be used for guidance scaling.", +) + +p.add_argument( + "--noise_level", + type=int, + default=20, + help="The value to be used for noise level of upscaler.", +) + +p.add_argument( + "--max_length", + type=int, + default=64, + help="Max length of the tokenizer output, options are 64 and 77.", +) + +p.add_argument( + "--max_embeddings_multiples", + type=int, + default=5, + help="The max multiple length of prompt embeddings compared to the max " + "output length of text encoder.", +) + +p.add_argument( + "--strength", + type=float, + default=0.8, + help="The strength of change applied on the given input image for " "img2img.", +) + +p.add_argument( + "--use_hiresfix", + type=bool, + default=False, + help="Use Hires Fix to do higher resolution images, while trying to " + "avoid the issues that come with it. This is accomplished by first " + "generating an image using txt2img, then running it through img2img.", +) + +p.add_argument( + "--hiresfix_height", + type=int, + default=768, + choices=range(128, 769, 8), + help="The height of the Hires Fix image.", +) + +p.add_argument( + "--hiresfix_width", + type=int, + default=768, + choices=range(128, 769, 8), + help="The width of the Hires Fix image.", +) + +p.add_argument( + "--hiresfix_strength", + type=float, + default=0.6, + help="The denoising strength to apply for the Hires Fix.", +) + +p.add_argument( + "--resample_type", + type=str, + default="Nearest Neighbor", + choices=resampler_list, + help="The resample type to use when resizing an image before being run " + "through stable diffusion.", +) + +############################################################################## +# Stable Diffusion Training Params +############################################################################## + +p.add_argument( + "--lora_save_dir", + type=str, + default="models/lora/", + help="Directory to save the lora fine tuned model.", +) + +p.add_argument( + "--training_images_dir", + type=str, + default="models/lora/training_images/", + help="Directory containing images that are an example of the prompt.", +) + +p.add_argument( + "--training_steps", + type=int, + default=2000, + help="The number of steps to train.", +) + +############################################################################## +# Inpainting and Outpainting Params +############################################################################## + +p.add_argument( + "--mask_path", + type=str, + help="Path to the mask image input for inpainting.", +) + +p.add_argument( + "--inpaint_full_res", + default=False, + action=argparse.BooleanOptionalAction, + help="If inpaint only masked area or whole picture.", +) + +p.add_argument( + "--inpaint_full_res_padding", + type=int, + default=32, + choices=range(0, 257, 4), + help="Number of pixels for only masked padding.", +) + +p.add_argument( + "--pixels", + type=int, + default=128, + choices=range(8, 257, 8), + help="Number of expended pixels for one direction for outpainting.", +) + +p.add_argument( + "--mask_blur", + type=int, + default=8, + choices=range(0, 65), + help="Number of blur pixels for outpainting.", +) + +p.add_argument( + "--left", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend left for outpainting.", +) + +p.add_argument( + "--right", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend right for outpainting.", +) + +p.add_argument( + "--up", + "--top", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend top for outpainting.", +) + +p.add_argument( + "--down", + "--bottom", + default=False, + action=argparse.BooleanOptionalAction, + help="If extend bottom for outpainting.", +) + +p.add_argument( + "--noise_q", + type=float, + default=1.0, + help="Fall-off exponent for outpainting (lower=higher detail) " + "(min=0.0, max=4.0).", +) + +p.add_argument( + "--color_variation", + type=float, + default=0.05, + help="Color variation for outpainting (min=0.0, max=1.0).", +) + +############################################################################## +# Model Config and Usage Params +############################################################################## + +p.add_argument("--device", type=str, default="vulkan", help="Device to run the model.") + +p.add_argument( + "--precision", type=str, default="fp16", help="Precision to run the model." +) + +p.add_argument( + "--import_mlir", + default=True, + action=argparse.BooleanOptionalAction, + help="Imports the model from torch module to shark_module otherwise " + "downloads the model from shark_tank.", +) + +p.add_argument( + "--use_tuned", + default=False, + action=argparse.BooleanOptionalAction, + help="Download and use the tuned version of the model if available.", +) + +p.add_argument( + "--use_base_vae", + default=False, + action=argparse.BooleanOptionalAction, + help="Do conversion from the VAE output to pixel space on cpu.", +) + +p.add_argument( + "--scheduler", + type=str, + default="DDIM", + help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, " + "DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, " + "DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, " + "DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, " + "HeunDiscrete].", +) + +p.add_argument( + "--output_img_format", + type=str, + default="png", + help="Specify the format in which output image is save. " + "Supported options: jpg / png.", +) + +p.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory path to save the output images and json.", +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to be generated with random seeds in " "single execution.", +) + +p.add_argument( + "--repeatable_seeds", + default=False, + action=argparse.BooleanOptionalAction, + help="The seed of the first batch will be used as the rng seed to " + "generate the subsequent seeds for subsequent batches in that run.", +) + +p.add_argument( + "--custom_weights", + type=str, + default="", + help="Path to a .safetensors or .ckpt file for SD pipeline weights.", +) + +p.add_argument( + "--custom_vae", + type=str, + default="", + help="HuggingFace repo-id or path to SD model's checkpoint whose VAE " + "needs to be plugged in.", +) + +p.add_argument( + "--base_model_id", + type=str, + default="stabilityai/stable-diffusion-2-1-base", + help="The repo-id of hugging face.", +) + +p.add_argument( + "--low_cpu_mem_usage", + default=False, + action=argparse.BooleanOptionalAction, + help="Use the accelerate package to reduce cpu memory consumption.", +) + +p.add_argument( + "--attention_slicing", + type=str, + default="none", + help="Amount of attention slicing to use (one of 'max', 'auto', 'none', " + "or an integer).", +) + +p.add_argument( + "--use_stencil", + choices=["canny", "openpose", "scribble", "zoedepth"], + help="Enable the stencil feature.", +) + +p.add_argument( + "--control_mode", + choices=["Prompt", "Balanced", "Controlnet"], + default="Balanced", + help="How Controlnet injection should be prioritized.", +) + +p.add_argument( + "--use_lora", + type=str, + default="", + help="Use standalone LoRA weight using a HF ID or a checkpoint " "file (~3 MB).", +) + +p.add_argument( + "--use_quantize", + type=str, + default="none", + help="Runs the quantized version of stable diffusion model. " + "This is currently in experimental phase. " + "Currently, only runs the stable-diffusion-2-1-base model in " + "int8 quantization.", +) + +p.add_argument( + "--lowvram", + default=False, + action=argparse.BooleanOptionalAction, + help="Load and unload models for low VRAM.", +) + +p.add_argument( + "--hf_auth_token", + type=str, + default=None, + help="Specify your own huggingface authentication tokens for models like Llama2.", +) + +p.add_argument( + "--external_weights", + type=str, + default=None, + help="What type of externalized weights to use. Currently options are 'safetensors' and defaults to inlined weights.", +) + +p.add_argument( + "--device_allocator_heap_key", + type=str, + default="", + help="Specify heap key for device caching allocator." + "Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count" + "Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)", +) + +############################################################################## +# IREE - Vulkan supported flags +############################################################################## + +p.add_argument( + "--iree_vulkan_target_triple", + type=str, + default="", + help="Specify target triple for vulkan.", +) + +p.add_argument( + "--iree_metal_target_platform", + type=str, + default="", + help="Specify target triple for metal.", +) + +############################################################################## +# Misc. Debug and Optimization flags +############################################################################## + +p.add_argument( + "--use_compiled_scheduler", + default=True, + action=argparse.BooleanOptionalAction, + help="Use the default scheduler precompiled into the model if available.", +) + +p.add_argument( + "--local_tank_cache", + default="", + help="Specify where to save downloaded shark_tank artifacts. " + "If this is not set, the default is ~/.local/shark_tank/.", +) + +p.add_argument( + "--dump_isa", + default=False, + action="store_true", + help="When enabled call amdllpc to get ISA dumps. " "Use with dispatch benchmarks.", +) + +p.add_argument( + "--dispatch_benchmarks", + default=None, + help="Dispatches to return benchmark data on. " + 'Use "All" for all, and None for none.', +) + +p.add_argument( + "--dispatch_benchmarks_dir", + default="temp_dispatch_benchmarks", + help="Directory where you want to store dispatch data " + 'generated with "--dispatch_benchmarks".', +) + +p.add_argument( + "--enable_rgp", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for inserting debug frames between iterations " "for use with rgp.", +) + +p.add_argument( + "--hide_steps", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for hiding the details of iteration/sec for each step.", +) + +p.add_argument( + "--warmup_count", + type=int, + default=0, + help="Flag setting warmup count for CLIP and VAE [>= 0].", +) + +p.add_argument( + "--clear_all", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag to clear all mlir and vmfb from common locations. " + "Recompiling will take several minutes.", +) + +p.add_argument( + "--save_metadata_to_json", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for whether or not to save a generation information " + "json file with the image.", +) + +p.add_argument( + "--write_metadata_to_png", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for whether or not to save generation information in " + "PNG chunk text to generated images.", +) + +p.add_argument( + "--import_debug", + default=False, + action=argparse.BooleanOptionalAction, + help="If import_mlir is True, saves mlir via the debug option " + "in shark importer. Does nothing if import_mlir is false (the default).", +) + +p.add_argument( + "--compile_debug", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag to toggle debug assert/verify flags for imported IR in the" + "iree-compiler. Default to false.", +) + +p.add_argument( + "--iree_constant_folding", + default=True, + action=argparse.BooleanOptionalAction, + help="Controls constant folding in iree-compile for all SD models.", +) + +p.add_argument( + "--data_tiling", + default=False, + action=argparse.BooleanOptionalAction, + help="Controls data tiling in iree-compile for all SD models.", +) + +p.add_argument( + "--quantization", + type=str, + default="None", + help="Quantization to be used for api-exposed model.", +) + +############################################################################## +# Web UI flags +############################################################################## + +p.add_argument( + "--webui", + default=True, + action=argparse.BooleanOptionalAction, + help="controls whether the webui is launched.", +) + +p.add_argument( + "--progress_bar", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for removing the progress bar animation during " "image generation.", +) + +p.add_argument( + "--ckpt_dir", + type=str, + default="../models", + help="Path to directory where all .ckpts are stored in order to populate " + "them in the web UI.", +) +# TODO: replace API flag when these can be run together +p.add_argument( + "--ui", + type=str, + default="app" if os.name == "nt" else "web", + help="One of: [api, app, web].", +) + +p.add_argument( + "--share", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for generating a public URL.", +) + +p.add_argument( + "--server_port", + type=int, + default=8080, + help="Flag for setting server port.", +) + +p.add_argument( + "--api", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for enabling rest API.", +) + +p.add_argument( + "--api_accept_origin", + action="append", + type=str, + help="An origin to be accepted by the REST api for Cross Origin" + "Resource Sharing (CORS). Use multiple times for multiple origins, " + 'or use --api_accept_origin="*" to accept all origins. If no origins ' + "are set no CORS headers will be returned by the api. Use, for " + "instance, if you need to access the REST api from Javascript running " + "in a web browser.", +) + +p.add_argument( + "--debug", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for enabling debugging log in WebUI.", +) + +p.add_argument( + "--output_gallery", + default=True, + action=argparse.BooleanOptionalAction, + help="Flag for removing the output gallery tab, and avoid exposing " + "images under --output_dir in the UI.", +) + +p.add_argument( + "--configs_path", + default=None, + type=str, + help="Path to .json config directory.", +) + +p.add_argument( + "--output_gallery_followlinks", + default=False, + action=argparse.BooleanOptionalAction, + help="Flag for whether the output gallery tab in the UI should " + "follow symlinks when listing subdirectories under --output_dir.", +) + +p.add_argument( + "--api_log", + default=False, + action=argparse.BooleanOptionalAction, + help="Enables Compatibility API logging.", +) + +############################################################################## +# SD model auto-annotation flags +############################################################################## + +p.add_argument( + "--annotation_output", + type=path_expand, + default="./", + help="Directory to save the annotated mlir file.", +) + +p.add_argument( + "--annotation_model", + type=str, + default="unet", + help="Options are unet and vae.", +) + +p.add_argument( + "--save_annotation", + default=False, + action=argparse.BooleanOptionalAction, + help="Save annotated mlir file.", +) +############################################################################## +# SD model auto-tuner flags +############################################################################## + +p.add_argument( + "--tuned_config_dir", + type=path_expand, + default="./", + help="Directory to save the tuned config file.", +) + +p.add_argument( + "--num_iters", + type=int, + default=400, + help="Number of iterations for tuning.", +) + +p.add_argument( + "--search_op", + type=str, + default="all", + help="Op to be optimized, options are matmul, bmm, conv and all.", +) + +############################################################################## +# DocuChat Flags +############################################################################## + +p.add_argument( + "--run_docuchat_web", + default=False, + action=argparse.BooleanOptionalAction, + help="Specifies whether the docuchat's web version is running or not.", +) + +############################################################################## +# rocm Flags +############################################################################## + +p.add_argument( + "--iree_rocm_target_chip", + type=str, + default="", + help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` " + "or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name", +) + +cmd_opts, unknown = p.parse_known_args() +if cmd_opts.import_debug: + os.environ["IREE_SAVE_TEMPS"] = os.path.join( + os.getcwd(), cmd_opts.hf_model_id.replace("/", "_") + ) diff --git a/apps/shark_studio/modules/timer.py b/apps/shark_studio/modules/timer.py new file mode 100644 index 0000000000..d6918e9c8c --- /dev/null +++ b/apps/shark_studio/modules/timer.py @@ -0,0 +1,106 @@ +import time +import argparse + + +class TimerSubcategory: + def __init__(self, timer, category): + self.timer = timer + self.category = category + self.start = None + self.original_base_category = timer.base_category + + def __enter__(self): + self.start = time.time() + self.timer.base_category = self.original_base_category + self.category + "/" + self.timer.subcategory_level += 1 + + if self.timer.print_log: + print(f"{' ' * self.timer.subcategory_level}{self.category}:") + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_for_subcategroy = time.time() - self.start + self.timer.base_category = self.original_base_category + self.timer.add_time_to_record( + self.original_base_category + self.category, + elapsed_for_subcategroy, + ) + self.timer.subcategory_level -= 1 + self.timer.record(self.category, disable_log=True) + + +class Timer: + def __init__(self, print_log=False): + self.start = time.time() + self.records = {} + self.total = 0 + self.base_category = "" + self.print_log = print_log + self.subcategory_level = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def add_time_to_record(self, category, amount): + if category not in self.records: + self.records[category] = 0 + + self.records[category] += amount + + def record(self, category, extra_time=0, disable_log=False): + e = self.elapsed() + + self.add_time_to_record(self.base_category + category, e + extra_time) + + self.total += e + extra_time + + if self.print_log and not disable_log: + print( + f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s" + ) + + def subcategory(self, name): + self.elapsed() + + subcat = TimerSubcategory(self, name) + return subcat + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [ + (category, time_taken) + for category, time_taken in self.records.items() + if time_taken >= 0.1 and "/" not in category + ] + if not additions: + return res + + res += " (" + res += ", ".join( + [f"{category}: {time_taken:.1f}s" for category, time_taken in additions] + ) + res += ")" + + return res + + def dump(self): + return {"total": self.total, "records": self.records} + + def reset(self): + self.__init__() + + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument( + "--log-startup", + action="store_true", + help="print a detailed log of what's happening at startup", +) +args = parser.parse_known_args()[0] + +startup_timer = Timer(print_log=args.log_startup) + +startup_record = None diff --git a/apps/shark_studio/shark_studio.spec b/apps/shark_studio/shark_studio.spec new file mode 100644 index 0000000000..1c87c953db --- /dev/null +++ b/apps/shark_studio/shark_studio.spec @@ -0,0 +1,48 @@ +# -*- mode: python ; coding: utf-8 -*- +from apps.shark_studio.studio_imports import pathex, datas, hiddenimports + +binaries = [] + +block_cipher = None + +a = Analysis( + ['web/index.py'], + pathex=pathex, + binaries=binaries, + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, + module_collection_mode={ + 'gradio': 'py', # Collect gradio package as source .py files + }, +) +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.zipfiles, + a.datas, + [], + name='nodai_shark_studio', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/apps/shark_studio/studio_imports.py b/apps/shark_studio/studio_imports.py new file mode 100644 index 0000000000..3f7aa319ba --- /dev/null +++ b/apps/shark_studio/studio_imports.py @@ -0,0 +1,68 @@ +from PyInstaller.utils.hooks import collect_data_files +from PyInstaller.utils.hooks import copy_metadata +from PyInstaller.utils.hooks import collect_submodules + +import sys + +sys.setrecursionlimit(sys.getrecursionlimit() * 5) + +# python path for pyinstaller +pathex = [ + ".", +] + +# datafiles for pyinstaller +datas = [] +datas += copy_metadata("torch") +datas += copy_metadata("tokenizers") +datas += copy_metadata("tqdm") +datas += copy_metadata("regex") +datas += copy_metadata("requests") +datas += copy_metadata("packaging") +datas += copy_metadata("filelock") +datas += copy_metadata("numpy") +datas += copy_metadata("importlib_metadata") +datas += copy_metadata("omegaconf") +datas += copy_metadata("safetensors") +datas += copy_metadata("Pillow") +datas += copy_metadata("sentencepiece") +datas += copy_metadata("pyyaml") +datas += copy_metadata("huggingface-hub") +datas += copy_metadata("gradio") +datas += copy_metadata("scipy") +datas += collect_data_files("torch") +datas += collect_data_files("tokenizers") +datas += collect_data_files("accelerate") +datas += collect_data_files("diffusers") +datas += collect_data_files("transformers") +datas += collect_data_files("gradio") +datas += collect_data_files("gradio_client") +datas += collect_data_files("iree", include_py_files=True) +datas += collect_data_files("shark", include_py_files=True) +datas += collect_data_files("tqdm") +datas += collect_data_files("tkinter") +datas += collect_data_files("sentencepiece") +datas += collect_data_files("jsonschema") +datas += collect_data_files("jsonschema_specifications") +datas += collect_data_files("cpuinfo") +datas += collect_data_files("scipy", include_py_files=True) +datas += [ + ("web/ui/css/*", "ui/css"), + ("web/ui/js/*", "ui/js"), + ("web/ui/logos/*", "logos"), +] + + +# hidden imports for pyinstaller +hiddenimports = ["shark", "apps"] +hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x] +hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x] +blacklist = ["tests", "convert"] +hiddenimports += [ + x + for x in collect_submodules("transformers") + if not any(kw in x for kw in blacklist) +] +hiddenimports += [x for x in collect_submodules("iree") if "test" not in x] +hiddenimports += ["iree._runtime"] +hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x] diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py index d07bb05b90..7bed2cb7b0 100644 --- a/apps/shark_studio/tests/api_test.py +++ b/apps/shark_studio/tests/api_test.py @@ -6,8 +6,26 @@ import logging import unittest -from apps.shark_studio.api.llm import LanguageModel +import json import gc +from apps.shark_studio.api.llm import LanguageModel, llm_chat_api +from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file +from apps.shark_studio.web.utils.file_utils import get_resource_path + +# class SDAPITest(unittest.TestCase): +# def testSDSimple(self): +# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +# import apps.shark_studio.web.utils.globals as global_obj + +# global_obj._init() + +# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json")) +# sd_kwargs = json.loads(sd_json) +# for arg in vars(cmd_opts): +# if arg in sd_kwargs: +# sd_kwargs[arg] = getattr(cmd_opts, arg) +# for i in shark_sd_fn_dict_input(sd_kwargs): +# print(i) class LLMAPITest(unittest.TestCase): diff --git a/apps/shark_studio/tests/export_unet.py b/apps/shark_studio/tests/export_unet.py new file mode 100644 index 0000000000..0cc8b2deb0 --- /dev/null +++ b/apps/shark_studio/tests/export_unet.py @@ -0,0 +1,41 @@ +import torch +from diffusers import ( + UNet2DConditionModel, +) +from torch.fx.experimental.proxy_tensor import make_fx + + +class UnetModel(torch.nn.Module): + def __init__(self, hf_model_name): + super().__init__() + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + ) + + def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): + samples = torch.cat([sample] * 2) + unet_out = self.unet.forward( + samples, timestep, encoder_hidden_states, return_dict=False + )[0] + noise_pred_uncond, noise_pred_text = unet_out.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + +if __name__ == "__main__": + hf_model_name = "CompVis/stable-diffusion-v1-4" + unet = UnetModel(hf_model_name) + inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5) + + fx_g = make_fx( + unet, + decomposition_table={}, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + _allow_fake_constant=False, + )(*inputs) + + print(fx_g) diff --git a/apps/shark_studio/tests/jupiter.png b/apps/shark_studio/tests/jupiter.png new file mode 100644 index 0000000000..e479e20548 Binary files /dev/null and b/apps/shark_studio/tests/jupiter.png differ diff --git a/apps/shark_studio/tests/rest_api_test.py b/apps/shark_studio/tests/rest_api_test.py new file mode 100644 index 0000000000..741fa523cc --- /dev/null +++ b/apps/shark_studio/tests/rest_api_test.py @@ -0,0 +1,45 @@ +import requests +from PIL import Image +import base64 +from io import BytesIO +import json + + +def llm_chat_test(verbose=False): + # Define values here + prompt = "What is the significance of the number 42?" + + url = "http://127.0.0.1:8080/v1/chat/completions" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + data = { + "model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "messages": [ + { + "role": "", + "content": prompt, + } + ], + "device": "vulkan://0", + "max_tokens": 4096, + } + + res = requests.post(url=url, json=data, headers=headers, timeout=1000) + res_dict = json.loads(res.content.decode("utf-8")) + print(f"[chat] response from server was : {res.status_code} {res.reason}") + + if verbose or res.status_code != 200: + print(f"\n{res_dict['choices'][0]['message']['content']}\n") + + +if __name__ == "__main__": + # "Exercises the chatbot REST API of Shark. Make sure " + # "Shark is running in API mode on 127.0.0.1:8080 before running" + # "this script." + + llm_chat_test(verbose=True) diff --git a/apps/shark_studio/web/api/compat.py b/apps/shark_studio/web/api/compat.py new file mode 100644 index 0000000000..b5e81f2e9a --- /dev/null +++ b/apps/shark_studio/web/api/compat.py @@ -0,0 +1,286 @@ +import base64 +import io +import os +import time +import datetime +import uvicorn +import ipaddress +import requests +import threading +import collections +import gradio as gr +from PIL import Image, PngImagePlugin +from threading import Lock +from io import BytesIO +from fastapi import APIRouter, Depends, FastAPI, Request, Response +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.exceptions import HTTPException +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + +# from sdapi_v1 import shark_sd_api +from apps.shark_studio.api.llm import llm_chat_api + + +def decode_base64_to_image(encoding): + if encoding.startswith("http://") or encoding.startswith("https://"): + headers = {} + response = requests.get(encoding, timeout=30, headers=headers) + try: + image = Image.open(BytesIO(response.content)) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid image url") from e + + if encoding.startswith("data:image/"): + encoding = encoding.split(";")[1].split(",")[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid encoded image") from e + + +def encode_pil_to_base64(image): + with io.BytesIO() as output_bytes: + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save( + output_bytes, + format="PNG", + pnginfo=(metadata if use_metadata else None), + ) + + bytes_data = output_bytes.getvalue() + + return base64.b64encode(bytes_data) + + +# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a +class FIFOLock(object): + def __init__(self): + self._lock = threading.Lock() + self._inner_lock = threading.Lock() + self._pending_threads = collections.deque() + + def acquire(self, blocking=True): + with self._inner_lock: + lock_acquired = self._lock.acquire(False) + if lock_acquired: + return True + elif not blocking: + return False + + release_event = threading.Event() + self._pending_threads.append(release_event) + + release_event.wait() + return self._lock.acquire() + + def release(self): + with self._inner_lock: + if self._pending_threads: + release_event = self._pending_threads.popleft() + release_event.set() + + self._lock.release() + + __enter__ = acquire + + def __exit__(self, t, v, tb): + self.release() + + +def api_middleware(app: FastAPI): + rich_available = False + try: + if os.environ.get("WEBUI_RICH_EXCEPTIONS", None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + + console = Console() + rich_available = True + except Exception: + pass + + @app.middleware("http") + async def log_and_time(req: Request, call_next): + ts = time.time() + res: Response = await call_next(req) + duration = str(round(time.time() - ts, 4)) + res.headers["X-Process-Time"] = duration + endpoint = req.scope.get("path", "err") + if cmd_opts.api_log and endpoint.startswith("/sdapi"): + print( + "API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format( + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get("http_version", "0.0"), + cli=req.scope.get("client", ("0:0.0.0", 0))[0], + prot=req.scope.get("scheme", "err"), + method=req.scope.get("method", "err"), + endpoint=endpoint, + duration=duration, + ) + ) + return res + + def handle_exception(request: Request, e: Exception): + err = { + "error": type(e).__name__, + "detail": vars(e).get("detail", ""), + "body": vars(e).get("body", ""), + "errors": str(e), + } + if not isinstance( + e, HTTPException + ): # do not print backtrace on known httpexceptions + message = f"API error: {request.method}: {request.url} {err}" + if rich_available: + print(message) + console.print_exception( + show_locals=True, + max_frames=2, + extra_lines=1, + suppress=[anyio, starlette], + word_wrap=False, + width=min([console.width, 200]), + ) + else: + print(message) + raise (e) + return JSONResponse( + status_code=vars(e).get("status_code", 500), + content=jsonable_encoder(err), + ) + + @app.middleware("http") + async def exception_handling(request: Request, call_next): + try: + return await call_next(request) + except Exception as e: + return handle_exception(request, e) + + @app.exception_handler(Exception) + async def fastapi_exception_handler(request: Request, e: Exception): + return handle_exception(request, e) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, e: HTTPException): + return handle_exception(request, e) + + +class ApiCompat: + def __init__(self, app: FastAPI, queue_lock: Lock): + self.router = APIRouter() + self.app = app + self.queue_lock = queue_lock + api_middleware(self.app) + # self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"]) + # self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"]) + # self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"]) + # self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) + # self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) + # self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) + # self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) + # self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) + # self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) + # self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) + # self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) + # self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) + # self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + # self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) + # self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + # self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) + # self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) + # self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) + # self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) + # self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) + # self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + # self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + # self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) + # self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) + # self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) + # self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) + # self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) + # self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) + # self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) + # self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + + # chat APIs needed for compatibility with multiple extensions using OpenAI API + self.add_api_route("/v1/chat/completions", llm_chat_api, methods=["POST"]) + self.add_api_route("/v1/completions", llm_chat_api, methods=["POST"]) + self.add_api_route("/chat/completions", llm_chat_api, methods=["POST"]) + self.add_api_route("/completions", llm_chat_api, methods=["POST"]) + self.add_api_route( + "/v1/engines/codegen/completions", llm_chat_api, methods=["POST"] + ) + + self.default_script_arg_txt2img = [] + self.default_script_arg_img2img = [] + + def add_api_route(self, path: str, endpoint, **kwargs): + return self.app.add_api_route(path, endpoint, **kwargs) + + # def refresh_checkpoints(self): + # with self.queue_lock: + # studio_data.refresh_checkpoints() + + # def refresh_vae(self): + # with self.queue_lock: + # studio_data.refresh_vae_list() + + # def unloadapi(self): + # unload_model_weights() + + # return {} + + # def reloadapi(self): + # reload_model_weights() + + # return {} + + # def skip(self): + # studio.state.skip() + + def launch(self, server_name, port, root_path): + self.app.include_router(self.router) + uvicorn.run( + self.app, + host=server_name, + port=port, + root_path=root_path, + ) + + # def kill_studio(self): + # restart.stop_program() + + # def restart_studio(self): + # if restart.is_restartable(): + # restart.restart_program() + # return Response(status_code=501) + + # def preprocess(self, args: dict): + # try: + # studio.state.begin(job="preprocess") + # preprocess(**args) + # studio.state.end() + # return models.PreprocessResponse(info="preprocess complete") + # except: + # studio.state.end() + + # def stop_studio(request): + # studio.state.server_command = "stop" + # return Response("Stopping.") diff --git a/apps/shark_studio/web/api/sd.py b/apps/shark_studio/web/api/sd.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/apps/shark_studio/web/api/sd.py @@ -0,0 +1 @@ + diff --git a/apps/shark_studio/web/configs/default_sd_config.json b/apps/shark_studio/web/configs/default_sd_config.json new file mode 100644 index 0000000000..323a6a329c --- /dev/null +++ b/apps/shark_studio/web/configs/default_sd_config.json @@ -0,0 +1,28 @@ +{ + "prompt": [ + "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" + ], + "negative_prompt": [ + "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" + ], + "sd_init_image": [null], + "height": 512, + "width": 512, + "steps": 50, + "strength": 0.8, + "guidance_scale": 7.5, + "seed": "-1", + "batch_count": 1, + "batch_size": 1, + "scheduler": "EulerDiscrete", + "base_model_id": "stabilityai/stable-diffusion-2-1-base", + "custom_weights": null, + "custom_vae": null, + "precision": "fp16", + "device": "AMD Radeon RX 7900 XTX => vulkan://0", + "ondemand": false, + "repeatable_seeds": false, + "resample_type": "Nearest Neighbor", + "controlnets": {}, + "embeddings": {} +} \ No newline at end of file diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 3ef6bc5739..d1b97c2f78 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -1,20 +1,59 @@ from multiprocessing import Process, freeze_support + +freeze_support() +from PIL import Image + import os +import time import sys import logging -from ui.chat import chat_element +import apps.shark_studio.api.initializers as initialize + + +from apps.shark_studio.modules import timer + +startup_timer = timer.startup_timer +startup_timer.record("launcher") + +initialize.imports() if sys.platform == "darwin": os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib" # import before IREE to avoid MLIR library issues import torch_mlir -# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation -# from apps.stable_diffusion.src import args, clear_all -# import apps.stable_diffusion.web.utils.global_obj as global_obj +def create_api(app): + from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock + + queue_lock = FIFOLock() + api = ApiCompat(app, queue_lock) + return api + + +def api_only(): + from fastapi import FastAPI + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + + initialize.initialize() + + app = FastAPI() + initialize.setup_middleware(app) + api = create_api(app) -def launch_app(address): + # from modules import script_callbacks + # script_callbacks.before_ui_callback() + # script_callbacks.app_started_callback(None, app) + + print(f"Startup time: {startup_timer.summary()}.") + api.launch( + server_name="0.0.0.0", + port=cmd_opts.server_port, + root_path="", + ) + + +def launch_webui(address): from tkinter import Tk import webview @@ -34,138 +73,78 @@ def launch_app(address): webview.start(private_mode=False, storage_path=os.getcwd()) -if __name__ == "__main__": - # if args.debug: - logging.basicConfig(level=logging.DEBUG) +def webui(): + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + from apps.shark_studio.web.ui.utils import ( + nodicon_loc, + nodlogo_loc, + ) + + launch_api = cmd_opts.api + initialize.initialize() + + from ui.chat import chat_element + from ui.sd import sd_element + from ui.outputgallery import outputgallery_element + # required to do multiprocessing in a pyinstaller freeze freeze_support() - # if args.api or "api" in args.ui.split(","): - # from apps.stable_diffusion.web.ui import ( - # txt2img_api, - # img2img_api, - # upscaler_api, - # inpaint_api, - # outpaint_api, - # llm_chat_api, - # ) + + # if args.api or "api" in args.ui.split(","): + # from apps.shark_studio.api.llm import ( + # chat, + # ) + # from apps.shark_studio.web.api import sdapi + # + # from fastapi import FastAPI, APIRouter + # from fastapi.middleware.cors import CORSMiddleware + # import uvicorn # - # from fastapi import FastAPI, APIRouter - # import uvicorn + # # init global sd pipeline and config + # global_obj._init() # - # # init global sd pipeline and config - # global_obj._init() + # api = FastAPI() + # api.mount("/sdapi/", sdapi) # - # app = FastAPI() - # app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"]) - # app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"]) + # # chat APIs needed for compatibility with multiple extensions using OpenAI API + # api.add_api_route( + # "/v1/chat/completions", llm_chat_api, methods=["post"] + # ) + # api.add_api_route("/v1/completions", llm_chat_api, methods=["post"]) + # api.add_api_route("/chat/completions", llm_chat_api, methods=["post"]) + # api.add_api_route("/completions", llm_chat_api, methods=["post"]) + # api.add_api_route( + # "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] + # ) + # api.include_router(APIRouter()) # - # # chat APIs needed for compatibility with multiple extensions using OpenAI API - # app.add_api_route( - # "/v1/chat/completions", llm_chat_api, methods=["post"] - # ) - # app.add_api_route("/v1/completions", llm_chat_api, methods=["post"]) - # app.add_api_route("/chat/completions", llm_chat_api, methods=["post"]) - # app.add_api_route("/completions", llm_chat_api, methods=["post"]) - # app.add_api_route( - # "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] - # ) - # app.include_router(APIRouter()) - # uvicorn.run(app, host="0.0.0.0", port=args.server_port) - # sys.exit(0) + # # deal with CORS requests if CORS accept origins are set + # if args.api_accept_origin: + # print( + # f"API Configured for CORS. Accepting origins: { args.api_accept_origin }" + # ) + # api.add_middleware( + # CORSMiddleware, + # allow_origins=args.api_accept_origin, + # allow_methods=["GET", "POST"], + # allow_headers=["*"], + # ) + # else: + # print("API not configured for CORS") # - # Setup to use shark_tmp for gradio's temporary image files and clear any - # existing temporary images there if they exist. Then we can import gradio. - # It has to be in this order or gradio ignores what we've set up. - # from apps.stable_diffusion.web.utils.gradio_configs import ( - # config_gradio_tmp_imgs_folder, - # ) - - # config_gradio_tmp_imgs_folder() + # uvicorn.run(api, host="0.0.0.0", port=args.server_port) + # sys.exit(0) import gradio as gr - # Create custom models folders if they don't exist - # from apps.stable_diffusion.web.ui.utils import create_custom_models_folders - - # create_custom_models_folders() - def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) dark_theme = resource_path("ui/css/sd_dark_theme.css") + gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js") - # from apps.stable_diffusion.web.ui import ( - # txt2img_web, - # txt2img_custom_model, - # txt2img_gallery, - # txt2img_png_info_img, - # txt2img_status, - # txt2img_sendto_img2img, - # txt2img_sendto_inpaint, - # txt2img_sendto_outpaint, - # txt2img_sendto_upscaler, - ## h2ogpt_upload, - ## h2ogpt_web, - # img2img_web, - # img2img_custom_model, - # img2img_gallery, - # img2img_init_image, - # img2img_status, - # img2img_sendto_inpaint, - # img2img_sendto_outpaint, - # img2img_sendto_upscaler, - # inpaint_web, - # inpaint_custom_model, - # inpaint_gallery, - # inpaint_init_image, - # inpaint_status, - # inpaint_sendto_img2img, - # inpaint_sendto_outpaint, - # inpaint_sendto_upscaler, - # outpaint_web, - # outpaint_custom_model, - # outpaint_gallery, - # outpaint_init_image, - # outpaint_status, - # outpaint_sendto_img2img, - # outpaint_sendto_inpaint, - # outpaint_sendto_upscaler, - # upscaler_web, - # upscaler_custom_model, - # upscaler_gallery, - # upscaler_init_image, - # upscaler_status, - # upscaler_sendto_img2img, - # upscaler_sendto_inpaint, - # upscaler_sendto_outpaint, - ## lora_train_web, - ## model_web, - ## model_config_web, - # hf_models, - # modelmanager_sendto_txt2img, - # modelmanager_sendto_img2img, - # modelmanager_sendto_inpaint, - # modelmanager_sendto_outpaint, - # modelmanager_sendto_upscaler, - # stablelm_chat, - # minigpt4_web, - # outputgallery_web, - # outputgallery_tab_select, - # outputgallery_watch, - # outputgallery_filename, - # outputgallery_sendto_txt2img, - # outputgallery_sendto_img2img, - # outputgallery_sendto_inpaint, - # outputgallery_sendto_outpaint, - # outputgallery_sendto_upscaler, - # ) - - # init global sd pipeline and config - # global_obj._init() + # from apps.shark_studio.web.ui import load_ui_from_script def register_button_click(button, selectedid, inputs, outputs): button.click( @@ -177,17 +156,6 @@ def register_button_click(button, selectedid, inputs, outputs): outputs, ) - def register_modelmanager_button(button, selectedid, inputs, outputs): - button.click( - lambda x: ( - "None", - x, - gr.Tabs.update(selected=selectedid), - ), - inputs, - outputs, - ) - def register_outputgallery_button(button, selectedid, inputs, outputs): button.click( lambda x: ( @@ -199,8 +167,19 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): ) with gr.Blocks( - css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta" - ) as sd_web: + css=dark_theme, + js=gradio_workarounds, + analytics_enabled=False, + title="Shark Studio 2.0 Beta", + ) as studio_web: + nod_logo = Image.open(nodlogo_loc) + gr.Image( + value=nod_logo, + show_label=False, + interactive=False, + elem_id="tab_bar_logo", + show_download_button=False, + ) with gr.Tabs() as tabs: # NOTE: If adding, removing, or re-ordering tabs, make sure that they # have a unique id that doesn't clash with any of the other tabs, @@ -211,216 +190,33 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): # destination of one of the 'send to' buttons. If you do have to change # that id, make sure you update the relevant register_button_click calls # further down with the new id. - # with gr.TabItem(label="Text-to-Image", id=0): - # txt2img_web.render() - # with gr.TabItem(label="Image-to-Image", id=1): - # img2img_web.render() - # with gr.TabItem(label="Inpainting", id=2): - # inpaint_web.render() - # with gr.TabItem(label="Outpainting", id=3): - # outpaint_web.render() - # with gr.TabItem(label="Upscaler", id=4): - # upscaler_web.render() - # if args.output_gallery: - # with gr.TabItem(label="Output Gallery", id=5) as og_tab: - # outputgallery_web.render() - - # # extra output gallery configuration - # outputgallery_tab_select(og_tab.select) - # outputgallery_watch( - # [ - # txt2img_status, - # img2img_status, - # inpaint_status, - # outpaint_status, - # upscaler_status, - # ] - # ) - ## with gr.TabItem(label="Model Manager", id=6): - ## model_web.render() - ## with gr.TabItem(label="LoRA Training (Experimental)", id=7): - ## lora_train_web.render() - with gr.TabItem(label="Chat Bot", id=0): + with gr.TabItem(label="Stable Diffusion", id=0): + sd_element.render() + with gr.TabItem(label="Output Gallery", id=1): + outputgallery_element.render() + with gr.TabItem(label="Chat Bot", id=2): chat_element.render() - ## with gr.TabItem( - ## label="Generate Sharding Config (Experimental)", id=9 - ## ): - ## model_config_web.render() - # with gr.TabItem(label="MultiModal (Experimental)", id=10): - # minigpt4_web.render() - # with gr.TabItem(label="DocuChat Upload", id=11): - # h2ogpt_upload.render() - # with gr.TabItem(label="DocuChat(Experimental)", id=12): - # h2ogpt_web.render() - - # send to buttons - # register_button_click( - # txt2img_sendto_img2img, - # 1, - # [txt2img_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # txt2img_sendto_inpaint, - # 2, - # [txt2img_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # txt2img_sendto_outpaint, - # 3, - # [txt2img_gallery], - # [outpaint_init_image, tabs], - # ) - # register_button_click( - # txt2img_sendto_upscaler, - # 4, - # [txt2img_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # img2img_sendto_inpaint, - # 2, - # [img2img_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # img2img_sendto_outpaint, - # 3, - # [img2img_gallery], - # [outpaint_init_image, tabs], - # ) - # register_button_click( - # img2img_sendto_upscaler, - # 4, - # [img2img_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # inpaint_sendto_img2img, - # 1, - # [inpaint_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # inpaint_sendto_outpaint, - # 3, - # [inpaint_gallery], - # [outpaint_init_image, tabs], - # ) - # register_button_click( - # inpaint_sendto_upscaler, - # 4, - # [inpaint_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # outpaint_sendto_img2img, - # 1, - # [outpaint_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # outpaint_sendto_inpaint, - # 2, - # [outpaint_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # outpaint_sendto_upscaler, - # 4, - # [outpaint_gallery], - # [upscaler_init_image, tabs], - # ) - # register_button_click( - # upscaler_sendto_img2img, - # 1, - # [upscaler_gallery], - # [img2img_init_image, tabs], - # ) - # register_button_click( - # upscaler_sendto_inpaint, - # 2, - # [upscaler_gallery], - # [inpaint_init_image, tabs], - # ) - # register_button_click( - # upscaler_sendto_outpaint, - # 3, - # [upscaler_gallery], - # [outpaint_init_image, tabs], - # ) - # if args.output_gallery: - # register_outputgallery_button( - # outputgallery_sendto_txt2img, - # 0, - # [outputgallery_filename], - # [txt2img_png_info_img, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_img2img, - # 1, - # [outputgallery_filename], - # [img2img_init_image, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_inpaint, - # 2, - # [outputgallery_filename], - # [inpaint_init_image, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_outpaint, - # 3, - # [outputgallery_filename], - # [outpaint_init_image, tabs], - # ) - # register_outputgallery_button( - # outputgallery_sendto_upscaler, - # 4, - # [outputgallery_filename], - # [upscaler_init_image, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_txt2img, - # 0, - # [hf_models], - # [txt2img_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_img2img, - # 1, - # [hf_models], - # [img2img_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_inpaint, - # 2, - # [hf_models], - # [inpaint_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_outpaint, - # 3, - # [hf_models], - # [outpaint_custom_model, tabs], - # ) - # register_modelmanager_button( - # modelmanager_sendto_upscaler, - # 4, - # [hf_models], - # [upscaler_custom_model, tabs], - # ) - - sd_web.queue() + + studio_web.queue() + # if args.ui == "app": # t = Process( # target=launch_app, args=[f"http://localhost:{args.server_port}"] # ) # t.start() - sd_web.launch( - share=True, + studio_web.launch( + share=cmd_opts.share, inbrowser=True, server_name="0.0.0.0", - server_port=11911, # args.server_port, + server_port=cmd_opts.server_port, + favicon_path=nodicon_loc, ) + + +if __name__ == "__main__": + from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + + if cmd_opts.webui == False: + api_only() + else: + webui() diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index 6e10cfaf6b..f41eaaaba0 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -5,13 +5,15 @@ from datetime import datetime as dt import json import sys -from apps.shark_studio.api.utils import ( - get_available_devices, -) from apps.shark_studio.api.llm import ( llm_model_map, LanguageModel, ) +import apps.shark_studio.web.utils.globals as global_obj + +B_SYS, E_SYS = "", "" + +B_SYS, E_SYS = "", "" B_SYS, E_SYS = "", "" @@ -99,7 +101,7 @@ def view_json_file(file_obj): choices=model_choices, allow_custom_value=True, ) - supported_devices = get_available_devices() + supported_devices = global_obj.get_device_list() enabled = True if len(supported_devices) == 0: supported_devices = ["cpu-task"] diff --git a/apps/shark_studio/web/ui/common_events.py b/apps/shark_studio/web/ui/common_events.py new file mode 100644 index 0000000000..7dda8ba268 --- /dev/null +++ b/apps/shark_studio/web/ui/common_events.py @@ -0,0 +1,67 @@ +from apps.shark_studio.web.ui.utils import ( + HSLHue, + hsl_color, +) +from apps.shark_studio.modules.embeddings import get_lora_metadata + + +# Answers HTML to show the most frequent tags used when a LoRA was trained, +# taken from the metadata of its .safetensors file. +def lora_changed(lora_files): + # tag frequency percentage, that gets maximum amount of the staring hue + TAG_COLOR_THRESHOLD = 0.55 + # tag frequency percentage, above which a tag is displayed + TAG_DISPLAY_THRESHOLD = 0.65 + # template for the html used to display a tag + TAG_HTML_TEMPLATE = ( + '{tag}' + ) + output = [] + for lora_file in lora_files: + if lora_file == "": + output.extend(["
No LoRA selected
"]) + elif not lora_file.lower().endswith(".safetensors"): + output.extend( + [ + "
Only metadata queries for .safetensors files are currently supported
" + ] + ) + else: + metadata = get_lora_metadata(lora_file) + if metadata: + frequencies = metadata["frequencies"] + output.extend( + [ + "".join( + [ + f'
Trained against weights in: {metadata["model"]}
' + ] + + [ + TAG_HTML_TEMPLATE.format( + color=hsl_color( + (tag[1] - TAG_COLOR_THRESHOLD) + / (1 - TAG_COLOR_THRESHOLD), + start=HSLHue.RED, + end=HSLHue.GREEN, + ), + tag=tag[0], + ) + for tag in frequencies + if tag[1] > TAG_DISPLAY_THRESHOLD + ], + ) + ] + ) + elif metadata is None: + output.extend( + [ + "
This LoRA does not publish tag frequency metadata
" + ] + ) + else: + output.extend( + [ + "
This LoRA has empty tag frequency metadata, or we could not parse it
" + ] + ) + return output diff --git a/apps/shark_studio/web/ui/css/sd_dark_theme.css b/apps/shark_studio/web/ui/css/sd_dark_theme.css new file mode 100644 index 0000000000..e17b90c862 --- /dev/null +++ b/apps/shark_studio/web/ui/css/sd_dark_theme.css @@ -0,0 +1,373 @@ +/* +Apply Gradio dark theme to the default Gradio theme. +Procedure to upgrade the dark theme: +- Using your browser, visit http://localhost:8080/?__theme=dark +- Open your browser inspector, search for the .dark css class +- Copy .dark class declarations, apply them here into :root +*/ + +:root { + --body-background-fill: var(--background-fill-primary); + --body-text-color: var(--neutral-100); + --color-accent-soft: var(--neutral-700); + --background-fill-primary: var(--neutral-950); + --background-fill-secondary: var(--neutral-900); + --border-color-accent: var(--neutral-600); + --border-color-primary: var(--neutral-700); + --link-text-color-active: var(--secondary-500); + --link-text-color: var(--secondary-500); + --link-text-color-hover: var(--secondary-400); + --link-text-color-visited: var(--secondary-600); + --body-text-color-subdued: var(--neutral-400); + --shadow-spread: 1px; + --block-background-fill: var(--neutral-800); + --block-border-color: var(--border-color-primary); + --block_border_width: None; + --block-info-text-color: var(--body-text-color-subdued); + --block-label-background-fill: var(--background-fill-secondary); + --block-label-border-color: var(--border-color-primary); + --block_label_border_width: None; + --block-label-text-color: var(--neutral-200); + --block_shadow: None; + --block_title_background_fill: None; + --block_title_border_color: None; + --block_title_border_width: None; + --block-title-text-color: var(--neutral-200); + --panel-background-fill: var(--background-fill-secondary); + --panel-border-color: var(--border-color-primary); + --panel_border_width: None; + --checkbox-background-color: var(--neutral-800); + --checkbox-background-color-focus: var(--checkbox-background-color); + --checkbox-background-color-hover: var(--checkbox-background-color); + --checkbox-background-color-selected: var(--secondary-600); + --checkbox-border-color: var(--neutral-700); + --checkbox-border-color-focus: var(--secondary-500); + --checkbox-border-color-hover: var(--neutral-600); + --checkbox-border-color-selected: var(--secondary-600); + --checkbox-border-width: var(--input-border-width); + --checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); + --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); + --checkbox-label-border-color: var(--border-color-primary); + --checkbox-label-border-color-hover: var(--checkbox-label-border-color); + --checkbox-label-border-width: var(--input-border-width); + --checkbox-label-text-color: var(--body-text-color); + --checkbox-label-text-color-selected: var(--checkbox-label-text-color); + --error-background-fill: var(--background-fill-primary); + --error-border-color: var(--border-color-primary); + --error_border_width: None; + --error-text-color: #ef4444; + --input-background-fill: var(--neutral-800); + --input-background-fill-focus: var(--secondary-600); + --input-background-fill-hover: var(--input-background-fill); + --input-border-color: var(--border-color-primary); + --input-border-color-focus: var(--neutral-700); + --input-border-color-hover: var(--input-border-color); + --input_border_width: None; + --input-placeholder-color: var(--neutral-500); + --input_shadow: None; + --input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset); + --loader_color: None; + --slider_color: None; + --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600)); + --table-border-color: var(--neutral-700); + --table-even-background-fill: var(--neutral-950); + --table-odd-background-fill: var(--neutral-900); + --table-row-focus: var(--color-accent-soft); + --button-border-width: var(--input-border-width); + --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); + --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); + --button-cancel-border-color: #dc2626; + --button-cancel-border-color-hover: var(--button-cancel-border-color); + --button-cancel-text-color: white; + --button-cancel-text-color-hover: var(--button-cancel-text-color); + --button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600)); + --button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500)); + --button-primary-border-color: var(--primary-500); + --button-primary-border-color-hover: var(--button-primary-border-color); + --button-primary-text-color: white; + --button-primary-text-color-hover: var(--button-primary-text-color); + --button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700)); + --button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600)); + --button-secondary-border-color: var(--neutral-600); + --button-secondary-border-color-hover: var(--button-secondary-border-color); + --button-secondary-text-color: white; + --button-secondary-text-color-hover: var(--button-secondary-text-color); + --block-border-width: 1px; + --block-label-border-width: 1px; + --form-gap-width: 1px; + --error-border-width: 1px; + --input-border-width: 1px; +} + +/* SHARK theme */ +body { + background-color: var(--background-fill-primary); +} + +.generating.svelte-zlszon.svelte-zlszon { + border: none; +} + +.generating { + border: none !important; +} + +#chatbot { + height: 100% !important; +} + +/* display in full width for desktop devices, but see below */ +@media (min-width: 1536px) +{ + .gradio-container { + max-width: var(--size-full) !important; + } +} + +/* media rules in custom css are don't appear to be applied in + gradio versions > 4.7, so we have to define a class which + we will manually need add and remove using javascript. + Remove this once this fixed in gradio. +*/ +.gradio-container-size-full { + max-width: var(--size-full) !important; +} + +.gradio-container .contain { + padding: 0 var(--size-4) !important; +} + +#top_logo { + color: transparent; + background-color: transparent; + border-radius: 0 !important; + border: 0; +} + +#ui_title { + padding: var(--size-2) 0 0 var(--size-1); +} + +#demo_title_outer { + border-radius: 0; +} + +#prompt_box_outer div:first-child { + border-radius: 0 !important +} + +#prompt_box textarea, #negative_prompt_box textarea { + background-color: var(--background-fill-primary) !important; +} + +#prompt_examples { + margin: 0 !important; +} + +#prompt_examples svg { + display: none !important; +} + +#ui_body { + padding: var(--size-2) !important; + border-radius: 0.5em !important; +} + +#img_result+div { + display: none !important; +} + +footer { + display: none !important; +} + +#gallery + div { + border-radius: 0 !important; +} + +/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */ +#gallery .thumbnail-item.thumbnail-lg { + aspect-ratio: unset; + max-height: calc(55vh - (2 * var(--spacing-lg))); +} +/* fix width and height of gallery items when on very large desktop screens, but see below */ +@media (min-width: 1921px) { + /* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */ + #gallery .grid-wrap, #gallery .preview{ + min-height: calc(768px + 4px + var(--size-14)); + max-height: calc(768px + 4px + var(--size-14)); + } + /* Limit height to 768px_height + 2px_margin_height for the thumbnails */ + #gallery .thumbnail-item.thumbnail-lg { + max-height: 770px !important; + } +} + +/* media rules in custom css are don't appear to be applied in + gradio versions > 4.7, so we have to define classes which + we will manually need add and remove using javascript. + Remove this once this fixed in gradio. +*/ +.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview { + min-height: calc(768px + 4px + var(--size-14)) !important; + max-height: calc(768px + 4px + var(--size-14)) !important; +} +.gallery-limit-height768 .thumbnail-item.thumbnail-lg { + max-height: 770px !important; +} + +/* Don't upscale when viewing in solo image mode */ +#gallery .preview img { + object-fit: scale-down; +} +/* Navbar images in cover mode*/ +#gallery .preview .thumbnail-item img { + object-fit: cover; +} + +/* Limit the stable diffusion text output height */ +#std_output textarea { + max-height: 215px; +} + +/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */ +#gallery .wrap.default { + pointer-events: none; +} + +/* Import Png info box */ +#txt2img_prompt_image { + height: var(--size-32) !important; +} + +/* Hide "remove buttons" from ui dropdowns */ +#custom_model .token-remove.remove-all, +#lora_weights .token-remove.remove-all, +#scheduler .token-remove.remove-all, +#device .token-remove.remove-all, +#stencil_model .token-remove.remove-all { + display: none; +} + +/* Hide selected items from ui dropdowns */ +#custom_model .options .item .inner-item, +#scheduler .options .item .inner-item, +#device .options .item .inner-item, +#stencil_model .options .item .inner-item { + display:none; +} + +/* workarounds for container=false not currently working for dropdowns */ +.dropdown_no_container { + padding: 0 !important; +} + +#output_subdir_container :first-child { + border: none; +} + +/* reduced animation load when generating */ +.generating { + animation-play-state: paused !important; +} + +/* better clarity when progress bars are minimal */ +.meta-text { + background-color: var(--block-label-background-fill); +} + +/* lora tag pills */ +.lora-tags { + border: 1px solid var(--border-color-primary); + color: var(--block-info-text-color) !important; + padding: var(--block-padding); +} + +.lora-tag { + display: inline-block; + height: 2em; + color: rgb(212 212 212) !important; + margin-right: 5pt; + margin-bottom: 5pt; + padding: 2pt 5pt; + border-radius: 5pt; + white-space: nowrap; +} + +.lora-model { + margin-bottom: var(--spacing-lg); + color: var(--block-info-text-color) !important; + line-height: var(--line-sm); +} + +/* output gallery tab */ +.output_parameters_dataframe table.table { + /* works around a gradio bug that always shows scrollbars */ + overflow: clip auto; +} + +.output_parameters_dataframe tbody td { + font-size: small; + line-height: var(--line-xs); +} + +.output_icon_button { + max-width: 30px; + align-self: end; + padding-bottom: 8px; +} + +.outputgallery_sendto { + min-width: 7em !important; +} + +/* output gallery should take up most of the viewport height regardless of image size/number */ +#outputgallery_gallery .fixed-height { + min-height: 89vh !important; +} + +.sd-right-panel { + height: calc(100vmin - var(--size-32) - var(--size-10)) !important; + overflow-y: scroll; +} + +.sd-right-panel .fill { + flex: 1; +} + +/* don't stretch non-square images to be square, breaking their aspect ratio */ +#outputgallery_gallery .thumbnail-item.thumbnail-lg > img { + object-fit: contain !important; +} + +/* centered logo for when there are no images */ +#top_logo.logo_centered { + height: 100%; + width: 100%; +} + +#top_logo.logo_centered img { + object-fit: scale-down; + position: absolute; + width: 80%; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); +} + +#tab_bar_logo { + overflow: visible !important; + border-width: 0 !important; + height: 0px !important; + padding: 0; + margin: 0; +} + +#tab_bar_logo .image-container { + object-fit: scale-down; + position: absolute !important; + top: 14px; + right: 0px; + height: 36px; +} \ No newline at end of file diff --git a/apps/shark_studio/web/ui/js/sd_gradio_workarounds.js b/apps/shark_studio/web/ui/js/sd_gradio_workarounds.js new file mode 100644 index 0000000000..b1f893ee27 --- /dev/null +++ b/apps/shark_studio/web/ui/js/sd_gradio_workarounds.js @@ -0,0 +1,49 @@ +// workaround gradio after 4.7, not applying any @media rules form the custom .css file + +() => { + console.log(`innerWidth: ${window.innerWidth}` ) + + // 1536px rules + + const mediaQuery1536 = window.matchMedia('(min-width: 1536px)') + + function handleWidth1536(event) { + + // display in full width for desktop devices + document.querySelectorAll(".gradio-container") + .forEach( (node) => { + if (event.matches) { + node.classList.add("gradio-container-size-full"); + } else { + node.classList.remove("gradio-container-size-full") + } + }); + } + + mediaQuery1536.addEventListener("change", handleWidth1536); + mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536})); + + // 1921px rules + + const mediaQuery1921 = window.matchMedia('(min-width: 1921px)') + + function handleWidth1921(event) { + + /* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */ + /* Limit height to 768px_height + 2px_margin_height for the thumbnails */ + document.querySelectorAll("#gallery") + .forEach( (node) => { + if (event.matches) { + node.classList.add("gallery-force-height768"); + node.classList.add("gallery-limit-height768"); + } else { + node.classList.remove("gallery-force-height768"); + node.classList.remove("gallery-limit-height768"); + } + }); + } + + mediaQuery1921.addEventListener("change", handleWidth1921); + mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921})); + +} diff --git a/apps/shark_studio/web/ui/logos/nod-icon.png b/apps/shark_studio/web/ui/logos/nod-icon.png new file mode 100644 index 0000000000..29f7e32220 Binary files /dev/null and b/apps/shark_studio/web/ui/logos/nod-icon.png differ diff --git a/apps/shark_studio/web/ui/logos/nod-logo.png b/apps/shark_studio/web/ui/logos/nod-logo.png new file mode 100644 index 0000000000..4727e15a19 Binary files /dev/null and b/apps/shark_studio/web/ui/logos/nod-logo.png differ diff --git a/apps/shark_studio/web/ui/outputgallery.py b/apps/shark_studio/web/ui/outputgallery.py new file mode 100644 index 0000000000..a3de6f7b57 --- /dev/null +++ b/apps/shark_studio/web/ui/outputgallery.py @@ -0,0 +1,406 @@ +import glob +import gradio as gr +import os +import subprocess +import sys +from PIL import Image + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.web.utils.file_utils import ( + get_generated_imgs_path, + get_generated_imgs_todays_subdir, +) +from apps.shark_studio.web.ui.utils import nodlogo_loc +from apps.shark_studio.web.utils.metadata import displayable_metadata + +# -- Functions for file, directory and image info querying + +output_dir = get_generated_imgs_path() + + +def outputgallery_filenames(subdir) -> list[str]: + new_dir_path = os.path.join(output_dir, subdir) + if os.path.exists(new_dir_path): + filenames = [ + glob.glob(new_dir_path + "/" + ext) for ext in ("*.png", "*.jpg", "*.jpeg") + ] + + return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) + else: + return [] + + +def output_subdirs() -> list[str]: + # Gets a list of subdirectories of output_dir and below, as relative paths. + relative_paths = [ + os.path.relpath(entry[0], output_dir) + for entry in os.walk( + output_dir, followlinks=cmd_opts.output_gallery_followlinks + ) + ] + + # It is less confusing to always including the subdir that will take any + # images generated today even if it doesn't exist yet + if get_generated_imgs_todays_subdir() not in relative_paths: + relative_paths.append(get_generated_imgs_todays_subdir()) + + # sort subdirectories so that the date named ones we probably + # created in this or previous sessions come first, sorted with the most + # recent first. Other subdirs are listed after. + generated_paths = sorted( + [path for path in relative_paths if path.isnumeric()], reverse=True + ) + result_paths = generated_paths + sorted( + [path for path in relative_paths if (not path.isnumeric()) and path != "."] + ) + + return result_paths + + +# --- Define UI layout for Gradio + +with gr.Blocks() as outputgallery_element: + nod_logo = Image.open(nodlogo_loc) + + with gr.Row(elem_id="outputgallery_gallery"): + # needed to workaround gradio issue: + # https://github.com/gradio-app/gradio/issues/2907 + dev_null = gr.Textbox("", visible=False) + + gallery_files = gr.State(value=[]) + subdirectory_paths = gr.State(value=[]) + + with gr.Column(scale=6): + logo = gr.Image( + label="Getting subdirectories...", + value=nod_logo, + interactive=False, + visible=True, + show_label=True, + elem_id="top_logo", + elem_classes="logo_centered", + show_download_button=False, + ) + + gallery = gr.Gallery( + label="", + value=gallery_files.value, + visible=False, + show_label=True, + columns=4, + ) + + with gr.Column(scale=4): + with gr.Group(): + with gr.Row(): + with gr.Column( + scale=15, + min_width=160, + elem_id="output_subdir_container", + ): + subdirectories = gr.Dropdown( + label=f"Subdirectories of {output_dir}", + type="value", + choices=subdirectory_paths.value, + value="", + interactive=True, + elem_classes="dropdown_no_container", + allow_custom_value=True, + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + open_subdir = gr.Button( + variant="secondary", + value="\U0001F5C1", # unicode open folder + interactive=False, + size="sm", + ) + with gr.Column( + scale=1, + min_width=32, + elem_classes="output_icon_button", + ): + refresh = gr.Button( + variant="secondary", + value="\u21BB", # unicode clockwise arrow circle + size="sm", + ) + + image_columns = gr.Slider( + label="Columns shown", value=4, minimum=1, maximum=16, step=1 + ) + outputgallery_filename = gr.Textbox( + label="Filename", + value="None", + interactive=False, + show_copy_button=True, + ) + + with gr.Accordion( + label="Parameter Information", open=False + ) as parameters_accordian: + image_parameters = gr.DataFrame( + headers=["Parameter", "Value"], + col_count=2, + wrap=True, + elem_classes="output_parameters_dataframe", + value=[["Status", "No image selected"]], + interactive=True, + ) + + with gr.Accordion(label="Send To", open=True): + with gr.Row(): + outputgallery_sendto_sd = gr.Button( + value="Stable Diffusion", + interactive=False, + elem_classes="outputgallery_sendto", + size="sm", + ) + + # --- Event handlers + + def on_clear_gallery(): + return [ + gr.Gallery( + value=[], + visible=False, + ), + gr.Image( + visible=True, + ), + ] + + def on_image_columns_change(columns): + return gr.Gallery(columns=columns) + + def on_select_subdir(subdir) -> list: + # evt.value is the subdirectory name + new_images = outputgallery_filenames(subdir) + new_label = f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_open_subdir(subdir): + subdir_path = os.path.normpath(os.path.join(output_dir, subdir)) + + if os.path.isdir(subdir_path): + if sys.platform == "linux": + subprocess.run(["xdg-open", subdir_path]) + elif sys.platform == "darwin": + subprocess.run(["open", subdir_path]) + elif sys.platform == "win32": + os.startfile(subdir_path) + + def on_refresh(current_subdir: str) -> list: + # get an up-to-date subdirectory list + refreshed_subdirs = output_subdirs() + # get the images using either the current subdirectory or the most + # recent valid one + new_subdir = ( + current_subdir + if current_subdir in refreshed_subdirs + else refreshed_subdirs[0] + ) + new_images = outputgallery_filenames(new_subdir) + new_label = ( + f"{len(new_images)} images in " f"{os.path.join(output_dir, new_subdir)}" + ) + + return [ + gr.Dropdown( + choices=refreshed_subdirs, + value=new_subdir, + ), + refreshed_subdirs, + new_images, + gr.Gallery(value=new_images, label=new_label, visible=len(new_images) > 0), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + + def on_new_image(subdir, subdir_paths, status) -> list: + # prevent error triggered when an image generates before the tab + # has even been selected + subdir_paths = ( + subdir_paths + if len(subdir_paths) > 0 + else [get_generated_imgs_todays_subdir()] + ) + + # only update if the current subdir is the most recent one as + # new images only go there + if subdir_paths[0] == subdir: + new_images = outputgallery_filenames(subdir) + new_label = ( + f"{len(new_images)} images in " + f"{os.path.join(output_dir, subdir)} - {status}" + ) + + return [ + new_images, + gr.Gallery( + value=new_images, + label=new_label, + visible=len(new_images) > 0, + ), + gr.Image( + label=new_label, + visible=len(new_images) == 0, + ), + ] + else: + # otherwise change nothing, + # (only untyped gradio gr.update() does this) + return [gr.update(), gr.update(), gr.update()] + + def on_select_image(images: list[str], evt: gr.SelectData) -> list: + # evt.index is an index into the full list of filenames for + # the current subdirectory + filename = images[evt.index] + params = displayable_metadata(filename) + + if params: + if params["source"] == "missing": + return [ + "Could not find this image file, refresh the gallery and update the images", + [["Status", "File missing"]], + ] + else: + return [ + filename, + list(map(list, params["parameters"].items())), + ] + + return [ + filename, + [["Status", "No parameters found"]], + ] + + def on_outputgallery_filename_change(filename: str) -> list: + exists = filename != "None" and os.path.exists(filename) + return [ + # disable or enable each of the sendto button based on whether + # an image is selected + gr.Button(interactive=exists), + ] + + # The time first our tab is selected we need to do an initial refresh + # to populate the subdirectory select box and the images from the most + # recent subdirectory. + # + # We do it at this point rather than setting this up in the controls' + # definitions as when you refresh the browser you always get what was + # *initially* set, which won't include any new subdirectories or images + # that might have created since the application was started. Doing it + # this way means a browser refresh/reload always gets the most + # up-to-date data. + def on_select_tab(subdir_paths, request: gr.Request): + local_client = request.headers["host"].startswith( + "127.0.0.1:" + ) or request.headers["host"].startswith("localhost:") + + if len(subdir_paths) == 0: + return on_refresh("") + [gr.update(interactive=local_client)] + else: + return ( + # Change nothing, (only untyped gr.update() does this) + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + # clearing images when we need to completely change what's in the + # gallery avoids current images being shown replacing piecemeal and + # prevents weirdness and errors if the user selects an image during the + # replacement phase. + clear_gallery = dict( + fn=on_clear_gallery, + inputs=None, + outputs=[gallery, logo], + queue=False, + ) + + subdirectories.select(**clear_gallery).then( + on_select_subdir, + [subdirectories], + [gallery_files, gallery, logo], + queue=False, + ) + + open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False) + + refresh.click(**clear_gallery).then( + on_refresh, + [subdirectories], + [subdirectories, subdirectory_paths, gallery_files, gallery, logo], + queue=False, + ) + + image_columns.change( + fn=on_image_columns_change, + inputs=[image_columns], + outputs=[gallery], + queue=False, + ) + + gallery.select( + on_select_image, + [gallery_files], + [outputgallery_filename, image_parameters], + queue=False, + ) + + outputgallery_filename.change( + on_outputgallery_filename_change, + [outputgallery_filename], + [ + outputgallery_sendto_sd, + ], + queue=False, + ) + + # We should have been given the .select function for our tab, so set it up + def outputgallery_tab_select(select): + select( + fn=on_select_tab, + inputs=[subdirectory_paths], + outputs=[ + subdirectories, + subdirectory_paths, + gallery_files, + gallery, + logo, + open_subdir, + ], + queue=False, + ) + + # We should have been passed a list of components on other tabs that update + # when a new image has generated on that tab, so set things up so the user + # will see that new image if they are looking at today's subdirectory + def outputgallery_watch(components: gr.Textbox): + for component in components: + component.change( + on_new_image, + inputs=[subdirectories, subdirectory_paths, component], + outputs=[gallery_files, gallery, logo], + queue=False, + ) diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py new file mode 100644 index 0000000000..799504cb75 --- /dev/null +++ b/apps/shark_studio/web/ui/sd.py @@ -0,0 +1,769 @@ +import os +import json +import gradio as gr +import numpy as np +from inspect import signature +from PIL import Image +from pathlib import Path +from datetime import datetime as dt +from gradio.components.image_editor import ( + EditorValue, +) +from apps.shark_studio.web.utils.file_utils import ( + get_generated_imgs_path, + get_checkpoints_path, + get_checkpoints, + get_configs_path, + write_default_sd_config, +) +from apps.shark_studio.api.sd import ( + sd_model_map, + shark_sd_fn_dict_input, + cancel_sd, +) +from apps.shark_studio.api.controlnet import ( + cnet_preview, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, +) +from apps.shark_studio.modules.img_processing import ( + resampler_list, + resize_stencil, +) +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts +from apps.shark_studio.web.ui.utils import ( + nodlogo_loc, + none_to_str_none, + str_none_to_none, +) +from apps.shark_studio.web.utils.state import ( + status_label, +) +from apps.shark_studio.web.ui.common_events import lora_changed +from apps.shark_studio.modules import logger +import apps.shark_studio.web.utils.globals as global_obj + +sd_default_models = [ + "CompVis/stable-diffusion-v1-4", + "runwayml/stable-diffusion-v1-5", + "stabilityai/stable-diffusion-2-1-base", + "stabilityai/stable-diffusion-2-1", + "stabilityai/stable-diffusion-xl-1.0", + "stabilityai/sdxl-turbo", +] + + +def view_json_file(file_path): + content = "" + with open(file_path, "r") as fopen: + content = fopen.read() + return content + + +def submit_to_cnet_config( + stencil: str, + preprocessed_hint: str, + cnet_strength: int, + control_mode: str, + curr_config: dict, +): + if any(i in [None, ""] for i in [stencil, preprocessed_hint]): + return gr.update() + if curr_config is not None: + if "controlnets" in curr_config: + curr_config["controlnets"]["control_mode"] = control_mode + curr_config["controlnets"]["model"].append(stencil) + curr_config["controlnets"]["hint"].append(preprocessed_hint) + curr_config["controlnets"]["strength"].append(cnet_strength) + return curr_config + + cnet_map = {} + cnet_map["controlnets"] = { + "control_mode": control_mode, + "model": [stencil], + "hint": [preprocessed_hint], + "strength": [cnet_strength], + } + return cnet_map + + +def update_embeddings_json(embedding): + return {"embeddings": [embedding]} + + +def submit_to_main_config(input_cfg: dict, main_cfg: dict): + if main_cfg in [None, "", {}]: + return input_cfg + + for base_key in input_cfg: + main_cfg[base_key] = input_cfg[base_key] + return main_cfg + + +def pull_sd_configs( + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + base_model_id, + custom_weights, + custom_vae, + precision, + device, + ondemand, + repeatable_seeds, + resample_type, + controlnets, + embeddings, +): + sd_args = str_none_to_none(locals()) + sd_cfg = {} + for arg in sd_args: + if arg in [ + "prompt", + "negative_prompt", + "sd_init_image", + ]: + sd_cfg[arg] = [sd_args[arg]] + elif arg in ["controlnets", "embeddings"]: + if isinstance(arg, dict): + sd_cfg[arg] = json.loads(sd_args[arg]) + else: + sd_cfg[arg] = {} + else: + sd_cfg[arg] = sd_args[arg] + + return json.dumps(sd_cfg) + + +def load_sd_cfg(sd_json: dict, load_sd_config: str): + new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config))) + if sd_json: + for key in new_sd_config: + sd_json[key] = new_sd_config[key] + else: + sd_json = new_sd_config + for i in sd_json["sd_init_image"]: + if i is not None: + if os.path.isfile(i): + sd_image = [Image.open(i, mode="r")] + else: + sd_image = None + + return [ + sd_json["prompt"][0], + sd_json["negative_prompt"][0], + sd_image, + sd_json["height"], + sd_json["width"], + sd_json["steps"], + sd_json["strength"], + sd_json["guidance_scale"], + sd_json["seed"], + sd_json["batch_count"], + sd_json["batch_size"], + sd_json["scheduler"], + sd_json["base_model_id"], + sd_json["custom_weights"], + sd_json["custom_vae"], + sd_json["precision"], + sd_json["device"], + sd_json["ondemand"], + sd_json["repeatable_seeds"], + sd_json["resample_type"], + sd_json["controlnets"], + sd_json["embeddings"], + sd_json, + ] + + +def save_sd_cfg(config: dict, save_name: str): + if os.path.exists(save_name): + filepath = save_name + elif cmd_opts.configs_path: + filepath = os.path.join(cmd_opts.configs_path, save_name) + else: + filepath = os.path.join(get_configs_path(), save_name) + if ".json" not in filepath: + filepath += ".json" + with open(filepath, mode="w") as f: + f.write(json.dumps(config)) + return "..." + + +def create_canvas(width, height): + data = Image.fromarray( + np.zeros( + shape=(height, width, 3), + dtype=np.uint8, + ) + + 255 + ) + img_dict = { + "background": data, + "layers": [], + "composite": None, + } + return EditorValue(img_dict) + + +def import_original(original_img, width, height): + if original_img is None: + resized_img = create_canvas(width, height) + return resized_img + else: + resized_img, _, _ = resize_stencil(original_img, width, height) + img_dict = { + "background": resized_img, + "layers": [], + "composite": None, + } + return EditorValue(img_dict) + + +def base_model_changed(base_model_id): + new_choices = get_checkpoints( + os.path.join("checkpoints", os.path.basename(str(base_model_id))) + ) + get_checkpoints(model_type="checkpoints") + + return gr.Dropdown( + value=new_choices[0] if len(new_choices) > 0 else "None", + choices=["None"] + new_choices, + ) + + +with gr.Blocks(title="Stable Diffusion") as sd_element: + with gr.Column(elem_id="ui_body"): + with gr.Row(): + with gr.Column(scale=2, min_width=600): + with gr.Accordion( + label="\U0001F4D0\U0000FE0F Device Settings", open=False + ): + device = gr.Dropdown( + elem_id="device", + label="Device", + value=global_obj.get_device_list()[0], + choices=global_obj.get_device_list(), + allow_custom_value=False, + ) + with gr.Row(): + ondemand = gr.Checkbox( + value=cmd_opts.lowvram, + label="Low VRAM", + interactive=True, + ) + precision = gr.Radio( + label="Precision", + value=cmd_opts.precision, + choices=[ + "fp16", + "fp32", + ], + visible=True, + ) + sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}" + base_model_id = gr.Dropdown( + label="\U000026F0\U0000FE0F Base Model", + info="Select or enter HF model ID", + elem_id="custom_model", + value="stabilityai/stable-diffusion-2-1-base", + choices=sd_default_models, + ) # base_model_id + with gr.Row(): + height = gr.Slider( + 384, + 768, + value=cmd_opts.height, + step=8, + label="\U00002195\U0000FE0F Height", + ) + width = gr.Slider( + 384, + 768, + value=cmd_opts.width, + step=8, + label="\U00002194\U0000FE0F Width", + ) + with gr.Accordion( + label="\U00002696\U0000FE0F Model Weights", open=False + ): + with gr.Column(): + custom_weights = gr.Dropdown( + label="Checkpoint Weights", + info="Select or enter HF model ID", + elem_id="custom_model", + value="None", + allow_custom_value=True, + choices=["None"] + + get_checkpoints(os.path.basename(str(base_model_id))), + ) # custom_weights + base_model_id.change( + fn=base_model_changed, + inputs=[base_model_id], + outputs=[custom_weights], + ) + sd_vae_info = (str(get_checkpoints_path("vae"))).replace( + "\\", "\n\\" + ) + sd_vae_info = f"VAE Path: {sd_vae_info}" + custom_vae = gr.Dropdown( + label=f"VAE Model", + info=sd_vae_info, + elem_id="custom_model", + value=( + os.path.basename(cmd_opts.custom_vae) + if cmd_opts.custom_vae + else "None" + ), + choices=["None"] + get_checkpoints("vae"), + allow_custom_value=True, + scale=1, + ) + sd_lora_info = (str(get_checkpoints_path("loras"))).replace( + "\\", "\n\\" + ) + lora_opt = gr.Dropdown( + allow_custom_value=True, + label=f"Standalone LoRA Weights", + info=sd_lora_info, + elem_id="lora_weights", + value=None, + multiselect=True, + choices=[] + get_checkpoints("lora"), + scale=2, + ) + lora_tags = gr.HTML( + value="
No LoRA selected
", + elem_classes="lora-tags", + ) + embeddings_config = gr.JSON( + label="Embeddings Options", min_width=50, scale=1 + ) + gr.on( + triggers=[lora_opt.change], + fn=lora_changed, + inputs=[lora_opt], + outputs=[lora_tags], + queue=True, + show_progress=False, + ).then( + fn=update_embeddings_json, + inputs=[lora_opt], + outputs=[embeddings_config], + show_progress=False, + ) + with gr.Accordion( + label="\U0001F9EA\U0000FE0F Input Image Processing", open=False + ): + strength = gr.Slider( + 0, + 1, + value=cmd_opts.strength, + step=0.01, + label="Denoising Strength", + ) + resample_type = gr.Dropdown( + value=cmd_opts.resample_type, + choices=resampler_list, + label="Resample Type", + allow_custom_value=True, + ) + with gr.Group(elem_id="prompt_box_outer"): + prompt = gr.Textbox( + label="\U00002795\U0000FE0F Prompt", + value=cmd_opts.prompt[0], + lines=2, + elem_id="prompt_box", + show_copy_button=True, + ) + negative_prompt = gr.Textbox( + label="\U00002796\U0000FE0F Negative Prompt", + value=cmd_opts.negative_prompt[0], + lines=2, + elem_id="negative_prompt_box", + show_copy_button=True, + ) + with gr.Row(equal_height=True): + seed = gr.Textbox( + value=cmd_opts.seed, + label="\U0001F331\U0000FE0F Seed", + info="An integer or a JSON list of integers, -1 for random", + show_copy_button=True, + ) + scheduler = gr.Dropdown( + elem_id="scheduler", + label="\U0001F4C5\U0000FE0F Scheduler", + info="\U000E0020", # forces same height as seed + value="EulerDiscrete", + choices=scheduler_model_map.keys(), + allow_custom_value=False, + ) + with gr.Row(): + steps = gr.Slider( + 1, + 100, + value=cmd_opts.steps, + step=1, + label="\U0001F3C3\U0000FE0F Steps", + ) + guidance_scale = gr.Slider( + 0, + 50, + value=cmd_opts.guidance_scale, + step=0.1, + label="\U0001F5C3\U0000FE0F CFG Scale", + ) + with gr.Accordion( + label="Controlnet Options", + open=False, + visible=False, + ): + preprocessed_hints = gr.State([]) + with gr.Column(): + sd_cnet_info = ( + str(get_checkpoints_path("controlnet")) + ).replace("\\", "\n\\") + with gr.Row(): + cnet_config = gr.JSON() + with gr.Column(): + clear_config = gr.ClearButton( + value="Clear Controlnet Config", + size="sm", + components=cnet_config, + ) + control_mode = gr.Radio( + choices=["Prompt", "Balanced", "Controlnet"], + value="Balanced", + label="Control Mode", + ) + with gr.Row(): + with gr.Column(scale=1): + cnet_model = gr.Dropdown( + allow_custom_value=True, + label=f"Controlnet Model", + info=sd_cnet_info, + value="None", + choices=[ + "None", + "canny", + "openpose", + "scribble", + "zoedepth", + ] + + get_checkpoints("controlnet"), + ) + cnet_strength = gr.Slider( + label="Controlnet Strength", + minimum=0, + maximum=100, + value=50, + step=1, + ) + with gr.Row(): + canvas_width = gr.Slider( + label="Canvas Width", + minimum=256, + maximum=1024, + value=512, + step=8, + ) + canvas_height = gr.Slider( + label="Canvas Height", + minimum=256, + maximum=1024, + value=512, + step=8, + ) + make_canvas = gr.Button( + value="Make Canvas!", + ) + use_input_img = gr.Button( + value="Use Original Image", + size="sm", + ) + cnet_input = gr.Image( + value=None, + type="pil", + image_mode="RGB", + interactive=True, + ) + with gr.Column(scale=1): + cnet_output = gr.Image( + value=None, + visible=True, + label="Preprocessed Hint", + interactive=False, + show_label=True, + ) + cnet_gen = gr.Button( + value="Preprocess controlnet input", + ) + use_result = gr.Button( + "Submit", + size="sm", + ) + make_canvas.click( + fn=create_canvas, + inputs=[canvas_width, canvas_height], + outputs=[cnet_input], + queue=False, + ) + cnet_gen.click( + fn=cnet_preview, + inputs=[ + cnet_model, + cnet_input, + ], + outputs=[ + cnet_output, + preprocessed_hints, + ], + ) + use_result.click( + fn=submit_to_cnet_config, + inputs=[ + cnet_model, + cnet_output, + cnet_strength, + control_mode, + cnet_config, + ], + outputs=[ + cnet_config, + ], + queue=False, + ) + with gr.Column(scale=3, min_width=600): + with gr.Tabs() as sd_tabs: + sd_element.load( + # Workaround for Gradio issue #7085 + # TODO: revert to setting selected= in gr.Tabs declaration + # once this is resolved in Gradio + lambda: gr.Tabs(selected=101), + outputs=[sd_tabs], + ) + with gr.Tab(label="Input Image", id=100) as sd_tab_init_image: + with gr.Column(elem_classes=["sd-right-panel"]): + with gr.Row(elem_classes=["fill"]): + # TODO: make this import image prompt info if it exists + sd_init_image = gr.Image( + type="pil", + interactive=True, + show_label=False, + ) + use_input_img.click( + fn=import_original, + inputs=[ + sd_init_image, + canvas_width, + canvas_height, + ], + outputs=[cnet_input], + queue=False, + ) + with gr.Tab(label="Generate Images", id=101) as sd_tab_gallery: + with gr.Column(elem_classes=["sd-right-panel"]): + with gr.Row(elem_classes=["fill"]): + sd_gallery = gr.Gallery( + label="Generated images", + show_label=False, + elem_id="gallery", + columns=2, + object_fit="fit", + preview=True, + ) + with gr.Row(): + std_output = gr.Textbox( + value=f"{sd_model_info}\n" + f"Images will be saved at " + f"{get_generated_imgs_path()}", + lines=2, + elem_id="std_output", + show_label=True, + label="Log", + show_copy_button=True, + ) + sd_element.load( + logger.read_sd_logs, None, std_output, every=1 + ) + sd_status = gr.Textbox(visible=False) + with gr.Row(): + batch_count = gr.Slider( + 1, + 100, + value=cmd_opts.batch_count, + step=1, + label="Batch Count", + interactive=True, + ) + batch_size = gr.Slider( + 1, + 4, + value=cmd_opts.batch_size, + step=1, + label="Batch Size", + interactive=True, + visible=True, + ) + repeatable_seeds = gr.Checkbox( + cmd_opts.repeatable_seeds, + label="Use Repeatable Seeds for Batches", + ) + with gr.Row(): + stable_diffusion = gr.Button("Start") + random_seed = gr.Button("Randomize Seed") + random_seed.click( + lambda: -1, + inputs=[], + outputs=[seed], + queue=False, + show_progress=False, + ) + stop_batch = gr.Button("Stop") + with gr.Tab(label="Config", id=102) as sd_tab_config: + with gr.Column(elem_classes=["sd-right-panel"]): + with gr.Row(elem_classes=["fill"]): + Path(get_configs_path()).mkdir( + parents=True, exist_ok=True + ) + default_config_file = os.path.join( + get_configs_path(), + "default_sd_config.json", + ) + write_default_sd_config(default_config_file) + sd_json = gr.JSON( + elem_classes=["fill"], + value=view_json_file(default_config_file), + ) + with gr.Row(): + with gr.Column(scale=3): + load_sd_config = gr.FileExplorer( + label="Load Config", + file_count="single", + root_dir=( + cmd_opts.configs_path + if cmd_opts.configs_path + else get_configs_path() + ), + height=75, + ) + with gr.Column(scale=1): + save_sd_config = gr.Button( + value="Save Config", size="sm" + ) + clear_sd_config = gr.ClearButton( + value="Clear Config", + size="sm", + components=sd_json, + ) + with gr.Row(): + sd_config_name = gr.Textbox( + value="Config Name", + info="Name of the file this config will be saved to.", + interactive=True, + show_label=False, + ) + load_sd_config.change( + fn=load_sd_cfg, + inputs=[sd_json, load_sd_config], + outputs=[ + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + base_model_id, + custom_weights, + custom_vae, + precision, + device, + ondemand, + repeatable_seeds, + resample_type, + cnet_config, + embeddings_config, + sd_json, + ], + ) + save_sd_config.click( + fn=save_sd_cfg, + inputs=[sd_json, sd_config_name], + outputs=[sd_config_name], + ) + save_sd_config.click( + fn=save_sd_cfg, + inputs=[sd_json, sd_config_name], + outputs=[sd_config_name], + ) + + pull_kwargs = dict( + fn=pull_sd_configs, + inputs=[ + prompt, + negative_prompt, + sd_init_image, + height, + width, + steps, + strength, + guidance_scale, + seed, + batch_count, + batch_size, + scheduler, + base_model_id, + custom_weights, + custom_vae, + precision, + device, + ondemand, + repeatable_seeds, + resample_type, + cnet_config, + embeddings_config, + ], + outputs=[ + sd_json, + ], + ) + + status_kwargs = dict( + fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs), + inputs=[batch_count, batch_size], + outputs=sd_status, + ) + + gen_kwargs = dict( + fn=shark_sd_fn_dict_input, + inputs=[sd_json], + outputs=[ + sd_gallery, + sd_status, + ], + ) + + prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs) + neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs) + generate_click = ( + stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs) + ) + stop_batch.click( + fn=cancel_sd, + cancels=[prompt_submit, neg_prompt_submit, generate_click], + ) diff --git a/apps/shark_studio/web/ui/utils.py b/apps/shark_studio/web/ui/utils.py new file mode 100644 index 0000000000..cee1a6d02e --- /dev/null +++ b/apps/shark_studio/web/ui/utils.py @@ -0,0 +1,43 @@ +from enum import IntEnum +import math +import sys +import os + + +def resource_path(relative_path): + """Get absolute path to resource, works for dev and for PyInstaller""" + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_path, relative_path) + + +nodlogo_loc = resource_path("logos/nod-logo.png") +nodicon_loc = resource_path("logos/nod-icon.png") + + +class HSLHue(IntEnum): + RED = 0 + YELLOW = 60 + GREEN = 120 + CYAN = 180 + BLUE = 240 + MAGENTA = 300 + + +def hsl_color(alpha: float, start, end): + b = (end - start) * (alpha if alpha > 0 else 0) + result = b + start + + # Return a CSS HSL string + return f"hsl({math.floor(result)}, 80%, 35%)" + + +def none_to_str_none(props: dict): + for key in props: + props[key] = "None" if props[key] == None else props[key] + return props + + +def str_none_to_none(props: dict): + for key in props: + props[key] = None if props[key] == "None" else props[key] + return props diff --git a/apps/shark_studio/web/utils/__init__.py b/apps/shark_studio/web/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/apps/shark_studio/web/utils/file_utils.py b/apps/shark_studio/web/utils/file_utils.py new file mode 100644 index 0000000000..0f1953f5ac --- /dev/null +++ b/apps/shark_studio/web/utils/file_utils.py @@ -0,0 +1,121 @@ +import os +import sys +import glob +from datetime import datetime as dt +from pathlib import Path + +from apps.shark_studio.modules.shared_cmd_opts import cmd_opts + +checkpoints_filetypes = ( + "*.ckpt", + "*.safetensors", +) + +default_sd_config = r"""{ + "prompt": [ + "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" + ], + "negative_prompt": [ + "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" + ], + "sd_init_image": [null], + "height": 512, + "width": 512, + "steps": 50, + "strength": 0.8, + "guidance_scale": 7.5, + "seed": "-1", + "batch_count": 1, + "batch_size": 1, + "scheduler": "EulerDiscrete", + "base_model_id": "stabilityai/stable-diffusion-2-1-base", + "custom_weights": null, + "custom_vae": null, + "precision": "fp16", + "device": "AMD Radeon RX 7900 XTX => vulkan://0", + "ondemand": false, + "repeatable_seeds": false, + "resample_type": "Nearest Neighbor", + "controlnets": {}, + "embeddings": {} +}""" + + +def write_default_sd_config(path): + with open(path, "w") as f: + f.write(default_sd_config) + + +def safe_name(name): + return name.replace("/", "_").replace("-", "_") + + +def get_path_stem(path): + path = Path(path) + return path.stem + + +def get_resource_path(path): + """Get absolute path to resource, works for dev and for PyInstaller""" + if os.path.isabs(path): + return path + else: + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) + result = Path(os.path.join(base_path, path)).resolve(strict=False) + return result + + +def get_configs_path() -> Path: + configs = get_resource_path(os.path.join("..", "configs")) + if not os.path.exists(configs): + os.mkdir(configs) + return Path(get_resource_path("../configs")) + + +def get_generated_imgs_path() -> Path: + return Path( + cmd_opts.output_dir + if cmd_opts.output_dir + else get_resource_path("../generated_imgs") + ) + + +def get_generated_imgs_todays_subdir() -> str: + return dt.now().strftime("%Y%m%d") + + +def create_checkpoint_folders(): + dir = ["checkpoints", "vae", "lora", "vmfb"] + if not os.path.isdir(cmd_opts.ckpt_dir): + try: + os.makedirs(cmd_opts.ckpt_dir) + except OSError: + sys.exit( + f"Invalid --ckpt_dir argument, " + f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created." + ) + + for root in dir: + Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True) + + +def get_checkpoints_path(model_type=""): + return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type)) + + +def get_checkpoints(model_type="checkpoints"): + ckpt_files = [] + file_types = checkpoints_filetypes + if model_type == "lora": + file_types = file_types + ("*.pt", "*.bin") + for extn in file_types: + files = [ + os.path.basename(x) + for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn)) + ] + ckpt_files.extend(files) + return sorted(ckpt_files, key=str.casefold) + + +def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"): + return os.path.join(get_checkpoints_path(model_type), checkpoint_name) diff --git a/apps/shark_studio/web/utils/globals.py b/apps/shark_studio/web/utils/globals.py new file mode 100644 index 0000000000..27910e74ef --- /dev/null +++ b/apps/shark_studio/web/utils/globals.py @@ -0,0 +1,134 @@ +import gc +from ...api.utils import get_available_devices + +""" +The global objects include SD pipeline and config. +Maintaining the global objects would avoid creating extra pipeline objects when switching modes. +Also we could avoid memory leak when switching models by clearing the cache. +""" + + +def _init(): + global _sd_obj + global _llm_obj + global _devices + global _pipe_kwargs + global _prep_kwargs + global _gen_kwargs + global _schedulers + _sd_obj = None + _llm_obj = None + _devices = None + _pipe_kwargs = None + _prep_kwargs = None + _gen_kwargs = None + _schedulers = None + set_devices() + + +def set_sd_obj(value): + global _sd_obj + global _llm_obj + _llm_obj = None + _sd_obj = value + + +def set_llm_obj(value): + global _sd_obj + global _llm_obj + _llm_obj = value + _sd_obj = None + + +def set_devices(): + global _devices + _devices = get_available_devices() + + +def set_sd_scheduler(key): + global _sd_obj + _sd_obj.scheduler = _schedulers[key] + + +def set_sd_status(value): + global _sd_obj + _sd_obj.status = value + + +def set_pipe_kwargs(value): + global _pipe_kwargs + _pipe_kwargs = value + + +def set_prep_kwargs(value): + global _prep_kwargs + _prep_kwargs = value + + +def set_gen_kwargs(value): + global _gen_kwargs + _gen_kwargs = value + + +def set_schedulers(value): + global _schedulers + _schedulers = value + + +def get_sd_obj(): + global _sd_obj + return _sd_obj + + +def get_llm_obj(): + global _llm_obj + return _llm_obj + + +def get_device_list(): + global _devices + return _devices + + +def get_sd_status(): + global _sd_obj + return _sd_obj.status + + +def get_pipe_kwargs(): + global _pipe_kwargs + return _pipe_kwargs + + +def get_prep_kwargs(): + global _prep_kwargs + return _prep_kwargs + + +def get_gen_kwargs(): + global _gen_kwargs + return _gen_kwargs + + +def get_scheduler(key): + global _schedulers + return _schedulers[key] + + +def clear_cache(): + global _sd_obj + global _llm_obj + global _pipe_kwargs + global _prep_kwargs + global _gen_kwargs + global _schedulers + del _sd_obj + del _llm_obj + del _schedulers + gc.collect() + _sd_obj = None + _llm_obj = None + _pipe_kwargs = None + _prep_kwargs = None + _gen_kwargs = None + _schedulers = None diff --git a/apps/shark_studio/web/utils/metadata/__init__.py b/apps/shark_studio/web/utils/metadata/__init__.py new file mode 100644 index 0000000000..bcbcf746ca --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/__init__.py @@ -0,0 +1,6 @@ +from .png_metadata import ( + import_png_metadata, +) +from .display import ( + displayable_metadata, +) diff --git a/apps/shark_studio/web/utils/metadata/csv_metadata.py b/apps/shark_studio/web/utils/metadata/csv_metadata.py new file mode 100644 index 0000000000..d515234083 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/csv_metadata.py @@ -0,0 +1,43 @@ +import csv +import os +from .format import humanize, humanizable + + +def csv_path(image_filename: str): + return os.path.join(os.path.dirname(image_filename), "imgs_details.csv") + + +def has_csv(image_filename: str) -> bool: + return os.path.exists(csv_path(image_filename)) + + +def matching_filename(image_filename: str, row): + # we assume the final column of the csv has the original filename with full path and match that + # against the image_filename if we are given a list. Otherwise we assume a dict and and take + # the value of the OUTPUT key + return os.path.basename(image_filename) in ( + row[-1] if isinstance(row, list) else row["OUTPUT"] + ) + + +def parse_csv(image_filename: str): + csv_filename = csv_path(image_filename) + + with open(csv_filename, "r", newline="") as csv_file: + # We use a reader or DictReader here for images_details.csv depending on whether we think it + # has headers or not. Having headers means less guessing of the format. + has_header = csv.Sniffer().has_header(csv_file.read(2048)) + csv_file.seek(0) + + reader = csv.DictReader(csv_file) if has_header else csv.reader(csv_file) + + matches = [ + # we rely on humanize and humanizable to work out the parsing of the individual .csv rows + humanize(row) + for row in reader + if row + and (has_header or humanizable(row)) + and matching_filename(image_filename, row) + ] + + return matches[0] if matches else {} diff --git a/apps/shark_studio/web/utils/metadata/display.py b/apps/shark_studio/web/utils/metadata/display.py new file mode 100644 index 0000000000..26234aab5c --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/display.py @@ -0,0 +1,53 @@ +import json +import os +from PIL import Image +from .png_metadata import parse_generation_parameters +from .exif_metadata import has_exif, parse_exif +from .csv_metadata import has_csv, parse_csv +from .format import compact, humanize + + +def displayable_metadata(image_filename: str) -> dict: + if not os.path.isfile(image_filename): + return {"source": "missing", "parameters": {}} + + pil_image = Image.open(image_filename) + + # we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads, + # and we go via that for SendTo, and is directly tied to the image) + if "parameters" in pil_image.info: + return { + "source": "png", + "parameters": compact( + parse_generation_parameters(pil_image.info["parameters"]) + ), + } + + # we have a matching json file (next most likely to be accurate when it's there) + json_path = os.path.splitext(image_filename)[0] + ".json" + if os.path.isfile(json_path): + with open(json_path) as params_file: + return { + "source": "json", + "parameters": compact( + humanize(json.load(params_file), includes_filename=False) + ), + } + + # we have a CSV file so try that (can be different shapes, and it usually has no + # headers/param names so of the things we we *know* have parameters, it's the + # last resort) + if has_csv(image_filename): + params = parse_csv(image_filename) + if params: # we might not have found the filename in the csv + return { + "source": "csv", + "parameters": compact(params), # already humanized + } + + # EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something* + if has_exif(image_filename): + return {"source": "exif", "parameters": parse_exif(pil_image)} + + # we've got nothing + return None diff --git a/apps/shark_studio/web/utils/metadata/exif_metadata.py b/apps/shark_studio/web/utils/metadata/exif_metadata.py new file mode 100644 index 0000000000..c72da8a935 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/exif_metadata.py @@ -0,0 +1,52 @@ +from PIL import Image +from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS + + +def has_exif(image_filename: str) -> bool: + return True if Image.open(image_filename).getexif() else False + + +def parse_exif(pil_image: Image) -> dict: + img_exif = pil_image.getexif() + + # See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594 + # I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I + # I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a + # dependency + exif_tags = { + TAGS.get(key, key): str(val) + for (key, val) in img_exif.items() + if key in TAGS + and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo) + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + def try_get_ifd(ifd_id): + try: + return img_exif.get_ifd(ifd_id).items() + except KeyError: + return {} + + ifd_tags = { + TAGS.get(key, key): str(val) + for ifd_id in IFD + for (key, val) in try_get_ifd(ifd_id) + if ifd_id != IFD.GPSInfo + and key in TAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + gps_tags = { + GPSTAGS.get(key, key): str(val) + for (key, val) in try_get_ifd(IFD.GPSInfo) + if key in GPSTAGS + and val + and (not isinstance(val, bytes)) + and (not str(val).isspace()) + } + + return {**exif_tags, **ifd_tags, **gps_tags} diff --git a/apps/shark_studio/web/utils/metadata/format.py b/apps/shark_studio/web/utils/metadata/format.py new file mode 100644 index 0000000000..308d9f8e8b --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/format.py @@ -0,0 +1,139 @@ +# As SHARK has evolved more columns have been added to images_details.csv. However, since +# no version of the CSV has any headers (yet) we don't actually have anything within the +# file that tells us which parameter each column is for. So this is a list of known patterns +# indexed by length which is what we're going to have to use to guess which columns are the +# right ones for the file we're looking at. + +# The same ordering is used for JSON, but these do have key names, however they are not very +# human friendly, nor do they match up with the what is written to the .png headers + +# So these are functions to try and get something consistent out the raw input from all +# these sources + +PARAMS_FORMATS = { + 9: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 10: { + "MODEL": "Model", + "VARIANT": "Variant", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "OUTPUT": "Filename", + }, + 12: { + "VARIANT": "Model", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", + }, +} + +PARAMS_FORMAT_CURRENT = { + "VARIANT": "Model", + "VAE": "VAE", + "LORA": "LoRA", + "SCHEDULER": "Sampler", + "PROMPT": "Prompt", + "NEG_PROMPT": "Negative prompt", + "SEED": "Seed", + "CFG_SCALE": "CFG scale", + "PRECISION": "Precision", + "STEPS": "Steps", + "HEIGHT": "Height", + "WIDTH": "Width", + "MAX_LENGTH": "Max Length", + "OUTPUT": "Filename", +} + + +def compact(metadata: dict) -> dict: + # we don't want to alter the original dictionary + result = dict(metadata) + + # discard the filename because we should already have it + if result.keys() & {"Filename"}: + result.pop("Filename") + + # make showing the sizes more compact by using only one line each + if result.keys() & {"Size-1", "Size-2"}: + result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}" + elif result.keys() & {"Height", "Width"}: + result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}" + + if result.keys() & {"Hires resize-1", "Hires resize-1"}: + hires_y = result.pop("Hires resize-1") + hires_x = result.pop("Hires resize-2") + + if hires_x == 0 and hires_y == 0: + result["Hires resize"] = "None" + else: + result["Hires resize"] = f"{hires_y}x{hires_x}" + + # remove VAE if it exists and is empty + if (result.keys() & {"VAE"}) and (not result["VAE"] or result["VAE"] == "None"): + result.pop("VAE") + + # remove LoRA if it exists and is empty + if (result.keys() & {"LoRA"}) and (not result["LoRA"] or result["LoRA"] == "None"): + result.pop("LoRA") + + return result + + +def humanizable(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + return lookup_key in PARAMS_FORMATS.keys() + + +def humanize(metadata: dict | list[str], includes_filename=True) -> dict: + lookup_key = len(metadata) + (0 if includes_filename else 1) + + # For lists we can only work based on the length, we have no other information + if isinstance(metadata, list): + if humanizable(metadata, includes_filename): + return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata)) + else: + raise KeyError( + f"Humanize could not find the format for a parameter list of length {len(metadata)}" + ) + + # For dictionaries we try to use the matching length parameter format if + # available, otherwise we just use the current format which is assumed to + # have everything currently known about. Then we swap keys in the metadata + # that match keys in the format for the friendlier name that we have set + # in the format value + if isinstance(metadata, dict): + if humanizable(metadata, includes_filename): + format = PARAMS_FORMATS[lookup_key] + else: + format = PARAMS_FORMAT_CURRENT + + return { + format[key]: metadata[key] + for key in format.keys() + if key in metadata.keys() and metadata[key] + } + + raise TypeError("Can only humanize parameter lists or dictionaries") diff --git a/apps/shark_studio/web/utils/metadata/png_metadata.py b/apps/shark_studio/web/utils/metadata/png_metadata.py new file mode 100644 index 0000000000..72f663f246 --- /dev/null +++ b/apps/shark_studio/web/utils/metadata/png_metadata.py @@ -0,0 +1,217 @@ +import re +from pathlib import Path +from apps.shark_studio.web.utils.file_utils import ( + get_checkpoint_pathfile, +) +from apps.shark_studio.api.sd import ( + sd_model_map, +) +from apps.shark_studio.modules.schedulers import ( + scheduler_model_map, +) + +re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") + + +def parse_generation_parameters(x: str): + res = {} + prompt = "" + negative_prompt = "" + done_with_prompt = False + + *lines, lastline = x.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = "" + + for i, line in enumerate(lines): + line = line.strip() + if line.startswith("Negative prompt:"): + done_with_prompt = True + line = line[16:].strip() + + if done_with_prompt: + negative_prompt += ("" if negative_prompt == "" else "\n") + line + else: + prompt += ("" if prompt == "" else "\n") + line + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + + for k, v in re_param.findall(lastline): + v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v + m = re_imagesize.match(v) + if m is not None: + res[k + "-1"] = m.group(1) + res[k + "-2"] = m.group(2) + else: + res[k] = v + + # Missing CLIP skip means it was set to 1 (the default) + if "Clip skip" not in res: + res["Clip skip"] = "1" + + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res[ + "Prompt" + ] += f"""""" + + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + + return res + + +def try_find_model_base_from_png_metadata(file: str, folder: str = "models") -> str: + custom = "" + + # Remove extension from file info + if file.endswith(".safetensors") or file.endswith(".ckpt"): + file = Path(file).stem + # Check for the file name match with one of the local ckpt or safetensors files + if Path(get_checkpoint_pathfile(file + ".ckpt", folder)).is_file(): + custom = file + ".ckpt" + if Path(get_checkpoint_pathfile(file + ".safetensors", folder)).is_file(): + custom = file + ".safetensors" + + return custom + + +def find_model_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + png_hf_id = "" + png_custom = "" + + if key in metadata: + model_file = metadata[key] + png_custom = try_find_model_base_from_png_metadata(model_file) + # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") + if model_file in sd_model_map: + png_custom = model_file + # If nothing had matched, check vendor/hf_model_id + if not png_custom and model_file.count("/"): + png_hf_id = model_file + # No matching model was found + if not png_custom and not png_hf_id: + print( + "Import PNG info: Unable to find a matching model for %s" % model_file + ) + + return png_custom, png_hf_id + + +def find_vae_from_png_metadata(key: str, metadata: dict[str, str | int]) -> str: + vae_custom = "" + + if key in metadata: + vae_file = metadata[key] + vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") + + # VAE input is optional, should not print or throw an error if missing + + return vae_custom + + +def find_lora_from_png_metadata( + key: str, metadata: dict[str, str | int] +) -> tuple[str, str]: + lora_hf_id = "" + lora_custom = "" + + if key in metadata: + lora_file = metadata[key] + lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") + # If nothing had matched, check vendor/hf_model_id + if not lora_custom and lora_file.count("/"): + lora_hf_id = lora_file + + # LoRA input is optional, should not print or throw an error if missing + + return lora_custom, lora_hf_id + + +def import_png_metadata( + pil_data, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, +): + try: + png_info = pil_data.info["parameters"] + metadata = parse_generation_parameters(png_info) + + (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( + "Model", metadata + ) + (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( + "LoRA", metadata + ) + vae_custom_model = find_vae_from_png_metadata("VAE", metadata) + + negative_prompt = metadata["Negative prompt"] + steps = int(metadata["Steps"]) + cfg_scale = float(metadata["CFG scale"]) + seed = int(metadata["Seed"]) + width = float(metadata["Size-1"]) + height = float(metadata["Size-2"]) + + if "Model" in metadata and png_custom_model: + custom_model = png_custom_model + elif "Model" in metadata and png_hf_model_id: + custom_model = png_hf_model_id + + if "LoRA" in metadata and lora_custom_model: + custom_lora = lora_custom_model + hf_lora_id = "" + if "LoRA" in metadata and lora_hf_model_id: + custom_lora = "None" + hf_lora_id = lora_hf_model_id + + if "VAE" in metadata and vae_custom_model: + custom_vae = vae_custom_model + + if "Prompt" in metadata: + prompt = metadata["Prompt"] + if "Sampler" in metadata: + if metadata["Sampler"] in scheduler_model_map: + sampler = metadata["Sampler"] + else: + print( + "Import PNG info: Unable to find a scheduler for %s" + % metadata["Sampler"] + ) + + except Exception as ex: + if pil_data and pil_data.info.get("parameters"): + print("import_png_metadata failed with %s" % ex) + pass + + return ( + None, + prompt, + negative_prompt, + steps, + sampler, + cfg_scale, + seed, + width, + height, + custom_model, + custom_lora, + hf_lora_id, + custom_vae, + ) diff --git a/apps/shark_studio/web/utils/state.py b/apps/shark_studio/web/utils/state.py new file mode 100644 index 0000000000..133c8fd82f --- /dev/null +++ b/apps/shark_studio/web/utils/state.py @@ -0,0 +1,39 @@ +import apps.shark_studio.web.utils.globals as global_obj +import gc + + +def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1): + if batch_index < batch_count: + bs = f"x{batch_size}" if batch_size > 1 else "" + return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}" + else: + return f"{tab_name} complete" + + +def get_generation_text_info(seeds, device): + cfg_dump = {} + for cfg in global_obj.get_config_dict(): + cfg_dump[cfg] = cfg + text_output = f"prompt={cfg_dump['prompts']}" + text_output += f"\nnegative prompt={cfg_dump['negative_prompts']}" + text_output += ( + f"\nmodel_id={cfg_dump['hf_model_id']}, " f"ckpt_loc={cfg_dump['ckpt_loc']}" + ) + text_output += f"\nscheduler={cfg_dump['scheduler']}, " f"device={device}" + text_output += ( + f"\nsteps={cfg_dump['steps']}, " + f"guidance_scale={cfg_dump['guidance_scale']}, " + f"seed={seeds}" + ) + text_output += ( + f"\nsize={cfg_dump['height']}x{cfg_dump['width']}, " + if not cfg_dump.use_hiresfix + else f"\nsize={cfg_dump['hiresfix_height']}x{cfg_dump['hiresfix_width']}, " + ) + text_output += ( + f"batch_count={cfg_dump['batch_count']}, " + f"batch_size={cfg_dump['batch_size']}, " + f"max_length={cfg_dump['max_length']}" + ) + + return text_output diff --git a/apps/shark_studio/web/utils/tmp_configs.py b/apps/shark_studio/web/utils/tmp_configs.py new file mode 100644 index 0000000000..4415276ea3 --- /dev/null +++ b/apps/shark_studio/web/utils/tmp_configs.py @@ -0,0 +1,73 @@ +import os +import shutil +from time import time + +shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") + + +def clear_tmp_mlir(): + cleanup_start = time() + print("Clearing .mlir temporary files from a prior run. This may take some time...") + mlir_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.endswith(".mlir") + ] + for filename in mlir_files: + os.remove(shark_tmp + filename) + print(f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds.") + + +def clear_tmp_imgs(): + # tell gradio to use a directory under shark_tmp for its temporary + # image files unless somewhere else has been set + if "GRADIO_TEMP_DIR" not in os.environ: + os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio") + + print( + f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. " + + "You may change this by setting the GRADIO_TEMP_DIR environment variable." + ) + + # Clear all gradio tmp images from the last session + if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): + cleanup_start = time() + print( + "Clearing gradio UI temporary image files from a prior run. This may take some time..." + ) + shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True) + print( + f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds." + ) + + # older SHARK versions had to workaround gradio bugs and stored things differently + else: + image_files = [ + filename + for filename in os.listdir(shark_tmp) + if os.path.isfile(os.path.join(shark_tmp, filename)) + and filename.startswith("tmp") + and filename.endswith(".png") + ] + if len(image_files) > 0: + print( + "Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..." + ) + cleanup_start = time() + for filename in image_files: + os.remove(shark_tmp + filename) + print( + f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds." + ) + else: + print("No temporary images files to clear.") + + +def config_tmp(): + # create shark_tmp if it does not exist + if not os.path.exists(shark_tmp): + os.mkdir(shark_tmp) + + clear_tmp_mlir() + clear_tmp_imgs() diff --git a/process_skipfiles.py b/process_skipfiles.py index 339c7ebec6..9086ce59bf 100644 --- a/process_skipfiles.py +++ b/process_skipfiles.py @@ -5,6 +5,7 @@ from distutils.sysconfig import get_python_lib import fileinput from pathlib import Path +import os # Temporary workaround for transformers/__init__.py. path_to_transformers_hook = Path( @@ -16,51 +17,16 @@ with open(path_to_transformers_hook, "w") as f: f.write("module_collection_mode = 'pyz+py'") -path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py") +paths_to_skipfiles = [Path(get_python_lib() + "/torch/_dynamo/skipfiles.py"), Path(get_python_lib() + "/torch/_dynamo/trace_rules.py")] -modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"] -startMonitoring = 0 -for line in fileinput.input(path_to_skipfiles, inplace=True): - if "SKIP_DIRS = " in line: - startMonitoring = 1 - print(line, end="") - elif startMonitoring in [1, 2]: - if "]" in line: - startMonitoring += 1 +for path in paths_to_skipfiles: + if not os.path.isfile(path): + continue + for line in fileinput.input(path, inplace=True): + if "[_module_dir(m) for m in BUILTIN_SKIPLIST]" in line and "x.__name__ for x in BUILTIN_SKIPLIST" not in line: + print(f"{line.rstrip()} + [x.__name__ for x in BUILTIN_SKIPLIST]") + elif "(_module_dir(m) for m in BUILTIN_SKIPLIST)" in line and "x.__name__ for x in BUILTIN_SKIPLIST" not in line: print(line, end="") + print(f"SKIP_DIRS.extend(filter(None, (x.__name__ for x in BUILTIN_SKIPLIST)))") else: - flag = True - for module in modules_to_comment: - if module in line: - if not line.startswith("#"): - print(f"#{line}", end="") - else: - print(f"{line[1:]}", end="") - flag = False - break - if flag: - print(line, end="") - else: - print(line, end="") - -# For getting around scikit-image's packaging, laze_loader has had a patch merged but yet to be released. -# Refer: https://github.com/scientific-python/lazy_loader -path_to_lazy_loader = Path(get_python_lib() + "/lazy_loader/__init__.py") - -for line in fileinput.input(path_to_lazy_loader, inplace=True): - if 'stubfile = filename if filename.endswith("i")' in line: - print( - ' stubfile = (filename if filename.endswith("i") else f"{os.path.splitext(filename)[0]}.pyi")', - end="", - ) - else: - print(line, end="") - -# For getting around timm's packaging. -# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505 -path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py") -for line in fileinput.input(path_to_timm_activations, inplace=True): - if "@torch.jit.script" in line: - print("@torch.jit._script_if_tracing", end="\n") - else: - print(line, end="") + print(line, end="") diff --git a/requirements-importer-macos.txt b/requirements-importer-macos.txt deleted file mode 100644 index 36e837b320..0000000000 --- a/requirements-importer-macos.txt +++ /dev/null @@ -1,34 +0,0 @@ --f https://download.pytorch.org/whl/nightly/cpu/ ---pre - -numpy -torch -torchvision - -tqdm - -#iree-compiler | iree-runtime should already be installed - -transformers -#jax[cpu] - -# tflitehub dependencies. -Pillow - -# web dependecies. -gradio -altair - -# Testing and support. -#lit -#pyyaml - -#ONNX and ORT for benchmarking -#--extra-index-url https://test.pypi.org/simple/ -#protobuf -#coloredlogs -#flatbuffers -#sympy -#psutil -#onnx-weekly -#ort-nightly diff --git a/requirements-importer.txt b/requirements-importer.txt deleted file mode 100644 index 3fe3a64659..0000000000 --- a/requirements-importer.txt +++ /dev/null @@ -1,41 +0,0 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html ---pre - -numpy>1.22.4 -pytorch-triton -torchvision -tabulate - -tqdm - -#iree-compiler | iree-runtime should already be installed -iree-tools-xla - -# Modelling and JAX. -gin-config -transformers -diffusers -#jax[cpu] -Pillow - -# Testing and support. -lit -pyyaml -python-dateutil -sacremoses -sentencepiece - -# web dependecies. -gradio==3.44.3 -altair -scipy - -#ONNX and ORT for benchmarking -#--extra-index-url https://test.pypi.org/simple/ -#protobuf -#coloredlogs -#flatbuffers -#sympy -#psutil -#onnx-weekly -#ort-nightly diff --git a/requirements.txt b/requirements.txt index 19d4521280..eb5ee5c505 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,9 @@ setuptools wheel -shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine#egg=shark-turbine&subdirectory=core -turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=models +torch==2.3.0.dev20240305 +shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sd-fp16#subdirectory=core +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-sd-fp16#subdirectory=models # SHARK Runner tqdm @@ -26,29 +27,15 @@ parameterized accelerate scipy ftfy -gradio==4.8.0 +gradio==4.19.2 altair omegaconf # 0.3.2 doesn't have binaries for arm64 safetensors==0.3.1 -opencv-python -scikit-image -pytorch_lightning # for runwayml models -tk -pywebview -sentencepiece py-cpuinfo -tiktoken # for codegen -joblib # for langchain -timm # for MiniGPT4 -langchain -einops # for zoedepth pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions +mpmath==1.3.0 # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile pyinstaller - -# For quantized GPTQ models -optimum -auto_gptq diff --git a/setup.py b/setup.py index 061873e7a8..d1aea1687a 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() -PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5" +PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "2.0.0" backend_deps = [] setup( diff --git a/setup_venv.ps1 b/setup_venv.ps1 index 09489bf4cc..c67b8fc83b 100644 --- a/setup_venv.ps1 +++ b/setup_venv.ps1 @@ -7,13 +7,13 @@ It checks the Python version installed and installs any required build dependencies into a Python virtual environment. If that environment does not exist, it creates it. - + .PARAMETER update-src git pulls latest version .PARAMETER force removes and recreates venv to force update of all dependencies - + .EXAMPLE .\setup_venv.ps1 --force @@ -39,7 +39,7 @@ if ($arguments -eq "--force"){ Write-Host "deactivating..." Deactivate } - + if (Test-Path .\shark.venv\) { Write-Host "removing and recreating venv..." Remove-Item .\shark.venv -Force -Recurse @@ -89,9 +89,7 @@ else {python -m venv .\shark.venv\} python -m pip install --upgrade pip pip install wheel pip install -r requirements.txt -pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime -Write-Host "Building SHARK..." -pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html -Write-Host "Build and installation completed successfully" +# remove this when windows DLL issues are fixed from LLVM changes +pip install --force-reinstall https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_compiler-20240326.843-cp311-cp311-win_amd64.whl https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_runtime-20240326.843-cp311-cp311-win_amd64.whl + Write-Host "Source your venv with ./shark.venv/Scripts/activate" diff --git a/setup_venv.sh b/setup_venv.sh index 62c6513a85..64f769d794 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -49,58 +49,20 @@ Red=`tput setaf 1` Green=`tput setaf 2` Yellow=`tput setaf 3` -# Assume no binary torch-mlir. -# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11) -torch_mlir_bin=false -if [[ $(uname -s) = 'Darwin' ]]; then - echo "${Yellow}Apple macOS detected" - if [[ $(uname -m) == 'arm64' ]]; then - echo "${Yellow}Apple M1 Detected" - hash rustc 2>/dev/null - if [ $? -eq 0 ];then - echo "${Green}rustc found to compile HF tokenizers" - else - echo "${Red}Could not find rustc" >&2 - echo "${Red}Please run:" - echo "${Red}curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh" - exit 1 - fi - fi - echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests" - echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command" - if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then - torch_mlir_bin=true - fi -elif [[ $(uname -s) = 'Linux' ]]; then - echo "${Yellow}Linux detected" - if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then - torch_mlir_bin=true - fi -else - echo "${Red}OS not detected. Pray and Play" -fi - # Upgrade pip and install requirements. $PYTHON -m pip install --upgrade pip || die "Could not upgrade pip" $PYTHON -m pip install --upgrade -r "$TD/requirements.txt" -if [ "$torch_mlir_bin" = true ]; then - if [[ $(uname -s) = 'Darwin' ]]; then - echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch." - $PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC - $PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/ +if [[ $(uname -s) = 'Darwin' ]]; then + echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch." + $PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC + $PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/ +else + $PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ + if [ $? -eq 0 ];then + echo "Successfully Installed torch-mlir" else - $PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ - if [ $? -eq 0 ];then - echo "Successfully Installed torch-mlir" - else - echo "Could not install torch-mlir" >&2 - fi + echo "Could not install torch-mlir" >&2 fi -else - echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)" - echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux" - echo "${Red}Please build torch-mlir from source in your environment" - exit 1 fi if [[ -z "${USE_IREE}" ]]; then rm .use-iree @@ -116,19 +78,6 @@ else echo "Not installing a backend, please make sure to add your backend to PYTHONPATH" fi -if [[ ! -z "${IMPORTER}" ]]; then - echo "${Yellow}Installing importer tools.." - if [[ $(uname -s) = 'Linux' ]]; then - echo "${Yellow}Linux detected.. installing Linux importer tools" - #Always get the importer tools from upstream IREE - $PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu - elif [[ $(uname -s) = 'Darwin' ]]; then - echo "${Yellow}macOS detected.. installing macOS importer tools" - #Conda seems to have some problems installing these packages and hope they get resolved upstream. - $PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu - fi -fi - if [[ $(uname -s) = 'Darwin' ]]; then PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/ else diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index ca6a12c45b..5fd1d4006a 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -64,6 +64,14 @@ def get_iree_device_args(device, extra_args=[]): return get_iree_rocm_args(device_num=device_num, extra_args=extra_args) return [] +def get_iree_target_triple(device): + args = get_iree_device_args(device) + for flag in args: + if "triple" in flag.split("-"): + triple = flag.split("=") + return triple + return "" + def clean_device_info(raw_device): # return appropriate device and device_id for consumption by Studio pipeline @@ -105,9 +113,8 @@ def get_iree_frontend_args(frontend): # Common args to be used given any frontend or device. def get_iree_common_args(debug=False): common_args = [ - "--iree-stream-resource-max-allocation-size=4294967295", - "--iree-vm-bytecode-module-strip-source-map=true", "--iree-util-zero-fill-elided-attrs", + "--mlir-elide-elementsattrs-if-larger=10", ] if debug == True: common_args.extend( diff --git a/shark/iree_utils/vulkan_target_env_utils.py b/shark/iree_utils/vulkan_target_env_utils.py index 92d2f53442..7cd1b05241 100644 --- a/shark/iree_utils/vulkan_target_env_utils.py +++ b/shark/iree_utils/vulkan_target_env_utils.py @@ -33,7 +33,7 @@ def get_vulkan_target_env(vulkan_target_triple): device_type = get_device_type(triple) # get capabilities capabilities = get_vulkan_target_capabilities(triple) - target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>" + target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>" return target_env @@ -63,62 +63,62 @@ def make_ext_list(ext_list): arch, product, os = triple if arch == "m1": ext = [ - "VK_KHR_16bit_storage", - "VK_KHR_8bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_variable_pointers", + "SPV_KHR_16bit_storage", + "SPV_KHR_8bit_storage", + "SPV_KHR_shader_float16_int8", + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_variable_pointers", ] return make_ext_list(ext_list=ext) if arch == "valhall": ext = [ - "VK_KHR_16bit_storage", - "VK_KHR_8bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_spirv_1_4", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_variable_pointers", + "SPV_KHR_16bit_storage", + "SPV_KHR_8bit_storage", + "SPV_KHR_shader_float16_int8", + "SPV_KHR_spirv_1_4", + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_variable_pointers", ] return make_ext_list(ext_list=ext) if arch == "adreno": ext = [ - "VK_KHR_16bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_spirv_1_4", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_variable_pointers", + "SPV_KHR_16bit_storage", + "SPV_KHR_shader_float16_int8", + "SPV_KHR_spirv_1_4", + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_variable_pointers", ] if os == "android31": - ext.append("VK_KHR_8bit_storage") + ext.append("SPV_KHR_8bit_storage") return make_ext_list(ext_list=ext) if get_vendor(triple) == "SwiftShader": - ext = ["VK_KHR_storage_buffer_storage_class"] + ext = ["SPV_KHR_storage_buffer_storage_class"] return make_ext_list(ext_list=ext) if arch == "unknown": ext = [ - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_variable_pointers", + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_variable_pointers", ] return make_ext_list(ext_list=ext) ext = [ - "VK_KHR_16bit_storage", - "VK_KHR_8bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_spirv_1_4", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_variable_pointers", + "SPV_KHR_16bit_storage", + "SPV_KHR_8bit_storage", + "SPV_KHR_shader_float16_int8", + "SPV_KHR_spirv_1_4", + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_variable_pointers", "VK_EXT_subgroup_size_control", ] if get_vendor(triple) == "NVIDIA" or arch == "rdna3": - ext.append("VK_KHR_cooperative_matrix") + ext.append("SPV_KHR_cooperative_matrix") if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]: - ext.append("VK_KHR_shader_integer_dot_product") + ext.append("SPV_KHR_shader_integer_dot_product") return make_ext_list(ext_list=ext) @@ -186,13 +186,13 @@ def get_subgroup_val(l): "Quad": 128, "PartitionedNV": 256, } - cap["maxComputeSharedMemorySize"] = 16384 - cap["maxComputeWorkGroupInvocations"] = 128 - cap["maxComputeWorkGroupSize"] = [128, 128, 64] - cap["subgroupSize"] = 32 + cap["max_compute_shared_memory_size"] = 16384 + cap["max_compute_workgroup_invocations"] = 128 + cap["max_compute_workgroup_size"] = [128, 128, 64] + cap["subgroup_size"] = 32 cap["subgroupFeatures"] = ["Basic"] - cap["minSubgroupSize"] = None - cap["maxSubgroupSize"] = None + cap["min_subgroup_size"] = None + cap["max_subgroup_size"] = None cap["shaderFloat16"] = False cap["shaderFloat64"] = False cap["shaderInt8"] = False @@ -209,13 +209,13 @@ def get_subgroup_val(l): cap["coopmatCases"] = None if arch in ["rdna1", "rdna2", "rdna3"]: - cap["maxComputeSharedMemorySize"] = 65536 - cap["maxComputeWorkGroupInvocations"] = 1024 - cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024] + cap["max_compute_shared_memory_size"] = 65536 + cap["max_compute_workgroup_invocations"] = 1024 + cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - cap["subgroupSize"] = 64 - cap["minSubgroupSize"] = 32 - cap["maxSubgroupSize"] = 64 + cap["subgroup_size"] = 64 + cap["min_subgroup_size"] = 32 + cap["max_subgroup_size"] = 64 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -244,7 +244,8 @@ def get_subgroup_val(l): if arch == "rdna3": # TODO: Get scope value cap["coopmatCases"] = [ - "mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope" + "m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = ", + "m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = " ] if product == "rx5700xt": @@ -252,11 +253,11 @@ def get_subgroup_val(l): cap["storagePushConstant8"] = False elif arch in ["rgcn5", "rgcn4", "rgcn3"]: - cap["maxComputeSharedMemorySize"] = 65536 - cap["maxComputeWorkGroupInvocations"] = 1024 - cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024] + cap["max_compute_shared_memory_size"] = 65536 + cap["max_compute_workgroup_invocations"] = 1024 + cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - cap["subgroupSize"] = 64 + cap["subgroup_size"] = 64 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -267,8 +268,8 @@ def get_subgroup_val(l): "Clustered", "Quad", ] - cap["minSubgroupSize"] = 64 - cap["maxSubgroupSize"] = 64 + cap["min_subgroup_size"] = 64 + cap["max_subgroup_size"] = 64 if arch == "rgcn5": cap["shaderFloat16"] = True @@ -290,11 +291,11 @@ def get_subgroup_val(l): cap["variablePointersStorageBuffer"] = True elif arch == "m1": - cap["maxComputeSharedMemorySize"] = 32768 - cap["maxComputeWorkGroupInvocations"] = 1024 - cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024] + cap["max_compute_shared_memory_size"] = 32768 + cap["max_compute_workgroup_invocations"] = 1024 + cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - cap["subgroupSize"] = 32 + cap["subgroup_size"] = 32 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -321,11 +322,11 @@ def get_subgroup_val(l): cap["variablePointersStorageBuffer"] = True elif arch == "valhall": - cap["maxComputeSharedMemorySize"] = 32768 - cap["maxComputeWorkGroupInvocations"] = 512 - cap["maxComputeWorkGroupSize"] = [512, 512, 512] + cap["max_compute_shared_memory_size"] = 32768 + cap["max_compute_workgroup_invocations"] = 512 + cap["max_compute_workgroup_size"] = [512, 512, 512] - cap["subgroupSize"] = 16 + cap["subgroup_size"] = 16 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -352,11 +353,11 @@ def get_subgroup_val(l): cap["variablePointersStorageBuffer"] = True elif arch == "arc": - cap["maxComputeSharedMemorySize"] = 32768 - cap["maxComputeWorkGroupInvocations"] = 1024 - cap["maxComputeWorkGroupSize"] = [1024, 1024, 64] + cap["max_compute_shared_memory_size"] = 32768 + cap["max_compute_workgroup_invocations"] = 1024 + cap["max_compute_workgroup_size"] = [1024, 1024, 64] - cap["subgroupSize"] = 32 + cap["subgroup_size"] = 32 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -385,8 +386,8 @@ def get_subgroup_val(l): elif arch == "cpu": if product == "swiftshader": - cap["maxComputeSharedMemorySize"] = 16384 - cap["subgroupSize"] = 4 + cap["max_compute_shared_memory_size"] = 16384 + cap["subgroup_size"] = 4 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -397,13 +398,13 @@ def get_subgroup_val(l): ] elif arch in ["pascal"]: - cap["maxComputeSharedMemorySize"] = 49152 - cap["maxComputeWorkGroupInvocations"] = 1536 - cap["maxComputeWorkGroupSize"] = [1536, 1024, 64] + cap["max_compute_shared_memory_size"] = 49152 + cap["max_compute_workgroup_invocations"] = 1536 + cap["max_compute_workgroup_size"] = [1536, 1024, 64] - cap["subgroupSize"] = 32 - cap["minSubgroupSize"] = 32 - cap["maxSubgroupSize"] = 32 + cap["subgroup_size"] = 32 + cap["min_subgroup_size"] = 32 + cap["max_subgroup_size"] = 32 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -431,13 +432,13 @@ def get_subgroup_val(l): cap["variablePointersStorageBuffer"] = True elif arch in ["ampere", "turing"]: - cap["maxComputeSharedMemorySize"] = 49152 - cap["maxComputeWorkGroupInvocations"] = 1024 - cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024] + cap["max_compute_shared_memory_size"] = 49152 + cap["max_compute_workgroup_invocations"] = 1024 + cap["max_compute_workgroup_size"] = [1024, 1024, 1024] - cap["subgroupSize"] = 32 - cap["minSubgroupSize"] = 32 - cap["maxSubgroupSize"] = 32 + cap["subgroup_size"] = 32 + cap["min_subgroup_size"] = 32 + cap["max_subgroup_size"] = 32 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -471,11 +472,11 @@ def get_subgroup_val(l): ] elif arch == "adreno": - cap["maxComputeSharedMemorySize"] = 32768 - cap["maxComputeWorkGroupInvocations"] = 1024 - cap["maxComputeWorkGroupSize"] = [1024, 1024, 64] + cap["max_compute_shared_memory_size"] = 32768 + cap["max_compute_workgroup_invocations"] = 1024 + cap["max_compute_workgroup_size"] = [1024, 1024, 64] - cap["subgroupSize"] = 64 + cap["subgroup_size"] = 64 cap["subgroupFeatures"] = [ "Basic", "Vote", @@ -491,14 +492,14 @@ def get_subgroup_val(l): cap["shaderInt16"] = True cap["storageBuffer16BitAccess"] = True - if os == "andorid31": + if os == "android31": cap["uniformAndStorageBuffer8BitAccess"] = True cap["variablePointers"] = True cap["variablePointersStorageBuffer"] = True elif arch == "unknown": - cap["subgroupSize"] = 64 + cap["subgroup_size"] = 64 cap["variablePointers"] = False cap["variablePointersStorageBuffer"] = False else: @@ -521,14 +522,14 @@ def get_comma_sep_str(ele_list): res += f"{k} = {'unit' if v == True else None}, " elif isinstance(v, list): if k == "subgroupFeatures": - res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, " - elif k == "maxComputeWorkGroupSize": - res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, " + res += f"subgroup_features = {get_subgroup_val(v)}: i32, " + elif k == "max_compute_workgroup_size": + res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, " elif k == "coopmatCases": cmc = "" for case in v: - cmc += f"#vk.coop_matrix_props<{case}>, " - res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], " + cmc += f"#spirv.coop_matrix_props_khr<{case}>, " + res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], " else: res += f"{k} = {get_comma_sep_str(v)}, " else: diff --git a/shark/iree_utils/vulkan_utils.py b/shark/iree_utils/vulkan_utils.py index a08fb6f5aa..ff394ea349 100644 --- a/shark/iree_utils/vulkan_utils.py +++ b/shark/iree_utils/vulkan_utils.py @@ -144,6 +144,8 @@ def get_vulkan_target_triple(device_name): # Intel Targets elif any(x in device_name for x in ("A770", "A750")): triple = f"arc-770-{system_os}" + elif "v620" in device_name: + triple = f"rdna2-v620-{system_os}" # Adreno Targets elif all(x in device_name for x in ("Adreno", "740")): @@ -169,7 +171,7 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]): print( f"Found vulkan device {vulkan_device}. Using target triple {triple}" ) - return f"-iree-vulkan-target-triple={triple}" + return f"--iree-vulkan-target-triple={triple}" print( """Optimized kernel for your target device is not added yet. Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u] @@ -184,7 +186,8 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]): res_vulkan_flag = [] res_vulkan_flag += [ - "--iree-stream-resource-max-allocation-size=3221225472" + "--iree-stream-resource-max-allocation-size=3221225472", + "--iree-flow-inline-constants-max-byte-length=0" ] vulkan_triple_flag = None for arg in extra_args: @@ -197,6 +200,7 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]): vulkan_triple_flag = get_vulkan_triple_flag( device_num=device_num, extra_args=extra_args ) + res_vulkan_flag += [vulkan_triple_flag] if vulkan_triple_flag is not None: vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)