From 9658431c5199be86efeaadf6ff66568ac2450270 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 22 May 2024 20:03:25 -0500 Subject: [PATCH] Small fixes for unifying pipelines. --- apps/shark_studio/api/sd.py | 51 ++++++++++--------------------------- 1 file changed, 14 insertions(+), 37 deletions(-) diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index 600a6a4561..ac82dd1494 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -4,12 +4,12 @@ import os import json import numpy as np +import copy from tqdm.auto import tqdm from pathlib import Path from random import randint -from turbine_models.custom_models.sd_inference import clip, unet, vae -from turbine_models.custom_models.sdxl_inference import sdxl_compiled_pipeline +from turbine_models.custom_models.sd_inference.sd_pipeline import SharkSDPipeline from apps.shark_studio.api.controlnet import control_adapter_map from apps.shark_studio.api.utils import parse_device from apps.shark_studio.web.utils.state import status_label @@ -34,16 +34,11 @@ process_custom_pipe_weights, ) -sd_model_map = { - "clip": { - "initializer": clip.export_clip_model, - }, - "unet": { - "initializer": unet.export_unet_model, - }, - "vae_decode": { - "initializer": vae.export_vae_model, - }, + +EMPTY_SD_MAP = { + "clip": None, + "unet": None, + "vae_decode": None, } EMPTY_FLAGS = { @@ -75,7 +70,6 @@ def __init__( num_loras: int = 0, import_ir: bool = True, is_controlled: bool = False, - hf_auth_token=None, ): self.compiled_pipeline = False self.base_model_id = base_model_id @@ -102,7 +96,7 @@ def __init__( ) if not os.path.exists(self.weights_path): os.mkdir(self.weights_path) - self.sd_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( + self.sd_pipe = SharkSDPipeline( hf_model_name=base_model_id, scheduler_id=scheduler, height=height, @@ -125,28 +119,10 @@ def __init__( def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") - self.is_img2img = is_img2img - mlirs = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } - vmfbs = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } - weights = { - "prompt_encoder": None, - "scheduled_unet": None, - "vae_decode": None, - "pipeline": None, - "full_pipeline": None, - } + self.is_img2img = False + mlirs = copy.deepcopy(EMPTY_SD_MAP) + vmfbs = copy.deepcopy(EMPTY_SD_MAP) + weights = copy.deepcopy(EMPTY_SD_MAP) vmfbs, weights = self.sd_pipe.check_prepared(mlirs, vmfbs, weights, interactive=False) print(f"\n[LOG] Loading pipeline to device {self.rt_device}.") self.sd_pipe.load_pipeline(vmfbs, weights, self.rt_device, self.compiled_pipeline) @@ -235,6 +211,7 @@ def shark_sd_fn( control_mode = None hints = [] num_loras = 0 + import_ir=True for i in embeddings: num_loras += 1 if embeddings[i] else 0 if "model" in controlnets: @@ -268,7 +245,7 @@ def shark_sd_fn( "device": device, "custom_vae": custom_vae, "num_loras": num_loras, - "import_ir": cmd_opts.import_mlir, + "import_ir": import_ir, "is_controlled": is_controlled, "steps": steps, "scheduler": scheduler,