Skip to content

Commit

Permalink
add support for diffusers pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Sep 7, 2023
1 parent bced86e commit 06242d4
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 5 deletions.
47 changes: 47 additions & 0 deletions configs/diffusers_pipeline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Diffusers Pipeline (2023)

> [Diffusers Pipeline](https://github.com/huggingface/diffusers)
> **Task**: Diffusers Pipeline
<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

We support diffusers pipelines for users to conveniently use diffusers to do inferece in our repo.

## Configs

| Model | Dataset | Download |
| :---------------------------------------: | :-----: | :------: |
| [diffusers pipeline](./sd_xl_pipeline.py) | - | - |

## Quick Start

```python
from mmagic.apis import MMagicInferencer

# Create a MMEdit instance and infer
editor = MMagicInferencer(model_name='diffusers_pipeline')
text_prompts = 'Japanese anime style, girl, beautiful, cute, colorful, best quality, extremely detailed'
negative_prompt = 'bad face, bad hands'
result_out_dir = 'resources/output/text2image/sd_xl_japanese.png'
editor.infer(text=text_prompts,
negative_prompt=negative_prompt,
result_out_dir=result_out_dir)
```

## Citation

```bibtex
@misc{von-platen-etal-2022-diffusers,
author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Thomas Wolf},
title = {Diffusers: State-of-the-art diffusion models},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/huggingface/diffusers}}
}
```
17 changes: 17 additions & 0 deletions configs/diffusers_pipeline/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Collections:
- Name: Diffusers Pipeline
Paper:
Title: Diffusers Pipeline
URL: https://github.com/huggingface/diffusers
README: configs/diffusers_pipeline/README.md
Task:
- diffusers pipeline
Year: 2023
Models:
- Config: configs/diffusers_pipeline/sd_xl_pipeline.py
In Collection: Diffusers Pipeline
Name: sd_xl_pipeline
Results:
- Dataset: '-'
Metrics: {}
Task: Diffusers Pipeline
6 changes: 6 additions & 0 deletions configs/diffusers_pipeline/sd_xl_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# config for model

model = dict(
type='DiffusionPipeline',
from_pretrained='stabilityai/stable-diffusion-xl-base-1.0'
)
6 changes: 5 additions & 1 deletion mmagic/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .colorization_inferencer import ColorizationInferencer
from .conditional_inferencer import ConditionalInferencer
from .controlnet_animation_inferencer import ControlnetAnimationInferencer
from .diffusers_pipeline_inferencer import DiffusersPipelineInferencer
from .eg3d_inferencer import EG3DInferencer
from .image_super_resolution_inferencer import ImageSuperResolutionInferencer
from .inpainting_inferencer import InpaintingInferencer
Expand All @@ -23,7 +24,7 @@
'ImageSuperResolutionInferencer', 'Text2ImageInferencer',
'TranslationInferencer', 'UnconditionalInferencer',
'VideoInterpolationInferencer', 'VideoRestorationInferencer',
'ControlnetAnimationInferencer'
'ControlnetAnimationInferencer', 'DiffusersPipelineInferencer'
]


Expand Down Expand Up @@ -91,6 +92,9 @@ def __init__(self,
]:
self.inferencer = ImageSuperResolutionInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['Diffusers Pipeline']:
self.inferencer = DiffusersPipelineInferencer(
config, ckpt, device, extra_parameters, seed=seed)
else:
raise ValueError(f'Unknown inferencer task: {self.task}')

Expand Down
73 changes: 73 additions & 0 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List

import numpy as np
from mmengine import mkdir_or_exist
from PIL.Image import Image
from torchvision.utils import save_image

from .base_mmagic_inferencer import BaseMMagicInferencer, InputsType, PredType


class DiffusersPipelineInferencer(BaseMMagicInferencer):
"""inferencer that predicts with text2image models."""

func_kwargs = dict(
preprocess=['text', 'negative_prompt'],
forward=[],
visualize=['result_out_dir'],
postprocess=[])

extra_parameters = dict(height=None, width=None)

def preprocess(self,
text: InputsType,
negative_prompt: InputsType = None) -> Dict:
"""Process the inputs into a model-feedable format.
Args:
text(InputsType): text input for text-to-image model.
negative_prompt(InputsType): negative prompt.
Returns:
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['prompt'] = text

if negative_prompt:
result['negative_prompt'] = negative_prompt

return result

def forward(self, inputs: InputsType) -> PredType:
"""Forward the inputs to the model."""
images = self.model(**inputs).images

return images

def visualize(self,
preds: PredType,
result_out_dir: str = None) -> List[np.ndarray]:
"""Visualize predictions.
Args:
preds (List[Union[str, np.ndarray]]): Forward results
by the inferencer.
result_out_dir (str): Output directory of image.
Defaults to ''.
Returns:
List[np.ndarray]: Result of visualize
"""
if result_out_dir:
mkdir_or_exist(os.path.dirname(result_out_dir))
if type(preds) is list:
preds = preds[0]
if type(preds) is Image:
preds.save(result_out_dir)
else:
save_image(preds, result_out_dir, normalize=True)

return preds
7 changes: 5 additions & 2 deletions mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,14 @@ class MMagicInferencer:
# 3D-aware generation
'eg3d',

# diffusers inferencer
# animation inferencer
'controlnet_animation',

# draggan
'draggan'
'draggan',

# diffusers pipeline inferencer
'diffusers_pipeline',
]

inference_supported_models_cfg = {}
Expand Down
15 changes: 13 additions & 2 deletions mmagic/models/archs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,21 @@ def gen_wrapped_cls(module, module_name):
wrapped_module = gen_wrapped_cls(module, module_name)
MODELS.register_module(name=module_name, module=wrapped_module)
DIFFUSERS_MODELS.append(module_name)
return DIFFUSERS_MODELS

DIFFUSERS_PIPELINES = []
for pipeline_name in dir(diffusers.pipelines):
pipeline = getattr(diffusers.pipelines, pipeline_name)
if (inspect.isclass(pipeline)
and issubclass(pipeline, diffusers.DiffusionPipeline)):
wrapped_pipeline = gen_wrapped_cls(pipeline, pipeline_name)
MODELS.register_module(name=pipeline_name, module=wrapped_pipeline)
DIFFUSERS_PIPELINES.append(pipeline_name)

REGISTERED_DIFFUSERS_MODELS = register_diffusers_models()
return DIFFUSERS_MODELS, DIFFUSERS_PIPELINES


REGISTERED_DIFFUSERS_MODELS, REGISTERED_DIFFUSERS_PIPELINES = \
register_diffusers_models()

__all__ = [
'ASPP', 'DepthwiseSeparableConvModule', 'SimpleGatedConvModule',
Expand Down
8 changes: 8 additions & 0 deletions mmagic/models/archs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,11 @@ def forward(self, *args, **kwargs) -> Any:
Any: The output of wrapped module's forward function.
"""
return self.model(*args, **kwargs)

def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
):
self.model.to(torch_device, torch_dtype)
return self
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Import:
- configs/deepfillv1/metafile.yml
- configs/deepfillv2/metafile.yml
- configs/dic/metafile.yml
- configs/diffusers_pipeline/metafile.yml
- configs/dim/metafile.yml
- configs/disco_diffusion/metafile.yml
- configs/draggan/metafile.yml
Expand Down

0 comments on commit 06242d4

Please sign in to comment.