Skip to content

Commit

Permalink
Pipeline tweaks, add cmd_opts parsing to sd api
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 20, 2023
1 parent 1288459 commit a42ecb0
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 102 deletions.
133 changes: 72 additions & 61 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
)
from transformers import CLIPTokenizer
from diffusers.image_processor import VaeImageProcessor
from math import ceil
from PIL import Image

sd_model_map = {
"clip": {
Expand Down Expand Up @@ -166,35 +164,40 @@ def __init__(
del static_kwargs
gc.collect()

def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img):
print(
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
f"\n[LOG] Custom embeddings currently unsupported."
)
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = is_img2img
schedulers = get_schedulers(self.base_model_id)
self.scheduler = schedulers[scheduler]
self.image_processor = VaeImageProcessor()#do_convert_rgb=True)
self.weights_path = os.path.join(get_checkpoints_path(), self.safe_name(self.base_model_id))
self.schedulers = get_schedulers(self.base_model_id)

self.weights_path = os.path.join(
get_checkpoints_path(), self.safe_name(self.base_model_id)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
print(f"[LOG] Loaded scheduler: {scheduler}")

for model in adapters:
self.model_map[model] = adapters[model]

for submodel in self.static_kwargs:
if custom_weights:
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
if submodel not in ["clip", "clip2"]:
self.static_kwargs[submodel]["external_weight_file"] = custom_weights
self.static_kwargs[submodel][
"external_weight_file"
] = custom_weights_params
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)

self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return


def encode_prompts_weight(
self,
prompt,
Expand Down Expand Up @@ -335,9 +338,9 @@ def produce_img_latents(

latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
# print(
# f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time

# if self.status == SD_STATE_CANCEL:
Expand Down Expand Up @@ -371,51 +374,52 @@ def decode_latents(self, latents, cpu_scheduling=True):
pil_images = self.image_processor.numpy_to_pil(images)
return pil_images

def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(
sd_init_image, resample_type
)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img
# def process_sd_init_image(self, sd_init_image, resample_type):
# if isinstance(sd_init_image, list):
# images = []
# for img in sd_init_image:
# img, _ = self.process_sd_init_image(img, resample_type)
# images.append(img)
# is_img2img = True
# return images, is_img2img
# if isinstance(sd_init_image, str):
# if os.path.isfile(sd_init_image):
# sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
# image, is_img2img = self.process_sd_init_image(
# sd_init_image, resample_type
# )
# else:
# image = None
# is_img2img = False
# elif isinstance(sd_init_image, Image.Image):
# image = sd_init_image.convert("RGB")
# elif sd_init_image:
# image = sd_init_image["image"].convert("RGB")
# else:
# image = None
# is_img2img = False
# if image:
# resample_type = (
# resamplers[resample_type]
# if resample_type in resampler_list
# # Fallback to Lanczos
# else Image.Resampling.LANCZOS
# )
# image = image.resize((self.width, self.height), resample=resample_type)
# image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
# image_arr = image_arr / 255.0
# image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
# image_arr = 2 * (image_arr - 0.5)
# is_img2img = True
# image = image_arr
# return image, is_img2img

def generate_images(
self,
prompt,
negative_prompt,
image,
scheduler,
steps,
strength,
guidance_scale,
Expand All @@ -427,9 +431,11 @@ def generate_images(
hints,
):
# TODO: Batched args
self.image_processor = VaeImageProcessor(do_convert_rgb=True)
self.scheduler = self.schedulers[scheduler]
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.process_sd_init_image(image, resample_type)
image, _ = self.image_processor.preprocess(image, resample_type)
else:
image = None

Expand Down Expand Up @@ -532,6 +538,8 @@ def shark_sd_fn(
embeddings: dict,
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False

print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
Expand Down Expand Up @@ -581,7 +589,6 @@ def shark_sd_fn(
"is_controlled": is_controlled,
}
submit_prep_kwargs = {
"scheduler": scheduler,
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
Expand All @@ -592,6 +599,7 @@ def shark_sd_fn(
"negative_prompt": negative_prompt,
"image": sd_init_image,
"steps": steps,
"scheduler": scheduler,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
Expand Down Expand Up @@ -667,5 +675,8 @@ def view_json_file(file_path):

sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)
14 changes: 8 additions & 6 deletions apps/shark_studio/modules/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return

def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
Expand All @@ -114,7 +114,7 @@ def get_io_params(self, submodel):
# assume the torch IR contains the weights.
weights_path = None
return weights_path

def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
Expand All @@ -125,7 +125,9 @@ def get_precompiled(self, pipe_id, submodel="None"):
break
for file in vmfbs:
if submodel in file:
self.pipe_map[submodel]["vmfb_path"] = os.path.join(self.pipe_vmfb_path, file)
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return

def import_torch_ir(self, submodel, kwargs):
Expand Down Expand Up @@ -153,9 +155,9 @@ def load_submodels(self, submodels: list):
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
print(
f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
)
# print(
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
# )
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
Expand Down
66 changes: 66 additions & 0 deletions apps/shark_studio/modules/seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)


# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed


# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None

if isinstance(seed_input, int):
return [seed_input]

if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input

raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)


# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)

# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))

if repeatable:
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])

# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))

# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]

if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)

return seeds
14 changes: 7 additions & 7 deletions apps/shark_studio/modules/shared_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_valid_file(arg):
)
p.add_argument(
"-p",
"--prompts",
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
Expand All @@ -44,7 +44,7 @@ def is_valid_file(arg):
)

p.add_argument(
"--negative_prompts",
"--negative_prompt",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
Expand All @@ -54,7 +54,7 @@ def is_valid_file(arg):
)

p.add_argument(
"--img_path",
"--sd_init_image",
type=str,
help="Path to the image input for img2img/inpainting.",
)
Expand Down Expand Up @@ -320,7 +320,7 @@ def is_valid_file(arg):
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
default="DDIM",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
Expand Down Expand Up @@ -359,10 +359,10 @@ def is_valid_file(arg):
)

p.add_argument(
"--ckpt_loc",
"--custom_weights",
type=str,
default="",
help="Path to SD's .ckpt file.",
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
)

p.add_argument(
Expand All @@ -374,7 +374,7 @@ def is_valid_file(arg):
)

p.add_argument(
"--hf_model_id",
"--base_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",
Expand Down
Loading

0 comments on commit a42ecb0

Please sign in to comment.