Skip to content

Commit

Permalink
Commit 2
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 11, 2023
1 parent edcb5b7 commit fd7941a
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 133 deletions.
42 changes: 27 additions & 15 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,39 +13,53 @@

def imports():
import torch # noqa: F401

startup_timer.record("import torch")
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="torch")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
warnings.filterwarnings(
action="ignore", category=DeprecationWarning, module="torch"
)
warnings.filterwarnings(
action="ignore", category=UserWarning, module="torchvision"
)

import gradio # noqa: F401

startup_timer.record("import gradio")

#from apps.shark_studio.modules import shared_init
#shared_init.initialize()
#startup_timer.record("initialize shared")
# from apps.shark_studio.modules import shared_init
# shared_init.initialize()
# startup_timer.record("initialize shared")

from apps.shark_studio.modules import (
processing,
gradio_extensons,
ui,
) # noqa: F401

from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
startup_timer.record("other imports")


def initialize():
configure_sigint_handler()
configure_opts_onchange()

#from apps.shark_studio.modules import modelloader
#modelloader.cleanup_models()
# from apps.shark_studio.modules import modelloader
# modelloader.cleanup_models()

#from apps.shark_studio.modules import sd_models
#sd_models.setup_model()
#startup_timer.record("setup SD model")
# from apps.shark_studio.modules import sd_models
# sd_models.setup_model()
# startup_timer.record("setup SD model")

# initialize_rest(reload_script_modules=False)

#initialize_rest(reload_script_modules=False)

def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
# Keep this for adding reload options to the webUI.


def dumpstacks():
import threading
import traceback
Expand All @@ -65,12 +79,10 @@ def dumpstacks():
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
print(f"Interrupted with signal {sig} in {frame}")

dumpstacks()

os._exit(0)

signal.signal(signal.SIGINT, sigint_handler)


20 changes: 15 additions & 5 deletions apps/shark_studio/api/schedulers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
#from shark_turbine.turbine_models.schedulers import export_scheduler_model
# from shark_turbine.turbine_models.schedulers import export_scheduler_model


def export_scheduler_model(model):
return "None", "None"


scheduler_model_map = {
"EulerDiscrete": export_scheduler_model("EulerDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model("EulerAncestralDiscreteScheduler"),
"EulerAncestralDiscrete": export_scheduler_model(
"EulerAncestralDiscreteScheduler"
),
"LCM": export_scheduler_model("LCMScheduler"),
"LMSDiscrete": export_scheduler_model("LMSDiscreteScheduler"),
"PNDM": export_scheduler_model("PNDMScheduler"),
"DDPM": export_scheduler_model("DDPMScheduler"),
"DDIM": export_scheduler_model("DDIMScheduler"),
"DPMSolverMultistep": export_scheduler_model("DPMSolverMultistepScheduler"),
"DPMSolverMultistep": export_scheduler_model(
"DPMSolverMultistepScheduler"
),
"KDPM2Discrete": export_scheduler_model("KDPM2DiscreteScheduler"),
"DEISMultistep": export_scheduler_model("DEISMultistepScheduler"),
"DPMSolverSinglestep": export_scheduler_model("DPMSolverSingleStepScheduler"),
"KDPM2AncestralDiscrete": export_scheduler_model("KDPM2AncestralDiscreteScheduler"),
"DPMSolverSinglestep": export_scheduler_model(
"DPMSolverSingleStepScheduler"
),
"KDPM2AncestralDiscrete": export_scheduler_model(
"KDPM2AncestralDiscreteScheduler"
),
"HeunDiscrete": export_scheduler_model("HeunDiscreteScheduler"),
}
171 changes: 101 additions & 70 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,97 +6,128 @@
import torch

sd_model_map = {
"sd15": {
"base_model_id": "runwayml/stable-diffusion-v1-5"
"CompVis/stable-diffusion-v1-4": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 77,
}
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
}
},
"vae_decode": {
"initializer": vae.export_vae_model,,
}
}
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
"runwayml/stable-diffusion-v1-5": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
"stabilityai/stable-diffusion-2-1-base": {
"clip": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
"stabilityai/stable_diffusion-xl-1.0": {
"clip_1": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"clip_2": {
"initializer": clip.export_clip_model,
"max_tokens": 64,
},
"vae_encode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
"unet": {
"initializer": unet.export_unet_model,
"max_tokens": 512,
},
"vae_decode": {
"initializer": vae.export_vae_model,
"max_tokens": 64,
},
},
}


class SharkStableDiffusionPipeline:
class StableDiffusion(SharkPipelineBase):

# This class is responsible for executing image generation and creating
# /managing a set of compiled modules to run Stable Diffusion. The init
# aims to be as general as possible, and the class will infer and compile
# a list of necessary modules or a combined "pipeline module" for a
# specified job based on the inference task.
#
# custom_model_ids: a dict of submodel + HF ID pairs for custom submodels.
# e.g. {"vae_decode": "madebyollin/sdxl-vae-fp16-fix"}
#
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.

def __init__(
self, model_name, , device=None, precision="fp32"
self,
base_model_id: str = "runwayml/stable-diffusion-v1-5",
height: int = 512,
width: int = 512,
precision: str = "fp16",
device: str = None,
custom_model_map: dict = {},
custom_weights_map: dict = {},
embeddings: dict = {},
import_ir: bool = True,
):
print(sd_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()

super().__init__(sd_model_map[base_model_id], device, import_ir)
self.base_model_id = base_model_id
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
self.get_compiled_map()

def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file

def generate_images(
self,
prompt,
):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)

history.append(token)
yield self.tokenizer.decode(history)

if token == llm_model_map["llama2_7b"]["stop_token"]:
break

for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output

return result_output,

if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
sd = StableDiffusion(
"runwayml/stable-diffusion-v1-5",
device="vulkan",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from safetensors.torch import load_file


def processLoRA(model, use_lora, splitting_prefix):
state_dict = ""
if ".safetensors" in use_lora:
Expand Down Expand Up @@ -108,4 +109,3 @@ def update_lora_weight(model, use_lora, model_name):
return processLoRA(model, use_lora, "lora_te_")
except:
return None

20 changes: 18 additions & 2 deletions apps/shark_studio/modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,24 @@

import gradio as gr

from modules import shared_cmd_options, shared_gradio, options, shared_items, sd_models_types
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import (
shared_cmd_options,
shared_gradio,
options,
shared_items,
sd_models_types,
)
from modules.paths_internal import (
models_path,
script_path,
data_path,
sd_configs_path,
sd_default_config,
sd_model_file,
default_sd_model_file,
extensions_dir,
extensions_builtin_dir,
) # noqa: F401
from modules import util

cmd_opts = shared_cmd_options.cmd_opts
Expand Down
Loading

0 comments on commit fd7941a

Please sign in to comment.