From e0703ed7fdb865d5036b0eebfd25bae37665f246 Mon Sep 17 00:00:00 2001 From: technillogue Date: Mon, 14 Oct 2024 17:57:52 -0400 Subject: [PATCH] try to keep the same state_dict structure --- flux/modules/autoencoder.py | 49 ++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py index 83965a7..d26d2a0 100644 --- a/flux/modules/autoencoder.py +++ b/flux/modules/autoencoder.py @@ -105,6 +105,16 @@ def forward(self, x: Tensor): x = self.conv(x) return x +class DownBlock(nn.Module): + def __init__(self, block: list, downsample: nn.Module) -> None: + super().__init__() + # we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample" + self.block = nn.Sequential(*block) + self.downsample = downsample + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.downsample(self.block(x)) + class Encoder(nn.Module): def __init__( @@ -127,20 +137,25 @@ def __init__( curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult: tuple[int] = in_ch_mult + self.in_ch_mult = in_ch_mult down_layers = [] block_in = self.ch + # ideally, this would all append to a single flat nn.Sequential + # we cannot do this due to the existing state dict keys for i_level in range(self.num_resolutions): - attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] + block_layers = [] for _ in range(self.num_res_blocks): - down_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out # ? # originally this provided for attn layers, but those are never actually created if i_level != self.num_resolutions - 1: - down_layers.append(Downsample(block_in)) + downsample = Downsample(block_in) curr_res = curr_res // 2 + else: + downsample = nn.Identity() + down_layers.append(DownBlock(block_layers, downsample)) self.down = nn.Sequential(*down_layers) # middle @@ -168,6 +183,15 @@ def forward(self, x: Tensor) -> Tensor: h = self.conv_out(h) return h +class UpBlock(nn.Module): + def __init__(self, block: list, upsample: nn.Module) -> None: + super().__init__() + self.block = nn.Sequential(*block) + self.upsample = upsample + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.upsample(self.block(x)) + class Decoder(nn.Module): def __init__( @@ -203,19 +227,22 @@ def __init__( self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling - up_layers = [] + up_blocks = [] + # 3, 2, 1, 0, descending order for i_level in reversed(range(self.num_resolutions)): - blocks = [] + level_blocks = [] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): - blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + level_blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out if i_level != 0: - blocks.append(Upsample(block_in)) + upsample = Upsample(block_in) curr_res = curr_res * 2 - # ??? gross - up_layers = blocks + up_layers # prepend to get consistent order - self.up = nn.Sequential(*up_layers) + else: + upsample = nn.Identity() + # 0, 1, 2, 3, ascending order + up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order + self.up = nn.Sequential(*up_blocks) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)