Skip to content

Commit

Permalink
Complete SD pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 18, 2023
1 parent b0151a7 commit 5813043
Show file tree
Hide file tree
Showing 8 changed files with 434 additions and 212 deletions.
410 changes: 316 additions & 94 deletions apps/shark_studio/api/sd.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions apps/shark_studio/modules/img_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def save_output_img(output_img, img_seed, extra_info=None):
"parameters",
f"{extra_info['prompt'][0]}"
f"\nNegative prompt: {extra_info['negative_prompt'][0]}"
f"\nSteps: {extra_info['steps'][0]},"
f"Sampler: {extra_info['scheduler'][0]}, "
f"CFG scale: {extra_info['guidance_scale'][0]}, "
f"\nSteps: {extra_info['steps']},"
f"Sampler: {extra_info['scheduler']}, "
f"CFG scale: {extra_info['guidance_scale']}, "
f"Seed: {img_seed},"
f"Size: {png_size_text}, "
f"Model: {img_model}, "
Expand Down
99 changes: 54 additions & 45 deletions apps/shark_studio/modules/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from msvcrt import kbhit
from shark.iree_utils.compile_utils import get_iree_compiled_module, load_vmfb_using_mmap
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
get_checkpoints_path,
get_resource_path,
Expand Down Expand Up @@ -32,8 +37,8 @@ def __init__(
self.model_map = model_map
self.static_kwargs = static_kwargs
self.base_model_id = base_model_id
self.device_name = device
self.device = device.split("=>")[-1].strip(" ")
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.iree_module_dict = {}
self.tempfiles = {}
Expand All @@ -46,22 +51,24 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
# initialization. As soon as you have a pipeline ID unique to your static torch IR parameters,
# and your model map is populated with any IR - unique model IDs and their static params,
# call this method to get the artifacts associated with your map.
self.pipe_id = pipe_id
self.pipe_id = self.safe_name(pipe_id)
self.pipe_vmfb_path = Path(os.path.join(get_checkpoints_path(".."), self.pipe_id))
self.pipe_vmfb_path.mkdir(parents=True, exist_ok=True)
print("\n[LOG] Checking for pre-compiled artifacts.")
if submodel == "None":
print("\n[LOG] Gathering any pre-compiled artifacts....")
for key in self.model_map:
self.get_compiled_map(pipe_id, submodel=key)
else:
self.get_precompiled(pipe_id, submodel)
ireec_flags = []
if submodel in self.iree_module_dict:
if "vmfb" in self.iree_module_dict[submodel]:
print(f"[LOG] Found executable for {submodel} at {self.iree_module_dict[submodel]['vmfb']}...")
print(f"\n[LOG] Executable for {submodel} already loaded...")
return
elif "vmfb_path" in self.model_map[submodel]:
return
elif submodel not in self.tempfiles:
print(f"[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
print(f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR...")
if submodel in self.static_kwargs:
init_kwargs = self.static_kwargs[submodel]
for key in self.static_kwargs["pipe"]:
Expand Down Expand Up @@ -90,16 +97,6 @@ def get_compiled_map(self, pipe_id, submodel="None", init_kwargs={}) -> None:
return


def hijack_weights(self, weights_path, submodel="None"):
if submodel == "None":
for i in self.model_map:
self.hijack_weights(weights_path, i)
else:
if submodel in self.iree_module_dict:
self.model_map[submodel]["external_weights_file"] = weights_path
return


def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
Expand All @@ -112,33 +109,10 @@ def get_precompiled(self, pipe_id, submodel="None"):
break
for file in vmfbs:
if submodel in file:
print(f"Found existing .vmfb at {file}")
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
os.path.join(vmfbs_path, file),
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=self.model_map[submodel]['external_weight_file'],
)
self.model_map[submodel]["vmfb_path"] = os.path.join(vmfbs_path, file)
return


def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]

return flat_args


def import_torch_ir(self, submodel, kwargs):
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
Expand All @@ -160,18 +134,53 @@ def import_torch_ir(self, submodel, kwargs):
def load_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
print(f"\n[LOG] {submodel} is ready for inference.")
if "vmfb_path" in self.model_map[submodel]:
print(
f"\n[LOG] Loading .vmfb for {submodel} from {self.iree_module_dict[submodel]['vmfb']}"
f"\n[LOG] Loading .vmfb for {submodel} from {self.model_map[submodel]['vmfb_path']}"
)
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],
self.iree_module_dict[submodel]["config"],
self.iree_module_dict[submodel]["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.model_map[submodel]["vmfb_path"],
self.device,
device_idx=0,
rt_flags=[],
external_weight_file=self.model_map[submodel]['external_weight_file'],
)
else:
self.get_compiled_map(self.pipe_id, submodel)
return


def unload_submodels(self, submodels: list):
for submodel in submodels:
if submodel in self.iree_module_dict:
del self.iree_module_dict[submodel]
gc.collect()
return


def run(self, submodel, inputs):
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, inputs)]
if not isinstance(inputs, list):
inputs = [inputs]
inp = [ireert.asdevicearray(self.iree_module_dict[submodel]["config"].device, input) for input in inputs]
return self.iree_module_dict[submodel]['vmfb']['main'](*inp)


def safe_name(name):
return name.replace("/", "_").replace("-", "_")
def safe_name(self, name):
return name.replace("/", "_").replace("-", "_").replace("\\", "_")


def safe_dict(self, kwargs: dict):
flat_args = {}
for i in kwargs:
if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]:
flat_args[i] = [kwargs[i][j] for j in kwargs[i]]
else:
flat_args[i] = kwargs[i]

return flat_args
58 changes: 13 additions & 45 deletions apps/shark_studio/modules/prompt_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from iree import runtime as ireert
import re
import torch
import numpy as np

re_attention = re.compile(
r"""
Expand Down Expand Up @@ -161,7 +162,7 @@ def pad_tokens_and_weights(
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = 8
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = (
max_length
if no_boseos_middle
Expand Down Expand Up @@ -194,13 +195,16 @@ def pad_tokens_and_weights(

return tokens, weights


def get_unweighted_text_embeddings(
pipe,
text_input: torch.Tensor,
text_input,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
Expand All @@ -214,7 +218,7 @@ def get_unweighted_text_embeddings(
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]

text_embedding = pipe.run("clip", text_input_chunk)[0]
text_embedding = pipe.run("clip", text_input_chunk)[0].to_host()

if no_boseos_middle:
if i == 0:
Expand All @@ -231,50 +235,14 @@ def get_unweighted_text_embeddings(
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
text_embeddings = torch.from_numpy(text_embeddings_np)
else:
text_embeddings = pipe.run("clip", text_input)[0]
# text_embeddings = torch.from_numpy(text_embeddings)[None, :]
return torch.from_numpy(text_embeddings.to_host())
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = 8
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
].clone()

# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
# text_embedding = pipe.text_encoder(text_input_chunk)[0]

print(text_input_chunk)
breakpoint()
text_embedding = pipe.run("clip", text_input_chunk)
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]

text_embeddings.append(text_embedding)
# SHARK: Convert the result to tensor
# text_embeddings = torch.concat(text_embeddings, axis=1)
text_embeddings_np = np.concatenate(np.array(text_embeddings))
text_embeddings = torch.from_numpy(text_embeddings_np)[None, :]
text_embeddings = torch.from_numpy(text_embeddings.to_host())
return text_embeddings



# This function deals with NoneType values occuring in tokens after padding
# It switches out None with 49407 as truncating None values causes matrix dimension errors,
def filter_nonetype_tokens(tokens: List[List]):
Expand All @@ -286,7 +254,7 @@ def get_weighted_text_embeddings(
prompt: List[str],
uncond_prompt: List[str] = None,
max_embeddings_multiples: Optional[int] = 8,
no_boseos_middle: Optional[bool] = False,
no_boseos_middle: Optional[bool] = True,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
):
Expand Down Expand Up @@ -325,12 +293,12 @@ def get_weighted_text_embeddings(
max_length = max(
max_length, max([len(token) for token in uncond_tokens])
)

max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (pipe.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)

max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2

# pad the length of tokens and weights
Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def get_schedulers(model_id):
#TODO: switch over to turbine and run all on GPU
print(f"[LOG] Initializing schedulers from model id: {model_id}")
print(f"\n[LOG] Initializing schedulers from model id: {model_id}")
schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
Expand Down
21 changes: 11 additions & 10 deletions apps/shark_studio/web/configs/default_sd_config.json
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
{
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
"sd_init_image": [ "None" ],
"sd_init_image": [ null ],
"height": 512,
"width": 512,
"steps": [ 50 ],
"strength": [ 0.8 ],
"guidance_scale": [ 7.5 ],
"seed": [ -1 ],
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": -1,
"batch_count": 1,
"batch_size": 1,
"scheduler": [ "EulerDiscrete" ],
"scheduler": "EulerDiscrete",
"base_model_id": "runwayml/stable-diffusion-v1-5",
"custom_weights": "",
"custom_vae": "",
"custom_weights": null,
"custom_vae": null,
"use_base_vae": false,
"precision": "fp16",
"device": "vulkan",
"ondemand": "False",
"repeatable_seeds": "False",
"ondemand": false,
"repeatable_seeds": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
Expand Down
Loading

0 comments on commit 5813043

Please sign in to comment.