Skip to content

Commit

Permalink
Diffusers wrapping, cont'd
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Dec 24, 2023
1 parent 68d80dd commit 42c8e5a
Show file tree
Hide file tree
Showing 5 changed files with 2,376 additions and 282 deletions.
86 changes: 46 additions & 40 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_resource_path,
get_checkpoints_path,
)
from apps.shark_studio.modules.meta_model import SharkMetaModelBase
from apps.shark_studio.modules.meta_model import SharkMetaLoader
from apps.shark_studio.modules.schedulers import get_schedulers
from apps.shark_studio.modules.prompt_encoding import (
get_weighted_text_embeddings,
Expand Down Expand Up @@ -80,23 +80,23 @@
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
],
},
"vae_encode": {
"initializer": vae.export_vae_model,
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
],
},
# "vae_encode": {
# "initializer": vae.export_vae_model,
# "ireec_flags": [
# "--iree-flow-collapse-reduction-dims",
# "--iree-opt-const-expr-hoisting=False",
# "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
# "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
# "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
# ],
# },
"unet": {
"initializer": unet.export_unet_model,
"ireec_flags": [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
],
},
"vae_decode": {
Expand Down Expand Up @@ -138,7 +138,7 @@ def setup_shark(
self.model_max_length = 77
self.batch_size = batch_size
self.precision = precision
self.dtype = torch.float16 if precision == "fp16" else torch.float32
# self.dtype = torch.float16 if precision == "fp16" else torch.float32
self.height = height
self.width = width
self.scheduler_obj = {}
Expand All @@ -147,10 +147,10 @@ def setup_shark(
"pipe": {
"external_weights": "safetensors",
},
"clip": {"hf_model_name": self.base_model_id},
"clip": { "hf_model_name": base_model_id },
"unet": {
"hf_model_name": self.base_model_id,
"unet_model": self.unet,
"hf_model_name": base_model_id,
"unet_model": unet.UnetModel(base_model_id, self.unet),
"batch_size": batch_size,
# "is_controlled": is_controlled,
# "num_loras": num_loras,
Expand All @@ -160,16 +160,16 @@ def setup_shark(
"max_length": self.model_max_length,
},
"vae_encode": {
"hf_model_name": self.base_model_id,
"vae_model": self.vae,
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(base_model_id, self.vae),
"batch_size": batch_size,
"height": height,
"width": width,
"precision": precision,
},
"vae_decode": {
"hf_model_name": self.base_model_id,
"vae_model": self.vae,
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(base_model_id, self.vae),
"batch_size": batch_size,
"height": height,
"width": width,
Expand All @@ -179,7 +179,7 @@ def setup_shark(
pipe_id_list = [
safe_name(self.base_model_id),
str(batch_size),
str(self.compile_static_args["unet"]["max_length"]),
str(self.model_max_length),
f"{str(height)}x{str(width)}",
precision,
]
Expand All @@ -200,31 +200,31 @@ def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
self.schedulers = get_schedulers(self.base_model_id)

self.weights_path = os.path.join(
get_checkpoints_path(), self.safe_name(self.base_model_id)
get_checkpoints_path(), self.shark_meta.safe_name(self.base_model_id)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)

for model in adapters:
self.model_map[model] = adapters[model]

for submodel in self.static_kwargs:
for submodel in self.compile_static_args:
if custom_weights:
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
if submodel not in ["clip", "clip2"]:
self.static_kwargs[submodel][
self.compile_static_args[submodel][
"external_weight_file"
] = custom_weights_params
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.compile_static_args[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.compile_static_args[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)

self.get_compiled_map(pipe_id=self.pipe_id, static_kwargs=self.static_kwargs)
self.shark_meta.get_compiled_map(pipe_id=self.pipe_id, static_kwargs=self.compile_static_args)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return

Expand Down Expand Up @@ -288,17 +288,6 @@ def _get_pipeline_class(

pipeline_cls = connected_pipeline_cls or pipeline_cls

# (SHARK): Monkey-patch in our classmethods to the auto-generated SD pipeline class.
pipeline_cls.setup_shark = setup_shark
pipeline_cls.prepare_pipe = prepare_pipe

class shark_pipeline_class(SharkMetaModelBase, pipeline_cls):
def __getattribute__(self, __name: str) -> Any:
return super().__getattribute__(__name)
pass

pipeline_cls = shark_pipeline_class

return pipeline_cls


Expand Down Expand Up @@ -354,6 +343,7 @@ def device(self) -> torch.device:
return torch.device("cpu")

def register_modules(self, **kwargs):
print("sharkified")
# import it here to avoid circular import
diffusers_module = importlib.import_module("diffusers")
pipelines = getattr(diffusers_module, "pipelines")
Expand Down Expand Up @@ -777,8 +767,6 @@ def get_connected_passed_kwargs(prefix):
)

# 7. Potentially add passed objects if expected
init_kwargs['device'] = device
init_kwargs['model_map'] = sd_model_map
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
Expand All @@ -790,12 +778,30 @@ def get_connected_passed_kwargs(prefix):
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)

# 10. (SHARK): Monkey-patch in our classmethods to the auto-generated SD pipeline class.
pipeline_class.setup_shark = setup_shark
pipeline_class.prepare_pipe = prepare_pipe
pipeline_class.shark_meta = SharkMetaLoader(sd_model_map, device)

# class shark_pipeline_class(pipeline_class):
# def __init__(self, pipeline_class, device, model_map, **init_kwargs):
# super().__init__(device, model_map, pipeline_class, **init_kwargs)


# def __getattribute__(cls, __name: str) -> Any:
# return super().__getattribute__(__name)

#pipeline_class = shark_pipeline_class

# 8. Instantiate the pipeline
model = pipeline_class(**init_kwargs)

# 9. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)



return model

@classmethod
Expand Down
40 changes: 10 additions & 30 deletions apps/shark_studio/modules/meta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os


class SharkMetaModelBase:
class SharkMetaLoader:
# This class is a lightweight base for managing an
# inference API class. It should provide methods for:
# - compiling a set (model map) of torch IR modules
Expand All @@ -30,37 +30,17 @@ def __init__(
self,
model_map: dict,
device: str,
dtype: str = "f16",
import_mlir: bool = True,
):
self.model_map = model_map
self.pipe_map = {}
self.triple = get_iree_target_triple(device)
self._device, self.device_id = clean_device_info(device)
self.import_mlir = import_mlir
self.device, self.device_id = clean_device_info(device)
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
self._dtype = dtype

@property
def device(self):
return self._device

@device.setter
def device(self, device_str):
self._device = device_str

@property
def dtype(self):
return self._dtype

@dtype.setter
def dtype(self, dtype_val):
self._dtype = dtype_val

def get_compiled_map(self, pipe_id, static_kwargs, submodel="None", init_kwargs={}) -> None:
# First checks whether we have .vmfbs precompiled, then populates the map
Expand Down Expand Up @@ -96,32 +76,31 @@ def get_compiled_map(self, pipe_id, static_kwargs, submodel="None", init_kwargs=
if key not in init_kwargs:
init_kwargs[key] = static_kwargs["pipe"][key]
self.import_torch_ir(submodel, init_kwargs)
self.get_compiled_map(pipe_id, submodel)
self.get_compiled_map(pipe_id, static_kwargs, submodel, init_kwargs)
else:
ireec_flags = (
self.model_map[submodel]["ireec_flags"]
if "ireec_flags" in self.model_map[submodel]
else []
)

self.iree_module_dict[submodel] = get_iree_compiled_module(
self.tempfiles[submodel],
device=self.device,
frontend="torch",
mmap=True,
external_weight_file=self.get_io_params(submodel),
external_weight_file=self.get_io_params(submodel, static_kwargs),
extra_args=ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return

def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
def get_io_params(self, submodel, static_kwargs: dict):
if "external_weight_file" in static_kwargs[submodel]:
# we are using custom weights
weights_path = self.static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in self.static_kwargs[submodel]:
weights_path = static_kwargs[submodel]["external_weight_file"]
elif "external_weight_path" in static_kwargs[submodel]:
# we are using the default weights for the HF model
weights_path = self.static_kwargs[submodel]["external_weight_path"]
weights_path = static_kwargs[submodel]["external_weight_path"]
else:
# assume the torch IR contains the weights.
weights_path = None
Expand All @@ -143,6 +122,7 @@ def get_precompiled(self, pipe_id, submodel="None"):
return

def import_torch_ir(self, submodel, kwargs):
breakpoint()
torch_ir = self.model_map[submodel]["initializer"](
**self.safe_dict(kwargs), compile_to="torch"
)
Expand Down
Loading

0 comments on commit 42c8e5a

Please sign in to comment.