Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch.sdp attention #4

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions guided_diffusion/script_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def create_model_and_diffusion(
use_fp16,
use_new_attention_order,
use_neighborhood_attention,
use_torch_sdp_attention,
):
model = create_model(
image_size,
Expand All @@ -116,6 +117,7 @@ def create_model_and_diffusion(
use_fp16=use_fp16,
use_new_attention_order=use_new_attention_order,
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
)
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
Expand Down Expand Up @@ -148,6 +150,7 @@ def create_model(
use_fp16=False,
use_new_attention_order=False,
use_neighborhood_attention=False,
use_torch_sdp_attention=False,
):
if channel_mult == "":
if image_size == 512:
Expand Down Expand Up @@ -186,6 +189,7 @@ def create_model(
resblock_updown=resblock_updown,
use_new_attention_order=use_new_attention_order,
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
)


Expand Down
41 changes: 37 additions & 4 deletions guided_diffusion/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from .fp16_util import convert_module_to_f16, convert_module_to_f32
from .nn import (
Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(
use_new_attention_order=False,
spatial=None,
use_neighborhood_attention=False,
use_torch_sdp_attention=False,
):
super().__init__()
self.channels = channels
Expand All @@ -287,12 +289,14 @@ def __init__(
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_neighborhood_attention and use_torch_sdp_attention:
raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True')
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads, spatial, use_neighborhood_attention)
self.attention = QKVAttention(self.num_heads, spatial, use_neighborhood_attention, use_torch_sdp_attention)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads, spatial, use_neighborhood_attention)
self.attention = QKVAttentionLegacy(self.num_heads, spatial, use_neighborhood_attention, use_torch_sdp_attention)

self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

Expand Down Expand Up @@ -333,11 +337,14 @@ class QKVAttentionLegacy(nn.Module):
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""

def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False):
def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False, use_torch_sdp_attention=False):
super().__init__()
self.n_heads = n_heads
self.spatial = tuple(spatial)
self.use_neighborhood_attention = use_neighborhood_attention
self.use_torch_sdp_attention = use_torch_sdp_attention
if use_neighborhood_attention and use_torch_sdp_attention:
raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True')
self.neighborhood_mask_cache = {}

def get_neighborhood_mask(self, spatial, device):
Expand All @@ -359,6 +366,11 @@ def forward(self, qkv, spatial=None):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
if self.use_torch_sdp_attention:
q, k, v = rearrange(qkv, "n (h p c) t -> p n h t c", p=3, c=ch).contiguous().unbind()
a = th.nn.functional.scaled_dot_product_attention(q, k, v)
a = rearrange(a, 'n h t c -> n (h c) t')
return a
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
Expand All @@ -380,11 +392,14 @@ class QKVAttention(nn.Module):
A module which performs QKV attention and splits in a different order.
"""

def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False):
def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False, use_torch_sdp_attention=False):
super().__init__()
self.n_heads = n_heads
self.spatial = tuple(spatial)
self.use_neighborhood_attention = use_neighborhood_attention
self.use_torch_sdp_attention = use_torch_sdp_attention
if use_neighborhood_attention and use_torch_sdp_attention:
raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True')
self.neighborhood_mask_cache = {}

def get_neighborhood_mask(self, spatial, device):
Expand All @@ -406,6 +421,11 @@ def forward(self, qkv, spatial=None):
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
if self.use_torch_sdp_attention:
q, k, v = rearrange(qkv, "n (p h c) t -> p n h t c", p=3, c=ch).contiguous().unbind()
a = th.nn.functional.scaled_dot_product_attention(q, k, v)
a = rearrange(a, 'n h t c -> n (h c) t')
return a
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
Expand Down Expand Up @@ -477,6 +497,7 @@ def __init__(
resblock_updown=False,
use_new_attention_order=False,
use_neighborhood_attention=False,
use_torch_sdp_attention=False,
):
super().__init__()

Expand All @@ -499,6 +520,9 @@ def __init__(
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_neighborhood_attention = use_neighborhood_attention
self.use_torch_sdp_attention = use_torch_sdp_attention
if use_neighborhood_attention and use_torch_sdp_attention:
raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True')

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
Expand Down Expand Up @@ -541,6 +565,7 @@ def __init__(
use_new_attention_order=use_new_attention_order,
spatial=(image_size // ds, image_size // ds),
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
Expand Down Expand Up @@ -588,6 +613,7 @@ def __init__(
use_new_attention_order=use_new_attention_order,
spatial=(image_size // ds, image_size // ds),
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
),
ResBlock(
ch,
Expand Down Expand Up @@ -626,6 +652,7 @@ def __init__(
use_new_attention_order=use_new_attention_order,
spatial=(image_size // ds, image_size // ds),
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
)
)
if level and i == num_res_blocks:
Expand Down Expand Up @@ -748,6 +775,7 @@ def __init__(
use_new_attention_order=False,
pool="adaptive",
use_neighborhood_attention=False,
use_torch_sdp_attention=False,
):
super().__init__()

Expand All @@ -768,6 +796,9 @@ def __init__(
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_neighborhood_attention = use_neighborhood_attention
self.use_torch_sdp_attention = use_torch_sdp_attention
if use_neighborhood_attention and use_torch_sdp_attention:
raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True')

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
Expand Down Expand Up @@ -807,6 +838,7 @@ def __init__(
use_new_attention_order=use_new_attention_order,
spatial=(image_size // ds, image_size // ds),
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
Expand Down Expand Up @@ -846,6 +878,7 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
use_neighborhood_attention=use_neighborhood_attention,
use_torch_sdp_attention=use_torch_sdp_attention,
),
AttentionBlock(
ch,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
name="guided-diffusion",
packages=["guided_diffusion"],
py_modules=["guided_diffusion"],
install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
install_requires=["blobfile>=1.0.5", "torch", "tqdm", "einops"],
)