-
Notifications
You must be signed in to change notification settings - Fork 0
/
Custom_SDXL_Inference
82 lines (67 loc) · 3.48 KB
/
Custom_SDXL_Inference
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Optional, Union, Tuple, List, Callable, Dict
from tqdm.notebook import tqdm
import torch
from diffusers import StableDiffusionXLPipeline
import torch.nn.functional as nnf
import numpy as np
import DDIM_inversion
def latent2image(model: StableDiffusionXLPipeline, latents):
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = model.vae.dtype == torch.float16 and model.vae.config.force_upcast
if needs_upcasting:
model.upcast_vae()
latents = latents.to(next(iter(model.vae.post_quant_conv.parameters())).dtype)
image = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
model.vae.to(dtype=torch.float16)
image = model.image_processor.postprocess(image, output_type='pil')
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def diffusion_step(model: StableDiffusionXLPipeline, latent, context, t, guidance_scale, added_cond_kwargs):
latents_input = torch.cat([latent] * 2)
latents_input = model.scheduler.scale_model_input(latents_input, t)
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context, added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text)
out_latent = model.scheduler.step(noise_pred, t, latent, return_dict=False)[0]
return out_latent
@torch.no_grad()
def text2image_ldm_stable(
model: StableDiffusionXLPipeline,
prompt: List[str],
num_inference_steps: int = 50,
guidance_scale: Optional[float] = 7.5,
generator: Optional[torch.Generator] = None,
latent: Optional[torch.FloatTensor] = None,
uncond_embeddings=None,
start_time=50,
return_type='image'
):
added_cond_kwargs, prompt_embedds = DDIM_inversion._encode_text_sdxl_with_negative(model, prompt[0])
model.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
uncond_emb, cond_emb = prompt_embedds.chunk(2)
if uncond_embeddings is not None:
context = torch.cat([uncond_embeddings[i], cond_emb])
else:
context = prompt_embedds
latent = diffusion_step(model, latent, context, t, guidance_scale, added_cond_kwargs)
if return_type == 'image':
image = latent2image(model, latent)
else:
image = latent
return image, latent