Skip to content

Commit

Permalink
Taited/stable diffusion inpaint (open-mmlab#1976)
Browse files Browse the repository at this point in the history
* support stable diffusion inpaint

* add unit test for stable diffusion inpaint

* fix unit tests and update model zoo error

* fix unit test bug

* fix unit test error, add attribute latent_channels for vae

* fix lint: add the new config in readme

* improve docstrings and readme for stable diffusion inpaint

* add copyright of HuggingFace Team and markdown code language type

---------

Co-authored-by: zengyanhong <[email protected]>
  • Loading branch information
Taited and zengyh1900 authored Aug 11, 2023
1 parent 66e4637 commit 0448b89
Show file tree
Hide file tree
Showing 9 changed files with 730 additions and 10 deletions.
37 changes: 36 additions & 1 deletion configs/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

> [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
> **Task**: Text2Image
> **Task**: Text2Image, Inpainting
<!-- [ALGORITHM] -->

Expand Down Expand Up @@ -45,6 +45,7 @@ Stable Diffusion is a latent diffusion model conditioned on the text embeddings
| :----------------------------------------------------------------------------------: | :-----: | :------: |
| [stable_diffusion_v1.5](./stable-diffusion_ddim_denoisingunet.py) | - | - |
| [stable_diffusion_v1.5_tomesd](./stable-diffusion_ddim_denoisingunet-tomesd_5e-1.py) | - | - |
| [stable_diffusion_v1.5_inpaint](./stable-diffusion_ddim_denoisingunet-inpaint.py) | - | - |

We use stable diffusion v1.5 weights. This model has several weights including vae, unet and clip.

Expand Down Expand Up @@ -83,6 +84,40 @@ image = StableDiffuser.infer(prompt)['samples'][0]
image.save('robot.png')
```

To inpaint an image, you could run the following codes.

```python
import mmcv
from mmengine import MODELS, Config
from mmengine.registry import init_default_scope
from PIL import Image

init_default_scope('mmagic')

config = 'configs/stable_diffusion/stable-diffusion_ddim_denoisingunet-inpaint.py'
config = Config.fromfile(config).copy()
# change the 'pretrained_model_path' if you have downloaded the weights manually
# config.model.unet.from_pretrained = '/path/to/your/stable-diffusion-inpainting'
# config.model.vae.from_pretrained = '/path/to/your/stable-diffusion-inpainting'

StableDiffuser = MODELS.build(config.model)
prompt = 'a mecha robot sitting on a bench'

img_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png' # noqa
mask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png' # noqa

image = Image.fromarray(mmcv.imread(img_url, channel_order='rgb'))
mask = Image.fromarray(mmcv.imread(mask_url)).convert('L')
StableDiffuser = StableDiffuser.to('cuda')

image = StableDiffuser.infer(
prompt,
image,
mask
)['samples'][0]
image.save('inpaint.png')
```

## Use ToMe to accelerate your stable diffusion model

We support **[tomesd](https://github.com/dbolya/tomesd)** now! It is developed based on [ToMe](https://github.com/facebookresearch/ToMe), an efficient ViT speed-up tool based on token merging. To work on with **tomesd** in `mmagic`, you just need to add `tomesd_cfg` to `model` as shown in [stable_diffusion_v1.5_tomesd](stable-diffusion_ddim_denoisingunet-tomesd_5e-1.py). The only requirement is `torch >= 1.12.1` in order to properly support `torch.Tensor.scatter_reduce()` functionality. Please do check it before running the demo.
Expand Down
16 changes: 12 additions & 4 deletions configs/stable_diffusion/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Collections:
README: configs/stable_diffusion/README.md
Task:
- text2image
- inpainting
Year: 2022
Models:
- Config: configs/stable_diffusion/stable-diffusion_ddim_denoisingunet.py
Expand All @@ -14,23 +15,30 @@ Models:
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image
Task: Text2Image, Inpainting
- Config: configs/stable_diffusion/stable-diffusion_ddim_denoisingunet-tomesd_5e-1.py
In Collection: Stable Diffusion
Name: stable-diffusion_ddim_denoisingunet-tomesd_5e-1
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image
Task: Text2Image, Inpainting
- Dataset: '-'
Metrics:
Size / Num images per prompt:
PSNR: 512.0
SSIM: 5.0
Task: Text2Image
Task: Text2Image, Inpainting
- Dataset: '-'
Metrics:
Size / Num images per prompt:
PSNR: 512.0
SSIM: 5.0
Task: Text2Image
Task: Text2Image, Inpainting
- Config: configs/stable_diffusion/stable-diffusion_ddim_denoisingunet-inpaint.py
In Collection: Stable Diffusion
Name: stable-diffusion_ddim_denoisingunet-inpaint
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image, Inpainting
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Use DiffuserWrapper!
stable_diffusion_v15_url = 'runwayml/stable-diffusion-inpainting'
unet = dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_v15_url)
vae = dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_v15_url,
subfolder='vae')

diffusion_scheduler = dict(
type='EditDDIMScheduler',
variance_type='learned_range',
beta_end=0.012,
beta_schedule='scaled_linear',
beta_start=0.00085,
num_train_timesteps=1000,
set_alpha_to_one=False,
clip_sample=False)

model = dict(
type='StableDiffusionInpaint',
unet=unet,
vae=vae,
enable_xformers=False,
text_encoder=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_v15_url,
subfolder='text_encoder'),
tokenizer=stable_diffusion_v15_url,
scheduler=diffusion_scheduler,
test_scheduler=diffusion_scheduler)
5 changes: 3 additions & 2 deletions mmagic/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from .singan import SinGAN
from .srcnn import SRCNNNet
from .srgan import SRGAN, ModifiedVGG, MSRResNet
from .stable_diffusion import StableDiffusion
from .stable_diffusion import StableDiffusion, StableDiffusionInpaint
from .stylegan1 import StyleGAN1
from .stylegan2 import StyleGAN2
from .stylegan3 import StyleGAN3, StyleGAN3Generator
Expand Down Expand Up @@ -92,5 +92,6 @@
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet',
'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion',
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DeblurGanV2',
'DeblurGanV2Generator', 'DeblurGanV2Discriminator'
'DeblurGanV2Generator', 'DeblurGanV2Discriminator',
'StableDiffusionInpaint'
]
9 changes: 7 additions & 2 deletions mmagic/models/editors/ddpm/denoising_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,9 @@ class DenoisingUnet(BaseModule):
image_size (int | list[int]): The size of image to denoise.
in_channels (int, optional): The input channels of the input image.
Defaults as ``3``.
out_channels (int, optional): The output channels of the output
prediction. Defaults as ``None`` for automaticaaly assigned by
``var_mode``.
base_channels (int, optional): The basic channel number of the
generator. The other layers contain channels based on this number.
Defaults to ``128``.
Expand Down Expand Up @@ -837,6 +840,7 @@ class DenoisingUnet(BaseModule):
def __init__(self,
image_size,
in_channels=3,
out_channels=None,
base_channels=128,
resblocks_per_downsample=3,
num_timesteps=1000,
Expand Down Expand Up @@ -886,8 +890,9 @@ def __init__(self,
self.in_channels = in_channels

# double output_channels to output mean and var at same time
out_channels = in_channels if 'FIXED' in self.var_mode.upper() \
else 2 * in_channels
if out_channels is None:
out_channels = in_channels if 'FIXED' in self.var_mode.upper() \
else 2 * in_channels
self.out_channels = out_channels

# check type of image_size
Expand Down
3 changes: 2 additions & 1 deletion mmagic/models/editors/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .stable_diffusion import StableDiffusion
from .stable_diffusion_inpaint import StableDiffusionInpaint
from .vae import AutoencoderKL

__all__ = ['StableDiffusion', 'AutoencoderKL']
__all__ = ['StableDiffusion', 'AutoencoderKL', 'StableDiffusionInpaint']
Loading

0 comments on commit 0448b89

Please sign in to comment.