Skip to content

Commit

Permalink
Enable UI / bugfixes / tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 15, 2023
1 parent ab32bfb commit b3d5add
Show file tree
Hide file tree
Showing 29 changed files with 2,648 additions and 466 deletions.
134 changes: 134 additions & 0 deletions apps/shark_studio/api/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# from turbine_models.custom_models.controlnet import control_adapter, preprocessors


class control_adapter:
def __init__(
self,
model: str,
):
self.model = None

def export_control_adapter_model(model_keyword):
return None

def export_xl_control_adapter_model(model_keyword):
return None


class preprocessors:
def __init__(
self,
model: str,
):
self.model = None

def export_controlnet_model(model_keyword):
return None


control_adapter_map = {
"sd15": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {
"initializer": control_adapter.export_control_adapter_model
},
"scribble": {
"initializer": control_adapter.export_control_adapter_model
},
"zoedepth": {
"initializer": control_adapter.export_control_adapter_model
},
},
"sdxl": {
"canny": {
"initializer": control_adapter.export_xl_control_adapter_model
},
},
}
preprocessor_model_map = {
"canny": {"initializer": preprocessors.export_controlnet_model},
"openpose": {"initializer": preprocessors.export_controlnet_model},
"scribble": {"initializer": preprocessors.export_controlnet_model},
"zoedepth": {"initializer": preprocessors.export_controlnet_model},
}


class PreprocessorModel:
def __init__(
self,
hf_model_id,
device,
):
self.model = None

def compile(self, device):
print("compile not implemented for preprocessor.")
return

def run(self, inputs):
print("run not implemented for preprocessor.")
return


def cnet_preview(model, input_img, stencils, images, preprocessed_hints):
if isinstance(input_image, PIL.Image.Image):
img_dict = {
"background": None,
"layers": [None],
"composite": input_image,
}
input_image = EditorValue(img_dict)
images[index] = input_image
if model:
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(
np.array(input_image["composite"]),
100,
200,
)
preprocessed_hints[index] = Image.fromarray(result)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result[0])
return (
Image.fromarray(result[0]),
stencils,
images,
preprocessed_hints,
)
case "zoedepth":
zoedepth = ZoeDetector()
result = zoedepth(np.array(input_image["composite"]))
preprocessed_hints[index] = Image.fromarray(result)
return (
Image.fromarray(result),
stencils,
images,
preprocessed_hints,
)
case "scribble":
preprocessed_hints[index] = input_image["composite"]
return (
input_image["composite"],
stencils,
images,
preprocessed_hints,
)
case _:
preprocessed_hints[index] = None
return (
None,
stencils,
images,
preprocessed_hints,
)
13 changes: 6 additions & 7 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,21 @@ def imports():

startup_timer.record("import gradio")

# from apps.shark_studio.modules import shared_init
# shared_init.initialize()
# startup_timer.record("initialize shared")
import apps.shark_studio.web.utils.globals as global_obj

global_obj._init()
startup_timer.record("initialize globals")

from apps.shark_studio.modules import (
processing,
gradio_extensons,
ui,
img_processing,
) # noqa: F401
from apps.shark_studio.modules.schedulers import scheduler_model_map

startup_timer.record("other imports")


def initialize():
configure_sigint_handler()
configure_opts_onchange()

# from apps.shark_studio.modules import modelloader
# modelloader.cleanup_models()
Expand Down
191 changes: 183 additions & 8 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from turbine_models.custom_models.sd_inference import clip, unet, vae
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.modules.pipeline import SharkPipelineBase
import iree.runtime as ireert
import gc
import torch
import gradio as gr

sd_model_map = {
"CompVis/stable-diffusion-v1-4": {
Expand Down Expand Up @@ -86,16 +90,15 @@


class StableDiffusion(SharkPipelineBase):

# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
#
#
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
#
#
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.

Expand All @@ -107,7 +110,6 @@ def __init__(
precision: str = "fp16",
device: str = None,
custom_model_map: dict = {},
custom_weights_map: dict = {},
embeddings: dict = {},
import_ir: bool = True,
):
Expand All @@ -118,12 +120,185 @@ def __init__(
self.iree_module_dict = None
self.get_compiled_map()

def prepare_pipeline(self, scheduler, custom_model_map):
return None

def generate_images(
self,
prompt,
):
return result_output,
self,
prompt,
negative_prompt,
steps,
strength,
guidance_scale,
seed,
ondemand,
repeatable_seeds,
resample_type,
control_mode,
preprocessed_hints,
):
return None, None, None, None, None


# NOTE: Each `hf_model_id` should have its own starting configuration.

# model_vmfb_key = ""


def shark_sd_fn(
prompt,
negative_prompt,
image_dict,
height: int,
width: int,
steps: int,
strength: float,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
base_model_id: str,
custom_weights: str,
custom_vae: str,
precision: str,
device: str,
lora_weights: str | list,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
preprocessed_hints: list,
progress=gr.Progress(),
):
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
continue
elif stencil is None and any(
img is not None for img in [images[i], preprocessed_hints[i]]
):
images[i] = None
preprocessed_hints[i] = None
elif images[i] is not None:
if isinstance(images[i], dict):
images[i] = images[i]["composite"]
images[i] = images[i].convert("RGB")

if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
elif image_dict:
image = image_dict["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
(
image,
_,
_,
) = resize_stencil(image, width, height)
is_img2img = True
print("Performing Stable Diffusion Pipeline setup...")

device_id = None

from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
import apps.shark_studio.web.utils.globals as global_obj

custom_model_map = {}
if custom_weights != "None":
custom_model_map["unet"] = {"custom_weights": custom_weights}
if custom_vae != "None":
custom_model_map["vae"] = {"custom_weights": custom_vae}
if stencils:
for i, stencil in enumerate(stencils):
if "xl" not in base_model_id.lower():
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
"runwayml/stable-diffusion-v1-5"
][stencil]
else:
custom_model_map[f"control_adapter_{i}"] = stencil_adapter_map[
"stabilityai/stable-diffusion-xl-1.0"
][stencil]

submit_pipe_kwargs = {
"base_model_id": base_model_id,
"height": height,
"width": width,
"precision": precision,
"device": device,
"custom_model_map": custom_model_map,
"import_ir": cmd_opts.import_mlir,
"is_img2img": is_img2img,
}
submit_prep_kwargs = {
"scheduler": scheduler,
"custom_model_map": custom_model_map,
"embeddings": lora_weights,
}
submit_run_kwargs = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"steps": steps,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
"ondemand": ondemand,
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"preprocessed_hints": preprocessed_hints,
}

global sd_pipe
global sd_pipe_kwargs

if sd_pipe_kwargs and sd_pipe_kwargs != submit_pipe_kwargs:
sd_pipe = None
sd_pipe_kwargs = submit_pipe_kwargs
gc.collect()

if sd_pipe is None:
history[-1][-1] = "Getting the pipeline ready..."
yield history, ""

# Initializes the pipeline and retrieves IR based on all
# parameters that are static in the turbine output format,
# which is currently MLIR in the torch dialect.

sd_pipe = SharkStableDiffusionPipeline(
**submit_pipe_kwargs,
)

sd_pipe.prepare_pipe(**submit_prep_kwargs)

for prompt, msg, exec_time in progress.tqdm(
out_imgs=sd_pipe.generate_images(**submit_run_kwargs),
desc="Generating Image...",
):
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
save_output_img(
out_imgs[0],
seeds[current_batch],
extra_info,
)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
), stencils, images

return generated_imgs, text_output, "", stencils, images


def cancel_sd():
print("Inject call to cancel longer API calls.")
return


if __name__ == "__main__":
sd = StableDiffusion(
Expand Down
Loading

0 comments on commit b3d5add

Please sign in to comment.