From c4a3b09a36fb22b949dc7d56f447206d5fd3b0d5 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 5 Aug 2022 14:42:03 +0200 Subject: [PATCH] [UNet2DConditionModel] add cross_attention_dim as an argument (#155) add cross_attention_dim as an argument --- src/diffusers/models/unet_2d_condition.py | 4 ++++ src/diffusers/models/unet_blocks.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ae82e202b..a39223811 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -28,6 +28,7 @@ def __init__( act_fn="silu", norm_num_groups=32, norm_eps=1e-5, + cross_attention_dim=1280, attention_head_dim=8, ): super().__init__() @@ -64,6 +65,7 @@ def __init__( add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, ) @@ -77,6 +79,7 @@ def __init__( resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, resnet_groups=norm_num_groups, ) @@ -101,6 +104,7 @@ def __init__( add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, ) self.up_blocks.append(up_block) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 034e662ed..66357d78e 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -31,6 +31,7 @@ def get_down_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + cross_attention_dim=None, downsample_padding=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type @@ -58,6 +59,8 @@ def get_down_block( attn_num_head_channels=attn_num_head_channels, ) elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, @@ -67,6 +70,7 @@ def get_down_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) elif down_block_type == "SkipDownBlock2D": @@ -115,6 +119,7 @@ def get_up_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + cross_attention_dim=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -129,6 +134,8 @@ def get_up_block( resnet_act_fn=resnet_act_fn, ) elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, @@ -138,6 +145,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) elif up_block_type == "AttnUpBlock2D":