Skip to content

Commit

Permalink
try to keep the same state_dict structure
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed Oct 15, 2024
1 parent 2a6dd2e commit e0703ed
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def forward(self, x: Tensor):
x = self.conv(x)
return x

class DownBlock(nn.Module):
def __init__(self, block: list, downsample: nn.Module) -> None:
super().__init__()
# we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample"
self.block = nn.Sequential(*block)
self.downsample = downsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.downsample(self.block(x))


class Encoder(nn.Module):
def __init__(
Expand All @@ -127,20 +137,25 @@ def __init__(

curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult: tuple[int] = in_ch_mult
self.in_ch_mult = in_ch_mult
down_layers = []
block_in = self.ch
# ideally, this would all append to a single flat nn.Sequential
# we cannot do this due to the existing state dict keys
for i_level in range(self.num_resolutions):
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
block_layers = []
for _ in range(self.num_res_blocks):
down_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out # ?
# originally this provided for attn layers, but those are never actually created
if i_level != self.num_resolutions - 1:
down_layers.append(Downsample(block_in))
downsample = Downsample(block_in)
curr_res = curr_res // 2
else:
downsample = nn.Identity()
down_layers.append(DownBlock(block_layers, downsample))
self.down = nn.Sequential(*down_layers)

# middle
Expand Down Expand Up @@ -168,6 +183,15 @@ def forward(self, x: Tensor) -> Tensor:
h = self.conv_out(h)
return h

class UpBlock(nn.Module):
def __init__(self, block: list, upsample: nn.Module) -> None:
super().__init__()
self.block = nn.Sequential(*block)
self.upsample = upsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.upsample(self.block(x))


class Decoder(nn.Module):
def __init__(
Expand Down Expand Up @@ -203,19 +227,22 @@ def __init__(
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

# upsampling
up_layers = []
up_blocks = []
# 3, 2, 1, 0, descending order
for i_level in reversed(range(self.num_resolutions)):
blocks = []
level_blocks = []
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
level_blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
if i_level != 0:
blocks.append(Upsample(block_in))
upsample = Upsample(block_in)
curr_res = curr_res * 2
# ??? gross
up_layers = blocks + up_layers # prepend to get consistent order
self.up = nn.Sequential(*up_layers)
else:
upsample = nn.Identity()
# 0, 1, 2, 3, ascending order
up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order
self.up = nn.Sequential(*up_blocks)

# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
Expand Down

0 comments on commit e0703ed

Please sign in to comment.