diff --git a/configs/diffusers_pipeline/README.md b/configs/diffusers_pipeline/README.md new file mode 100644 index 0000000000..ab2b246d15 --- /dev/null +++ b/configs/diffusers_pipeline/README.md @@ -0,0 +1,47 @@ +# Diffusers Pipeline (2023) + +> [Diffusers Pipeline](https://github.com/huggingface/diffusers) + +> **Task**: Diffusers Pipeline + + + +## Abstract + + + +We support diffusers pipelines for users to conveniently use diffusers to do inferece in our repo. + +## Configs + +| Model | Dataset | Download | +| :---------------------------------------: | :-----: | :------: | +| [diffusers pipeline](./sd_xl_pipeline.py) | - | - | + +## Quick Start + +```python +from mmagic.apis import MMagicInferencer + +# Create a MMEdit instance and infer +editor = MMagicInferencer(model_name='diffusers_pipeline') +text_prompts = 'Japanese anime style, girl, beautiful, cute, colorful, best quality, extremely detailed' +negative_prompt = 'bad face, bad hands' +result_out_dir = 'resources/output/text2image/sd_xl_japanese.png' +editor.infer(text=text_prompts, + negative_prompt=negative_prompt, + result_out_dir=result_out_dir) +``` + +## Citation + +```bibtex +@misc{von-platen-etal-2022-diffusers, + author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf}, + title = {Diffusers: State-of-the-art diffusion models}, + year = {2022}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/diffusers}} +} +``` diff --git a/configs/diffusers_pipeline/metafile.yml b/configs/diffusers_pipeline/metafile.yml new file mode 100644 index 0000000000..c61b740cef --- /dev/null +++ b/configs/diffusers_pipeline/metafile.yml @@ -0,0 +1,17 @@ +Collections: +- Name: Diffusers Pipeline + Paper: + Title: Diffusers Pipeline + URL: https://github.com/huggingface/diffusers + README: configs/diffusers_pipeline/README.md + Task: + - diffusers pipeline + Year: 2023 +Models: +- Config: configs/diffusers_pipeline/sd_xl_pipeline.py + In Collection: Diffusers Pipeline + Name: sd_xl_pipeline + Results: + - Dataset: '-' + Metrics: {} + Task: Diffusers Pipeline diff --git a/configs/diffusers_pipeline/sd_xl_pipeline.py b/configs/diffusers_pipeline/sd_xl_pipeline.py new file mode 100644 index 0000000000..e1c66e47e7 --- /dev/null +++ b/configs/diffusers_pipeline/sd_xl_pipeline.py @@ -0,0 +1,6 @@ +# config for model + +model = dict( + type='DiffusionPipeline', + from_pretrained='stabilityai/stable-diffusion-xl-base-1.0' +) diff --git a/mmagic/apis/inferencers/__init__.py b/mmagic/apis/inferencers/__init__.py index b175a0c297..66c5710b57 100644 --- a/mmagic/apis/inferencers/__init__.py +++ b/mmagic/apis/inferencers/__init__.py @@ -7,6 +7,7 @@ from .colorization_inferencer import ColorizationInferencer from .conditional_inferencer import ConditionalInferencer from .controlnet_animation_inferencer import ControlnetAnimationInferencer +from .diffusers_pipeline_inferencer import DiffusersPipelineInferencer from .eg3d_inferencer import EG3DInferencer from .image_super_resolution_inferencer import ImageSuperResolutionInferencer from .inpainting_inferencer import InpaintingInferencer @@ -23,7 +24,7 @@ 'ImageSuperResolutionInferencer', 'Text2ImageInferencer', 'TranslationInferencer', 'UnconditionalInferencer', 'VideoInterpolationInferencer', 'VideoRestorationInferencer', - 'ControlnetAnimationInferencer' + 'ControlnetAnimationInferencer', 'DiffusersPipelineInferencer' ] @@ -91,6 +92,9 @@ def __init__(self, ]: self.inferencer = ImageSuperResolutionInferencer( config, ckpt, device, extra_parameters, seed=seed) + elif self.task in ['Diffusers Pipeline']: + self.inferencer = DiffusersPipelineInferencer( + config, ckpt, device, extra_parameters, seed=seed) else: raise ValueError(f'Unknown inferencer task: {self.task}') diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py new file mode 100644 index 0000000000..5af47ebf13 --- /dev/null +++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Dict, List + +import numpy as np +from mmengine import mkdir_or_exist +from PIL.Image import Image +from torchvision.utils import save_image + +from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType + + +class DiffusersPipelineInferencer(BaseMMagicInferencer): + """inferencer that predicts with text2image models.""" + + func_kwargs = dict( + preprocess=['text', 'negative_prompt'], + forward=[], + visualize=['result_out_dir'], + postprocess=[]) + + extra_parameters = dict(height=None, width=None) + + def preprocess(self, + text: InputsType, + negative_prompt: InputsType = None) -> Dict: + """Process the inputs into a model-feedable format. + + Args: + text(InputsType): text input for text-to-image model. + negative_prompt(InputsType): negative prompt. + + Returns: + result(Dict): Results of preprocess. + """ + result = self.extra_parameters + result['prompt'] = text + + if negative_prompt: + result['negative_prompt'] = negative_prompt + + return result + + def forward(self, inputs: InputsType) -> PredType: + """Forward the inputs to the model.""" + images = self.model(**inputs).images + + return images + + def visualize(self, + preds: PredType, + result_out_dir: str = None) -> List[np.ndarray]: + """Visualize predictions. + + Args: + preds (List[Union[str, np.ndarray]]): Forward results + by the inferencer. + result_out_dir (str): Output directory of image. + Defaults to ''. + + Returns: + List[np.ndarray]: Result of visualize + """ + if result_out_dir: + mkdir_or_exist(os.path.dirname(result_out_dir)) + if type(preds) is list: + preds = preds[0] + if type(preds) is Image: + preds.save(result_out_dir) + else: + save_image(preds, result_out_dir, normalize=True) + + return preds diff --git a/mmagic/apis/mmagic_inferencer.py b/mmagic/apis/mmagic_inferencer.py index 634d5e5718..bad7e82df0 100644 --- a/mmagic/apis/mmagic_inferencer.py +++ b/mmagic/apis/mmagic_inferencer.py @@ -114,11 +114,14 @@ class MMagicInferencer: # 3D-aware generation 'eg3d', - # diffusers inferencer + # animation inferencer 'controlnet_animation', # draggan - 'draggan' + 'draggan', + + # diffusers pipeline inferencer + 'diffusers_pipeline', ] inference_supported_models_cfg = {} diff --git a/mmagic/models/archs/__init__.py b/mmagic/models/archs/__init__.py index f33271b509..a67bb3980b 100644 --- a/mmagic/models/archs/__init__.py +++ b/mmagic/models/archs/__init__.py @@ -63,10 +63,21 @@ def gen_wrapped_cls(module, module_name): wrapped_module = gen_wrapped_cls(module, module_name) MODELS.register_module(name=module_name, module=wrapped_module) DIFFUSERS_MODELS.append(module_name) - return DIFFUSERS_MODELS + DIFFUSERS_PIPELINES = [] + for pipeline_name in dir(diffusers.pipelines): + pipeline = getattr(diffusers.pipelines, pipeline_name) + if (inspect.isclass(pipeline) + and issubclass(pipeline, diffusers.DiffusionPipeline)): + wrapped_pipeline = gen_wrapped_cls(pipeline, pipeline_name) + MODELS.register_module(name=pipeline_name, module=wrapped_pipeline) + DIFFUSERS_PIPELINES.append(pipeline_name) -REGISTERED_DIFFUSERS_MODELS = register_diffusers_models() + return DIFFUSERS_MODELS, DIFFUSERS_PIPELINES + + +REGISTERED_DIFFUSERS_MODELS, REGISTERED_DIFFUSERS_PIPELINES = \ + register_diffusers_models() __all__ = [ 'ASPP', 'DepthwiseSeparableConvModule', 'SimpleGatedConvModule', diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py index b4dc39f394..5b95a9d649 100644 --- a/mmagic/models/archs/wrapper.py +++ b/mmagic/models/archs/wrapper.py @@ -177,3 +177,11 @@ def forward(self, *args, **kwargs) -> Any: Any: The output of wrapped module's forward function. """ return self.model(*args, **kwargs) + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + ): + self.model.to(torch_device, torch_dtype) + return self diff --git a/model-index.yml b/model-index.yml index a0cd9cd232..66c8c194db 100644 --- a/model-index.yml +++ b/model-index.yml @@ -12,6 +12,7 @@ Import: - configs/deepfillv1/metafile.yml - configs/deepfillv2/metafile.yml - configs/dic/metafile.yml +- configs/diffusers_pipeline/metafile.yml - configs/dim/metafile.yml - configs/disco_diffusion/metafile.yml - configs/draggan/metafile.yml