diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index f4b074ae13..84bc9f9976 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -5,8 +5,14 @@ import json import numpy as np import transformers -import logging import importlib +import inspect +import re +import fnmatch +import sys + + +from requests.exceptions import HTTPError from packaging import version from tqdm.auto import tqdm @@ -20,7 +26,7 @@ get_resource_path, get_checkpoints_path, ) -from apps.shark_studio.modules.pipeline import SharkPipelineBase +from apps.shark_studio.modules.meta_model import SharkMetaModelBase from apps.shark_studio.modules.schedulers import get_schedulers from apps.shark_studio.modules.prompt_encoding import ( get_weighted_text_embeddings, @@ -52,9 +58,15 @@ is_safetensors_compatible, variant_compatible_siblings, get_class_obj_and_candidates, + get_class_from_dynamic_module, maybe_raise_or_warn, - _get_pipeline_class, + load_sub_model, + _unwrap_model, + LOADABLE_CLASSES, ) +from diffusers.utils import logging, PushToHubMixin, CONFIG_NAME +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.configuration_utils import ConfigMixin logger = logging.get_logger(__name__) @@ -105,152 +117,205 @@ TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" CONNECTED_PIPES_KEYS = ["prior"] -LOADABLE_CLASSES = { - "diffusers": { - "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_pretrained", "from_pretrained"], - "DiffusionPipeline": ["save_pretrained", "from_pretrained"], - "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], - }, - "transformers": { - "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], - "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], - "PreTrainedModel": ["save_pretrained", "from_pretrained"], - "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], - "ProcessorMixin": ["save_pretrained", "from_pretrained"], - "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], - }, - #"onnxruntime.training": { - # "ORTModule": ["save_pretrained", "from_pretrained"], - #}, -} - ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -def load_sub_model( - library_name: str, - class_name: str, - importable_classes: List[Any], - pipelines: Any, - is_pipeline_module: bool, - pipeline_class: Any, - torch_dtype: torch.dtype, - provider: Any, - sess_options: Any, - device_map: Optional[Union[Dict[str, torch.device], str]], - max_memory: Optional[Dict[Union[int, str], Union[int, str]]], - offload_folder: Optional[Union[str, os.PathLike]], - offload_state_dict: bool, - model_variants: Dict[str, str], - name: str, - from_flax: bool, - variant: str, - low_cpu_mem_usage: bool, - cached_folder: Union[str, os.PathLike], - revision: str = None, +def setup_shark( + self, + base_model_id, + height: int, + width: int, + batch_size: int, + precision: str, + device: str, + custom_vae: str = None, + num_loras: int = 0, + import_ir: bool = True, + is_controlled: bool = False, ): - """Helper method to load the module `name` from `library_name` and `class_name`""" - # retrieve class candidates - class_obj, class_candidates = get_class_obj_and_candidates( - library_name, - class_name, - importable_classes, - pipelines, - is_pipeline_module, - component_name=name, - cache_dir=cached_folder, + self.model_max_length = 77 + self.batch_size = batch_size + self.precision = precision + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.height = height + self.width = width + self.scheduler_obj = {} + self.base_model_id = base_model_id #pipe.config["_name_or_path"] + self.compile_static_args = { + "pipe": { + "external_weights": "safetensors", + }, + "clip": {"hf_model_name": self.base_model_id}, + "unet": { + "hf_model_name": self.base_model_id, + "unet_model": self.unet, + "batch_size": batch_size, + # "is_controlled": is_controlled, + # "num_loras": num_loras, + "height": height, + "width": width, + "precision": precision, + "max_length": self.model_max_length, + }, + "vae_encode": { + "hf_model_name": self.base_model_id, + "vae_model": self.vae, + "batch_size": batch_size, + "height": height, + "width": width, + "precision": precision, + }, + "vae_decode": { + "hf_model_name": self.base_model_id, + "vae_model": self.vae, + "batch_size": batch_size, + "height": height, + "width": width, + "precision": precision, + }, + } + pipe_id_list = [ + safe_name(self.base_model_id), + str(batch_size), + str(self.compile_static_args["unet"]["max_length"]), + f"{str(height)}x{str(width)}", + precision, + ] + if num_loras > 0: + pipe_id_list.append(str(num_loras) + "lora") + if is_controlled: + pipe_id_list.append("controlled") + if custom_vae: + pipe_id_list.append(custom_vae) + self.pipe_id = "_".join(pipe_id_list) + print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") + gc.collect() + + +def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): + print(f"\n[LOG] Preparing pipeline...") + self.is_img2img = is_img2img + self.schedulers = get_schedulers(self.base_model_id) + + self.weights_path = os.path.join( + get_checkpoints_path(), self.safe_name(self.base_model_id) ) + if not os.path.exists(self.weights_path): + os.mkdir(self.weights_path) + + for model in adapters: + self.model_map[model] = adapters[model] + + for submodel in self.static_kwargs: + if custom_weights: + custom_weights_params, _ = process_custom_pipe_weights(custom_weights) + if submodel not in ["clip", "clip2"]: + self.static_kwargs[submodel][ + "external_weight_file" + ] = custom_weights_params + else: + self.static_kwargs[submodel]["external_weight_path"] = os.path.join( + self.weights_path, submodel + ".safetensors" + ) + else: + self.static_kwargs[submodel]["external_weight_path"] = os.path.join( + self.weights_path, submodel + ".safetensors" + ) - load_method_name = None - # retrive load method name - for class_name, class_candidate in class_candidates.items(): - if class_candidate is not None and issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] - - # if load method name is None, then we have a dummy module -> raise Error - if load_method_name is None: - none_module = class_obj.__module__ - is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( - TRANSFORMERS_DUMMY_MODULES_FOLDER - ) - if is_dummy_path and "dummy" in none_module: - # call class_obj for nice error message of missing requirements - class_obj() + self.get_compiled_map(pipe_id=self.pipe_id, static_kwargs=self.static_kwargs) + print("\n[LOG] Pipeline successfully prepared for runtime.") + return - raise ValueError( - f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" - f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." - ) - load_method = getattr(class_obj, load_method_name) +def _get_pipeline_class( + class_obj, + config, + load_connected_pipeline=False, + custom_pipeline=None, + repo_id=None, + hub_revision=None, + class_name=None, + cache_dir=None, + revision=None, + shark_device="cpu", +): + if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + elif repo_id is not None: + file_name = f"{custom_pipeline}.py" + custom_pipeline = repo_id + else: + file_name = CUSTOM_PIPELINE_FILE_NAME - # add kwargs to loading method - diffusers_module = importlib.import_module(__name__.split(".")[0]) - loading_kwargs = {} - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): - raise Exception("Support for onnx imports not implemented.") + if repo_id is not None and hub_revision is not None: + # if we load the pipeline code from the Hub + # make sure to overwrite the `revison` + revision = hub_revision - is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) + return get_class_from_dynamic_module( + custom_pipeline, + module_file=file_name, + class_name=class_name, + cache_dir=cache_dir, + revision=revision, + ) - transformers_version = version.parse(version.parse(transformers.__version__).base_version) + if class_obj != SharkDiffusionPipeline: + return class_obj - is_transformers_model = ( - issubclass(class_obj, PreTrainedModel) - and transformers_version >= version.parse("4.20.0") - ) + diffusers_module = importlib.import_module("diffusers") + class_name = config["_class_name"] + class_name = class_name[4:] if class_name.startswith("Flax") else class_name - # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. - # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. - # This makes sure that the weights won't be initialized which significantly speeds up loading. - if is_diffusers_model or is_transformers_model: - loading_kwargs["device_map"] = device_map - loading_kwargs["max_memory"] = max_memory - loading_kwargs["offload_folder"] = offload_folder - loading_kwargs["offload_state_dict"] = offload_state_dict - loading_kwargs["variant"] = model_variants.pop(name, None) - if from_flax: - loading_kwargs["from_flax"] = True - - # the following can be deleted once the minimum required `transformers` version - # is higher than 4.27 - if ( - is_transformers_model - and loading_kwargs["variant"] is not None - and transformers_version < version.parse("4.27.0") - ): - raise ImportError( - f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" - ) - elif is_transformers_model and loading_kwargs["variant"] is None: - loading_kwargs.pop("variant") + pipeline_cls = getattr(diffusers_module, class_name) - # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` - if not (from_flax and is_transformers_model): - loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + if load_connected_pipeline: + from diffusers.pipelines.auto_pipeline import _get_connected_pipeline + + connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) + if connected_pipeline_cls is not None: + logger.info( + f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" + ) else: - loading_kwargs["low_cpu_mem_usage"] = False + logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") + + pipeline_cls = connected_pipeline_cls or pipeline_cls - # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): - loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) - else: - # else load from the root directory - loaded_sub_model = load_method(cached_folder, **loading_kwargs) + # (SHARK): Monkey-patch in our classmethods to the auto-generated SD pipeline class. + pipeline_cls.setup_shark = setup_shark + pipeline_cls.prepare_pipe = prepare_pipe - return loaded_sub_model + class shark_pipeline_class(SharkMetaModelBase, pipeline_cls): + def __getattribute__(self, __name: str) -> Any: + return super().__getattribute__(__name) + pass -class SharkDiffusionPipeline(DiffusionPipeline, SharkPipelineBase): - # This class is responsible for executing image generation and creating - # /managing a set of compiled modules to run Stable Diffusion. The init - # aims to be as general as possible, and the class will infer and compile - # a list of necessary modules or a combined "pipeline module" for a - # specified job based on the inference task. + pipeline_cls = shark_pipeline_class + return pipeline_cls + + + +class SharkDiffusionPipeline(ConfigMixin, PushToHubMixin): + + r''' + Instantiates a diffusers pipeline and replaces any model preparation/device + methods and properties with our custom model loading and runtime + (provided by SharkPipelineBase). + + This class is responsible for creating and managing a set of compiled + modules to run a diffusers 'DiffusionPipeline'. The init + aims to be as general as possible, and the class will infer and compile + a list of necessary modules or a combined "pipeline module" via Turbine + for a specified job based on the inference task. + ''' config_name = "model_index.json" model_cpu_offload_seq = None _optional_components = [] @@ -258,133 +323,204 @@ class SharkDiffusionPipeline(DiffusionPipeline, SharkPipelineBase): _load_connected_pipes = False _is_onnx = False - def __init__( - self, - base_model_id, - height: int, - width: int, - batch_size: int, - precision: str, - device: str, - custom_vae: str = None, - num_loras: int = 0, - import_ir: bool = True, - is_controlled: bool = False, - ): + @property + def dtype(self) -> torch.dtype: + r""" + Returns: + `torch.dtype`: The torch dtype on which the pipeline is located. + """ + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] - self.model_max_length = 77 - self.batch_size = batch_size - self.precision = precision - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.height = height - self.width = width - self.scheduler_obj = {} - compile_static_args = { - "pipe": { - "external_weights": "safetensors", - }, - "clip": {"hf_model_name": base_model_id}, - "unet": { - "hf_model_name": base_model_id, - "unet_model": unet.UnetModel( - hf_model_name=base_model_id, hf_auth_token=None - ), - "batch_size": batch_size, - # "is_controlled": is_controlled, - # "num_loras": num_loras, - "height": height, - "width": width, - "precision": precision, - "max_length": self.model_max_length, - }, - "vae_encode": { - "hf_model_name": base_model_id, - "vae_model": self.vae_encode, - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - "vae_decode": { - "hf_model_name": base_model_id, - "vae_model": self.vae_decode, - "batch_size": batch_size, - "height": height, - "width": width, - "precision": precision, - }, - } - super().__init__(sd_model_map, base_model_id, compile_static_args, device, import_ir) - pipe_id_list = [ - safe_name(base_model_id), - str(batch_size), - str(static_kwargs["unet"]["max_length"]), - f"{str(height)}x{str(width)}", - precision, - ] - if num_loras > 0: - pipe_id_list.append(str(num_loras) + "lora") - if is_controlled: - pipe_id_list.append("controlled") - if custom_vae: - pipe_id_list.append(custom_vae) - self.pipe_id = "_".join(pipe_id_list) - print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") - del static_kwargs - gc.collect() + for module in modules: + return module.dtype + return torch.float32 + @property - def device(self): + def device(self) -> torch.device: r""" Returns: - `device`: The device on which the pipeline is located. + `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: - return module.device + return torch.device("cpu") - return torch.device('cpu') + return torch.device("cpu") - def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): - print(f"\n[LOG] Preparing pipeline...") - self.is_img2img = is_img2img - self.schedulers = get_schedulers(self.base_model_id) + def register_modules(self, **kwargs): + # import it here to avoid circular import + diffusers_module = importlib.import_module("diffusers") + pipelines = getattr(diffusers_module, "pipelines") - self.weights_path = os.path.join( - get_checkpoints_path(), self.safe_name(self.base_model_id) - ) - if not os.path.exists(self.weights_path): - os.mkdir(self.weights_path) - - for model in adapters: - self.model_map[model] = adapters[model] - - for submodel in self.static_kwargs: - if custom_weights: - custom_weights_params, _ = process_custom_pipe_weights(custom_weights) - if submodel not in ["clip", "clip2"]: - self.static_kwargs[submodel][ - "external_weight_file" - ] = custom_weights_params + for name, module in kwargs.items(): + # retrieve library + if module is None or isinstance(module, (tuple, list)) and module[0] is None: + register_dict = {name: (None, None)} + else: + # register the config from the original module, not the dynamo compiled one + not_compiled_module = _unwrap_model(module) + + library = not_compiled_module.__module__.split(".")[0] + + # check if the module is a pipeline module + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if is_pipeline_module: + library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ + + # retrieve class_name + class_name = not_compiled_module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def __setattr__(self, name: str, value: Any): + if name in self.__dict__ and hasattr(self.config, name): + # We need to overwrite the config if name exists in config + if isinstance(getattr(self.config, name), (tuple, list)): + if value is not None and self.config[name][0] is not None: + class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" - ) + class_library_tuple = (None, None) + + self.register_to_config(**{name: class_library_tuple}) else: - self.static_kwargs[submodel]["external_weight_path"] = os.path.join( - self.weights_path, submodel + ".safetensors" - ) + self.register_to_config(**{name: value}) + + super().__setattr__(name, value) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = True, + variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its + class implements both a save and loading method. The pipeline is easily reloaded using the + [`~DiffusionPipeline.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save a pipeline to. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name", None) + model_index_dict.pop("_diffusers_version", None) + model_index_dict.pop("_module", None) + model_index_dict.pop("_name_or_path", None) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + if library_name in sys.modules: + library = importlib.import_module(library_name) + else: + logger.info( + f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" + ) + + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + if save_method_name is None: + logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.") + # make sure that unsaveable components are not tried to be loaded afterward + self.register_to_config(**{pipeline_component_name: (None, None)}) + continue + + save_method = getattr(sub_model, save_method_name) + + # Call the save method with the argument safe_serialization only if it's supported + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + save_method_accept_variant = "variant" in save_method_signature.parameters + + save_kwargs = {} + if save_method_accept_safe: + save_kwargs["safe_serialization"] = safe_serialization + if save_method_accept_variant: + save_kwargs["variant"] = variant + + save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) + + # finally save the config + self.save_config(save_directory) + + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + - self.get_compiled_map(pipe_id=self.pipe_id) - print("\n[LOG] Pipeline successfully prepared for runtime.") - return - @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], device, **kwargs): + r""" Instantiate a Sharkified PyTorch diffusion pipeline from pretrained pipeline weights. @@ -413,6 +549,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): @@ -437,6 +574,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P custom_revision=custom_revision, variant=variant, load_connected_pipeline=load_connected_pipeline, + shark_device=device, **kwargs, ) else: @@ -480,7 +618,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P class_name=custom_class_name, cache_dir=cache_dir, revision=custom_revision, + shark_device=device, ) + # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( @@ -637,6 +777,8 @@ def get_connected_passed_kwargs(prefix): ) # 7. Potentially add passed objects if expected + init_kwargs['device'] = device + init_kwargs['model_map'] = sd_model_map missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components @@ -654,8 +796,398 @@ def get_connected_passed_kwargs(prefix): # 9. Save where the model was instantiated from model.register_to_config(_name_or_path=pretrained_model_name_or_path) - breakpoint() return model + + @classmethod + @validate_hf_hub_args + def download(cls, pretrained_model_name, shark_device, **kwargs) -> Union[str, os.PathLike]: + + cache_dir = kwargs.pop("cache_dir", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + from_flax = kwargs.pop("from_flax", False) + custom_pipeline = kwargs.pop("custom_pipeline", None) + custom_revision = kwargs.pop("custom_revision", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + use_onnx = kwargs.pop("use_onnx", None) + load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + trust_remote_code = kwargs.pop("trust_remote_code", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + allow_patterns = None + ignore_patterns = None + + model_info_call_error: Optional[Exception] = None + if not local_files_only: + try: + info = model_info(pretrained_model_name, token=token, revision=revision) + except HTTPError as e: + logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") + local_files_only = True + model_info_call_error = e # save error to reraise it if model is not cached locally + + if not local_files_only: + config_file = hf_hub_download( + pretrained_model_name, + cls.config_name, + cache_dir=cache_dir, + revision=revision, + proxies=proxies, + force_download=force_download, + resume_download=resume_download, + token=token, + ) + + config_dict = cls._dict_from_json_file(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] + + filenames = {sibling.rfilename for sibling in info.siblings} + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + + diffusers_module = importlib.import_module("diffusers") + pipelines = getattr(diffusers_module, "pipelines") + + # optionally create a custom component <> custom file mapping + custom_components = {} + for component in folder_names: + module_candidate = config_dict[component][0] + + if module_candidate is None or not isinstance(module_candidate, str): + continue + + # We compute candidate file path on the Hub. Do not use `os.path.join`. + candidate_file = f"{component}/{module_candidate}.py" + + if candidate_file in filenames: + custom_components[component] = module_candidate + elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): + raise ValueError( + f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." + ) + + # remove ignored filenames + model_filenames = set(model_filenames) - set(ignore_filenames) + variant_filenames = set(variant_filenames) - set(ignore_filenames) + + # if the whole pipeline is cached we don't have to ping the Hub + + model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} + + custom_class_name = None + if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)): + custom_pipeline = config_dict["_class_name"][0] + custom_class_name = config_dict["_class_name"][1] + + # all filenames compatible with variant will be added + allow_patterns = list(model_filenames) + + # allow all patterns from non-model folders + # this enables downloading schedulers, tokenizers, ... + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] + # add custom component files + allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()] + # add custom pipeline file + allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] + # also allow downloading config.json files with the model + allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] + + allow_patterns += [ + SCHEDULER_CONFIG_NAME, + CONFIG_NAME, + cls.config_name, + CUSTOM_PIPELINE_FILE_NAME, + ] + + load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames + load_components_from_hub = len(custom_components) > 0 + + if load_pipe_from_hub and not trust_remote_code: + raise ValueError( + f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + + if load_components_from_hub and not trust_remote_code: + raise ValueError( + f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly " + f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + + # retrieve passed components that should not be downloaded + pipeline_class = _get_pipeline_class( + cls, + config_dict, + load_connected_pipeline=load_connected_pipeline, + custom_pipeline=custom_pipeline, + repo_id=pretrained_model_name if load_pipe_from_hub else None, + hub_revision=revision, + class_name=custom_class_name, + cache_dir=cache_dir, + revision=custom_revision, + shark_device=shark_device, + ) + + expected_components, _ = cls._get_signature_keys(pipeline_class) + passed_components = [k for k in expected_components if k in kwargs] + + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ) + ): + raise EnvironmentError( + f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + elif use_safetensors and is_safetensors_compatible( + model_filenames, variant=variant, passed_components=passed_components + ): + ignore_patterns = ["*.bin", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if ( + len(safetensors_variant_filenames) > 0 + and safetensors_model_filenames != safetensors_variant_filenames + ): + logger.warn( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warn( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." + ) + + # Don't download any objects that are passed + allow_patterns = [ + p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) + ] + + if pipeline_class._load_connected_pipes: + allow_patterns.append("README.md") + + # Don't download index files of forbidden patterns either + ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] + + re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] + re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] + + expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] + expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) + + if pipeline_is_cached and not force_download: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + user_agent = {"pipeline_class": cls.__name__} + if custom_pipeline is not None and not custom_pipeline.endswith(".py"): + user_agent["custom_pipeline"] = custom_pipeline + + # download all allow_patterns - ignore_patterns + try: + cached_folder = snapshot_download( + pretrained_model_name, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + + # retrieve pipeline class from local file + cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) + cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None + + if pipeline_class is not None and pipeline_class._load_connected_pipes: + modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) + connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], []) + for connected_pipe_repo_id in connected_pipes: + download_kwargs = { + "cache_dir": cache_dir, + "resume_download": resume_download, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "variant": variant, + "use_safetensors": use_safetensors, + } + DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs) + + return cached_folder + + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise EnvironmentError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occured" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + def to(self, *args, **kwargs): + + torch_dtype = kwargs.pop("torch_dtype", None) + torch_device = kwargs.pop("torch_device", None) + dtype_kwarg = kwargs.pop("dtype", None) + device_kwarg = kwargs.pop("device", None) + silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) + + dtype = torch_dtype or dtype_kwarg + + device = torch_device or device_kwarg + + dtype_arg = None + device_arg = None + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_arg = args[0] + else: + device_arg = torch.device(args[0]) if args[0] is not None else None + elif len(args) == 2: + if isinstance(args[0], torch.dtype): + raise ValueError( + "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`." + ) + device_arg = torch.device(args[0]) if args[0] is not None else None + dtype_arg = args[1] + elif len(args) > 2: + raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`") + + if dtype is not None and dtype_arg is not None: + raise ValueError( + "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + dtype = dtype or dtype_arg + + if device is not None and device_arg is not None: + raise ValueError( + "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two." + ) + + device = device or device_arg + + # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. + def module_is_sequentially_offloaded(module): + return False + + def module_is_offloaded(module): + return False + + # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer + pipeline_is_sequentially_offloaded = any( + module_is_sequentially_offloaded(module) for _, module in self.components.items() + ) + if pipeline_is_sequentially_offloaded: + raise ValueError( + "Sequential offload not supported." + ) + + # Display a warning in this case (the operation succeeds but the benefits are lost) + pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) + if pipeline_is_offloaded: + logger.warning( + "Sequential offload not supported." + ) + + module_names, _ = self._get_signature_keys(self) + modules = [getattr(self, n, None) for n in module_names] + modules = [m for m in modules if isinstance(m, torch.nn.Module)] + + is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded + for module in modules: + is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + + if is_loaded_in_8bit and dtype is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." + ) + + if is_loaded_in_8bit and device is not None: + logger.warning( + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." + ) + else: + module.to(device, dtype) + + if ( + module.dtype == torch.float16 + and str(device) in ["cpu"] + and not silence_dtype_warnings + and not is_offloaded + ): + logger.warning( + "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." + ) + return self + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + optional_names = list(optional_parameters) + for name in optional_names: + if name in cls._optional_components: + expected_modules.add(name) + optional_parameters.remove(name) + + return expected_modules, optional_parameters def shark_sd_fn_dict_input( @@ -784,7 +1316,7 @@ def shark_sd_fn( # parameters that are static in the turbine output format, # which is currently MLIR in the torch dialect. - sd_pipe = SharkDiffusionPipeline.from_pretrained( + sd_pipe = DiffusionPipeline.from_pretrained( **submit_pipe_kwargs, ) global_obj.set_sd_obj(sd_pipe) diff --git a/apps/shark_studio/modules/meta_model.py b/apps/shark_studio/modules/meta_model.py new file mode 100644 index 0000000000..7da3a90a6a --- /dev/null +++ b/apps/shark_studio/modules/meta_model.py @@ -0,0 +1,218 @@ +from msvcrt import kbhit +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, + clean_device_info, + get_iree_target_triple, +) +from apps.shark_studio.web.utils.file_utils import ( + get_checkpoints_path, + get_resource_path, +) +from apps.shark_studio.modules.shared_cmd_opts import ( + cmd_opts, +) +from iree import runtime as ireert +from pathlib import Path +import gc +import os + + +class SharkMetaModelBase: + # This class is a lightweight base for managing an + # inference API class. It should provide methods for: + # - compiling a set (model map) of torch IR modules + # - preparing weights for an inference job + # - loading weights for an inference job + # - utilites like benchmarks, tests + + def __init__( + self, + model_map: dict, + device: str, + dtype: str = "f16", + import_mlir: bool = True, + ): + self.model_map = model_map + self.pipe_map = {} + self.triple = get_iree_target_triple(device) + self._device, self.device_id = clean_device_info(device) + self.import_mlir = import_mlir + self.iree_module_dict = {} + self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp")) + if not os.path.exists(self.tmp_dir): + os.mkdir(self.tmp_dir) + self.tempfiles = {} + self.pipe_vmfb_path = "" + self._dtype = dtype + + @property + def device(self): + return self._device + + @device.setter + def device(self, device_str): + self._device = device_str + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, dtype_val): + self._dtype = dtype_val + + def get_compiled_map(self, pipe_id, static_kwargs, submodel="None", init_kwargs={}) -> None: + # First checks whether we have .vmfbs precompiled, then populates the map + # with the precompiled executables and fetches executables for the rest of the map. + # The weights aren't static here anymore so this function should be a part of pipeline + # initialization. As soon as you have a pipeline ID unique to your static torch IR parameters, + # and your model map is populated with any IR - unique model IDs and their static params, + # call this method to get the artifacts associated with your map. + self.pipe_id = self.safe_name(pipe_id) + self.pipe_vmfb_path = Path( + os.path.join(get_checkpoints_path(".."), self.pipe_id) + ) + self.pipe_vmfb_path.mkdir(parents=False, exist_ok=True) + if submodel == "None": + print("\n[LOG] Gathering any pre-compiled artifacts....") + for key in self.model_map: + self.get_compiled_map(pipe_id, static_kwargs, submodel=key) + else: + self.pipe_map[submodel] = {} + self.get_precompiled(self.pipe_id, submodel) + ireec_flags = [] + if submodel in self.iree_module_dict: + return + elif "vmfb_path" in self.pipe_map[submodel]: + return + elif submodel not in self.tempfiles: + print( + f"\n[LOG] Tempfile for {submodel} not found. Fetching torch IR..." + ) + if submodel in static_kwargs: + init_kwargs = static_kwargs[submodel] + for key in static_kwargs["pipe"]: + if key not in init_kwargs: + init_kwargs[key] = static_kwargs["pipe"][key] + self.import_torch_ir(submodel, init_kwargs) + self.get_compiled_map(pipe_id, submodel) + else: + ireec_flags = ( + self.model_map[submodel]["ireec_flags"] + if "ireec_flags" in self.model_map[submodel] + else [] + ) + + self.iree_module_dict[submodel] = get_iree_compiled_module( + self.tempfiles[submodel], + device=self.device, + frontend="torch", + mmap=True, + external_weight_file=self.get_io_params(submodel), + extra_args=ireec_flags, + write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"), + ) + return + + def get_io_params(self, submodel): + if "external_weight_file" in self.static_kwargs[submodel]: + # we are using custom weights + weights_path = self.static_kwargs[submodel]["external_weight_file"] + elif "external_weight_path" in self.static_kwargs[submodel]: + # we are using the default weights for the HF model + weights_path = self.static_kwargs[submodel]["external_weight_path"] + else: + # assume the torch IR contains the weights. + weights_path = None + return weights_path + + def get_precompiled(self, pipe_id, submodel="None"): + if submodel == "None": + for model in self.model_map: + self.get_precompiled(pipe_id, model) + vmfbs = [] + for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path): + vmfbs.extend(filenames) + break + for file in vmfbs: + if submodel in file: + self.pipe_map[submodel]["vmfb_path"] = os.path.join( + self.pipe_vmfb_path, file + ) + return + + def import_torch_ir(self, submodel, kwargs): + torch_ir = self.model_map[submodel]["initializer"]( + **self.safe_dict(kwargs), compile_to="torch" + ) + if submodel == "clip": + # clip.export_clip_model returns (torch_ir, tokenizer) + torch_ir = torch_ir[0] + + self.tempfiles[submodel] = os.path.join( + self.tmp_dir, f"{submodel}.torch.tempfile" + ) + + with open(self.tempfiles[submodel], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + return + + def load_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + print(f"\n[LOG] {submodel} is ready for inference.") + continue + if "vmfb_path" in self.pipe_map[submodel]: + weights_path = self.get_io_params(submodel) + # print( + # f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}" + # ) + self.iree_module_dict[submodel] = {} + ( + self.iree_module_dict[submodel]["vmfb"], + self.iree_module_dict[submodel]["config"], + self.iree_module_dict[submodel]["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.pipe_map[submodel]["vmfb_path"], + self.device, + device_idx=0, + rt_flags=[], + external_weight_file=weights_path, + ) + else: + self.get_compiled_map(self.pipe_id, submodel) + return + + def unload_submodels(self, submodels: list): + for submodel in submodels: + if submodel in self.iree_module_dict: + del self.iree_module_dict[submodel] + gc.collect() + return + + def run(self, submodel, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + inp = [ + ireert.asdevicearray( + self.iree_module_dict[submodel]["config"].device, input + ) + for input in inputs + ] + return self.iree_module_dict[submodel]["vmfb"]["main"](*inp) + + def safe_name(self, name): + return name.replace("/", "_").replace("-", "_").replace("\\", "_") + + def safe_dict(self, kwargs: dict): + flat_args = {} + for i in kwargs: + if isinstance(kwargs[i], dict) and "pass_dict" not in kwargs[i]: + flat_args[i] = [kwargs[i][j] for j in kwargs[i]] + else: + flat_args[i] = kwargs[i] + + return flat_args diff --git a/apps/shark_studio/tests/diffusers_pipeline_test.py b/apps/shark_studio/tests/diffusers_pipeline_test.py new file mode 100644 index 0000000000..421d8a5457 --- /dev/null +++ b/apps/shark_studio/tests/diffusers_pipeline_test.py @@ -0,0 +1,47 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import torch +import unittest +import PIL +from typing import List +from apps.shark_studio.api.sd import SharkDiffusionPipeline +#from diffusers import DiffusionPipeline + + +class SDBaseAPITest(unittest.TestCase): + def testPipeSimple(self): + pipe = SharkDiffusionPipeline.from_pretrained( + pretrained_model_name_or_path="hf-internal-testing/tiny-stable-diffusion-torch", + device="vulkan", + torch_dtype=torch.float32, + ) + pipe.setup_shark( + base_model_id="hf-internal-testing/tiny-stable-diffusion-torch", + height=512, + width=512, + batch_size=1, + precision="f32", + device="vulkan", + ) + + pipe.prepare_pipe( + custom_weights="", + adapters=[], + embeddings=[], + is_img2img=False, + ) + + prompt = ["An astronaut riding a fearsome shark"] + negative_prompt = [""] + image = pipe(prompt, negative_prompt).images[0] + assert isinstance(image, List(PIL.Image.Image)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() \ No newline at end of file