Skip to content

Commit

Permalink
[Feature] Support SDXL (open-mmlab#2035)
Browse files Browse the repository at this point in the history
* support sdxl

* fix test

* fix test

* fix test

* add docs

---------

Co-authored-by: rangoliu <[email protected]>
  • Loading branch information
okotaku and liuwenran authored Sep 21, 2023
1 parent 130b603 commit 8977fa1
Show file tree
Hide file tree
Showing 10 changed files with 1,314 additions and 5 deletions.
57 changes: 57 additions & 0 deletions configs/stable_diffusion_xl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Stable Diffusion XL (2023)

> [Stable Diffusion XL](https://arxiv.org/abs/2307.01952)
> **Task**: Text2Image, Inpainting
<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared the previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/okotaku/diffengine/assets/24734142/27d4ebad-5705-4500-826f-41f425a08c0d"/>
</div>

## Pretrained models

| Model | Task | Dataset | Download |
| :----------------------------------------------------------------: | :--------: | :-----: | :------: |
| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - |

We use stable diffusion xl weights. This model has several weights including vae, unet and clip.

You may download the weights from [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and change the 'from_pretrained' in config to the weights dir.

## Quick Start

Running the following codes, you can get a text-generated image.

```python
from mmengine import MODELS, Config

from mmengine.registry import init_default_scope

init_default_scope('mmagic')

config = 'configs/stable_diffusion_xl/stable-diffusion_xl_ddim_denoisingunet.py'
config = Config.fromfile(config).copy()

StableDiffuser = MODELS.build(config.model)
prompt = 'A mecha robot in a favela in expressionist style'
StableDiffuser = StableDiffuser.to('cuda')

image = StableDiffuser.infer(prompt)['samples'][0]
image.save('robot.png')
```

## Comments

Our codebase for the stable diffusion models builds heavily on [diffusers codebase](https://github.com/huggingface/diffusers) and the model weights are from [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).

Thanks for the efforts of the community!
18 changes: 18 additions & 0 deletions configs/stable_diffusion_xl/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Collections:
- Name: Stable Diffusion XL
Paper:
Title: Stable Diffusion XL
URL: https://arxiv.org/abs/2307.01952
README: configs/stable_diffusion_xl/README.md
Task:
- text2image
- inpainting
Year: 2023
Models:
- Config: configs/stable_diffusion_xl/stable-diffusion_xl_ddim_denoisingunet.py
In Collection: Stable Diffusion XL
Name: stable-diffusion_xl_ddim_denoisingunet
Results:
- Dataset: '-'
Metrics: {}
Task: Text2Image
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Use DiffuserWrapper!
stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0'
unet = dict(
type='UNet2DConditionModel',
subfolder='unet',
from_pretrained=stable_diffusion_xl_url)
vae = dict(
type='AutoencoderKL',
from_pretrained=stable_diffusion_xl_url,
subfolder='vae')

diffusion_scheduler = dict(
type='EditDDIMScheduler',
variance_type='learned_range',
beta_end=0.012,
beta_schedule='scaled_linear',
beta_start=0.00085,
num_train_timesteps=1000,
set_alpha_to_one=False,
clip_sample=False)

model = dict(
type='StableDiffusionXL',
unet=unet,
vae=vae,
enable_xformers=False,
text_encoder_one=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_xl_url,
subfolder='text_encoder'),
tokenizer_one=stable_diffusion_xl_url,
text_encoder_two=dict(
type='ClipWrapper',
clip_type='huggingface',
pretrained_model_name_or_path=stable_diffusion_xl_url,
subfolder='text_encoder_2'),
tokenizer_two=stable_diffusion_xl_url,
scheduler=diffusion_scheduler,
test_scheduler=diffusion_scheduler)
3 changes: 2 additions & 1 deletion mmagic/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .srcnn import SRCNNNet
from .srgan import SRGAN, ModifiedVGG, MSRResNet
from .stable_diffusion import StableDiffusion, StableDiffusionInpaint
from .stable_diffusion_xl import StableDiffusionXL
from .stylegan1 import StyleGAN1
from .stylegan2 import StyleGAN2
from .stylegan3 import StyleGAN3, StyleGAN3Generator
Expand Down Expand Up @@ -97,5 +98,5 @@
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DeblurGanV2',
'DeblurGanV2Generator', 'DeblurGanV2Discriminator',
'StableDiffusionInpaint', 'ViCo', 'FastComposer', 'AnimateDiff',
'UNet3DConditionMotionModel'
'UNet3DConditionMotionModel', 'StableDiffusionXL'
]
2 changes: 1 addition & 1 deletion mmagic/models/editors/dreambooth/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DreamBooth(StableDiffusion):
Defaults to 3.
prior_loss_weight (float, optional): The weight for class prior loss.
Defaults to 0.
fine_tune_text_encoder (bool, optional): Whether to fine-tune text
finetune_text_encoder (bool, optional): Whether to fine-tune text
encoder. Defaults to False.
dtype (str, optional): The dtype for the model. Defaults to 'fp16'.
enable_xformers (bool, optional): Whether to use xformers.
Expand Down
6 changes: 3 additions & 3 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ def infer(self,
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.test_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs)['prev_sample']
# compute the previous noisy sample x_t -> x_t-1
latents = self.test_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs)['prev_sample']

# 8. Post-processing
image = self.decode_latents(latents.to(img_dtype))
Expand Down
4 changes: 4 additions & 0 deletions mmagic/models/editors/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .stable_diffusion_xl import StableDiffusionXL

__all__ = ['StableDiffusionXL']
Loading

0 comments on commit 8977fa1

Please sign in to comment.