diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py index d26d2a0..dc9a08c 100644 --- a/flux/modules/autoencoder.py +++ b/flux/modules/autoencoder.py @@ -248,6 +248,16 @@ def __init__( self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + # this is a hack to get something like property but only evaluate it once + # we're doing it like this so that up_descending isn't in the state_dict keys + # without adding anything conditional to the main flow + def __getattr__(self, name): + if name == "up_descending": + self.up_descending = nn.Sequential(*reversed(self.up)) + Decoder.__getattr__ = nn.Module.__getattr__ + return self.up_descending + return super().__getattr__(name) + def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) @@ -258,7 +268,7 @@ def forward(self, z: Tensor) -> Tensor: h = self.mid.block_2(h) # upsampling - h = self.up(h) + h = self.up_descending(h) # end h = self.norm_out(h)