Skip to content

Commit

Permalink
(WIP): Studio2 app infra and SD API
Browse files Browse the repository at this point in the history
UI/app structure and utility implementation.

- Initializers for webui/API launch
- Schedulers file for SD scheduling utilities
- Additions to API-level utilities
- Added embeddings module for LoRA, Lycoris, yada yada
- Added image_processing module for resamplers, resize tools,
  transforms, and any image annotation (PNG metadata)
- shared_cmd_opts module -- sorry, this is stable_args.py. It lives on.
  We still want to have some global control over the app exclusively
  from the command-line. At least we will be free from shark_args.
- Moving around some utility pieces.
- Try to make api+webui concurrency possible in index.py
- SD UI -- this is just img2imgUI but hopefully a little better.
- UI utilities for your nod logos and your gradio temps.

Enable UI / bugfixes / tweaks
  • Loading branch information
monorimet committed Dec 12, 2023
1 parent e2307a4 commit 44e1d49
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 58 deletions.
4 changes: 4 additions & 0 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def imports():
) # noqa: F401
from apps.shark_studio.modules.schedulers import scheduler_model_map

from apps.shark_studio.modules import processing, gradio_extensons, ui # noqa: F401
startup_timer.record("other imports")


Expand All @@ -51,6 +52,7 @@ def initialize():

# initialize_rest(reload_script_modules=False)

#initialize_rest(reload_script_modules=False)

def initialize_rest(*, reload_script_modules=False):
"""
Expand Down Expand Up @@ -85,3 +87,5 @@ def sigint_handler(sig, frame):
os._exit(0)

signal.signal(signal.SIGINT, sigint_handler)


18 changes: 12 additions & 6 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class StableDiffusion(SharkPipelineBase):
# embeddings: a dict of embedding checkpoints or model IDs to use when
# initializing the compiled modules.

class SharkStableDiffusionPipeline:
def __init__(
self,
base_model_id: str = "runwayml/stable-diffusion-v1-5",
Expand All @@ -117,15 +118,16 @@ def __init__(
self.base_model_id = base_model_id
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.get_compiled_map()

def prepare_pipeline(self, scheduler, custom_model_map):
return None

def generate_images(
self,
prompt,
self,
prompt,
negative_prompt,
steps,
strength,
Expand Down Expand Up @@ -172,7 +174,7 @@ def shark_sd_fn(
images: list,
preprocessed_hints: list,
progress=gr.Progress(),
):
):
# Handling gradio ImageEditor datatypes so we have unified inputs to the SD API
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
Expand Down Expand Up @@ -271,7 +273,7 @@ def shark_sd_fn(

sd_pipe = SharkStableDiffusionPipeline(
**submit_pipe_kwargs,
)
)

sd_pipe.prepare_pipe(**submit_prep_kwargs)

Expand All @@ -281,19 +283,21 @@ def shark_sd_fn(
):
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
)
save_output_img(
out_imgs[0],
seeds[current_batch],
extra_info,
)
)
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Stable Diffusion", current_batch + 1, batch_count, batch_size
), stencils, images

return generated_imgs, text_output, "", stencils, images

if token == llm_model_map["llama2_7b"]["stop_token"]:
break

def cancel_sd():
print("Inject call to cancel longer API calls.")
Expand All @@ -306,3 +310,5 @@ def cancel_sd():
device="vulkan",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)
6 changes: 0 additions & 6 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
from cpuinfo import get_cpu_info

# TODO: migrate these utils to studio
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)

checkpoints_filetypes = (
"*.ckpt",
Expand Down
6 changes: 4 additions & 2 deletions apps/shark_studio/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []

# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight.
Expand Down Expand Up @@ -73,15 +75,15 @@ def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = (
weight_up = (
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
)
weight_down = (
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
change = (
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
)
)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
Expand Down
8 changes: 4 additions & 4 deletions apps/shark_studio/web/api/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ async def log_and_time(req: Request, call_next):
if shared.cmd_opts.api_log and endpoint.startswith("/sdapi"):
print(
"API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}".format(
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
code=res.status_code,
ver=req.scope.get("http_version", "0.0"),
cli=req.scope.get("client", ("0:0.0.0", 0))[0],
prot=req.scope.get("scheme", "err"),
method=req.scope.get("method", "err"),
endpoint=endpoint,
duration=duration,
endpoint=endpoint,
duration=duration,
)
)
return res
Expand Down
2 changes: 2 additions & 0 deletions apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def launch_webui(address):
)
webview.start(private_mode=False, storage_path=os.getcwd())

def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts

def webui():
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
Expand Down
Loading

0 comments on commit 44e1d49

Please sign in to comment.