Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variable batch to SD compilation #782

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions examples/05_stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ Build the AIT modules by running `compile.py`.

Set correct width and height depending on the model variant
```
python3 scripts/compile.py --width 512 --height 512
python3 scripts/compile.py --batch-size 1 8 \
--width 512 --height 512
```
It generates three folders: `./tmp/CLIPTextModel`, `./tmp/UNet2DConditionModel`, `./tmp/AutoencoderKL`. In each folder, there is a `test.so` file which is the generated AIT module for the model.

Expand Down Expand Up @@ -98,17 +99,25 @@ Run AIT models with an example image:

Set correct width and height depending on the model variant
```
python3 scripts/demo.py --width 512 --height 512
python3 scripts/demo.py --batch 8 \
--width 512 --height 512 \
--prompt "a photo of an astronaut riding a horse on mars"
```

Check the resulted images: `example_ait_[0..7].png`

Img2img demo:

Internally img2img demo will download and use [sketch-mountains-input.jpg](https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg).

Set correct width and height depending on the model variant
```
python3 scripts/demo_img2img.py --width 512 --height 512
python3 scripts/demo_img2img.py --batch 8 \
--width 512 --height 512 \
--prompt "A fantasy landscape, trending on artstation"
```

Check the resulted image: `example_ait.png`
Check the resulted images: `example_ait_[0..7].png`


### Sample outputs
Expand Down
21 changes: 18 additions & 3 deletions examples/05_stable_diffusion/scripts/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,22 @@
)
@click.option("--width", default=512, help="Width of generated image")
@click.option("--height", default=512, help="Height of generated image")
@click.option("--batch-size", default=1, help="batch size")
@click.option(
"--batch-size",
default=(1, 8),
type=(int, int),
nargs=2,
help="Minimum and maximum batch size",
)
@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation")
@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm")
def compile_diffusers(
local_dir, width, height, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True
local_dir,
width,
height,
batch_size,
use_fp16_acc=True,
convert_conv_to_gemm=True
):
logging.getLogger().setLevel(logging.INFO)
torch.manual_seed(4896)
Expand Down Expand Up @@ -73,10 +84,14 @@ def compile_diffusers(
dim=pipe.text_encoder.config.hidden_size,
act_layer=pipe.text_encoder.config.hidden_act,
)

# UNet
compile_unet(
pipe.unet,
batch_size=batch_size * 2,
batch_size=(
batch_size[0] * 2,
batch_size[1] * 2,
), # double batch size for unet
width=ww,
height=hh,
use_fp16_acc=use_fp16_acc,
Expand Down
8 changes: 5 additions & 3 deletions examples/05_stable_diffusion/scripts/demo_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
)
@click.option("--width", default=512, help="Width of generated image")
@click.option("--height", default=512, help="Height of generated image")
@click.option("--batch", default=1, help="Batch size of generated image")
@click.option(
"--prompt", default="A fantasy landscape, trending on artstation", help="prompt"
)
@click.option(
"--benchmark", type=bool, default=False, help="run stable diffusion e2e benchmark"
)
def run(local_dir, width, height, prompt, benchmark):
def run(local_dir, width, height, batch, prompt, benchmark):
# load the pipeline
device = "cuda"
pipe = StableDiffusionImg2ImgAITPipeline.from_pretrained(
Expand All @@ -60,6 +61,7 @@ def run(local_dir, width, height, prompt, benchmark):
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((height, width))

prompt = [prompt] * batch
with torch.autocast("cuda"):
images = pipe(
prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5
Expand All @@ -68,8 +70,8 @@ def run(local_dir, width, height, prompt, benchmark):
args = (prompt, init_image)
t = benchmark_torch_function(10, pipe, *args)
print(f"sd e2e: {t} ms")

images[0].save("fantasy_landscape_ait.png")
for i, image in enumerate(images):
image.save(f"example_ait_{i}.png")


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions examples/05_stable_diffusion/src/compile_lib/compile_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def map_clip_params(pt_mod):

def compile_clip(
pt_mod,
batch_size=1,
batch_size=(1, 8),
seqlen=64,
dim=768,
num_heads=12,
Expand All @@ -70,8 +70,9 @@ def compile_clip(

pt_mod = pt_mod.eval()
params_ait = map_clip_params(pt_mod)
batch_size_d = IntVar(values=[1, max(8, batch_size)], name="batch_size")

# batch lower val should always be 1 and higher val should be 8+
# Otherwise output image will be a mess on T4 GPU
batch_size_d = IntVar(values=[1, max(8, batch_size[1])], name="batch_size")
input_ids_ait = Tensor(
[batch_size_d, seqlen], name="input0", dtype="int64", is_input=True
)
Expand Down
5 changes: 3 additions & 2 deletions examples/05_stable_diffusion/src/compile_lib/compile_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from aitemplate.compiler import compile_model
from aitemplate.frontend import Tensor
from aitemplate.frontend import IntVar, Tensor
from aitemplate.testing import detect_target

from ..modeling.unet_2d_condition import (
Expand Down Expand Up @@ -50,7 +50,7 @@ def map_unet_params(pt_mod, dim):

def compile_unet(
pt_mod,
batch_size=2,
batch_size=(2, 16),
height=64,
width=64,
dim=320,
Expand All @@ -72,6 +72,7 @@ def compile_unet(
# set AIT parameters
pt_mod = pt_mod.eval()
params_ait = map_unet_params(pt_mod, dim)
batch_size = IntVar(values=list(batch_size), name="batch_size")

latent_model_input_ait = Tensor(
[batch_size, height, width, 4], name="input0", is_input=True
Expand Down
5 changes: 3 additions & 2 deletions examples/05_stable_diffusion/src/compile_lib/compile_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
from aitemplate.compiler import compile_model
from aitemplate.frontend import Tensor
from aitemplate.frontend import IntVar, Tensor
from aitemplate.testing import detect_target

from ..modeling.vae import AutoencoderKL as ait_AutoencoderKL
Expand Down Expand Up @@ -118,7 +118,7 @@ def map_vae(pt_module, device="cuda", dtype="float16"):

def compile_vae(
pt_mod,
batch_size=1,
batch_size=(1, 8),
height=64,
width=64,
use_fp16_acc=False,
Expand Down Expand Up @@ -159,6 +159,7 @@ def compile_vae(
latent_channels=latent_channels,
sample_size=sample_size,
)
batch_size = IntVar(values=list(batch_size), name="batch_size")

ait_input = Tensor(
shape=[batch_size, height, width, latent_channels],
Expand Down