From 7333f3461763c802b9476782417977e8d6659aad Mon Sep 17 00:00:00 2001 From: technillogue Date: Tue, 15 Oct 2024 02:23:17 -0400 Subject: [PATCH] lazily create up_descending after state dict is already loaded, but only do it once --- flux/modules/autoencoder.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/flux/modules/autoencoder.py b/flux/modules/autoencoder.py index d26d2a0..dc9a08c 100644 --- a/flux/modules/autoencoder.py +++ b/flux/modules/autoencoder.py @@ -248,6 +248,16 @@ def __init__( self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + # this is a hack to get something like property but only evaluate it once + # we're doing it like this so that up_descending isn't in the state_dict keys + # without adding anything conditional to the main flow + def __getattr__(self, name): + if name == "up_descending": + self.up_descending = nn.Sequential(*reversed(self.up)) + Decoder.__getattr__ = nn.Module.__getattr__ + return self.up_descending + return super().__getattr__(name) + def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) @@ -258,7 +268,7 @@ def forward(self, z: Tensor) -> Tensor: h = self.mid.block_2(h) # upsampling - h = self.up(h) + h = self.up_descending(h) # end h = self.norm_out(h)