Skip to content

Commit

Permalink
lazily create up_descending after state dict is already loaded, but o…
Browse files Browse the repository at this point in the history
…nly do it once
  • Loading branch information
technillogue committed Oct 15, 2024
1 parent e0703ed commit 7333f34
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ def __init__(
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

# this is a hack to get something like property but only evaluate it once
# we're doing it like this so that up_descending isn't in the state_dict keys
# without adding anything conditional to the main flow
def __getattr__(self, name):
if name == "up_descending":
self.up_descending = nn.Sequential(*reversed(self.up))
Decoder.__getattr__ = nn.Module.__getattr__
return self.up_descending
return super().__getattr__(name)

def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
Expand All @@ -258,7 +268,7 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.block_2(h)

# upsampling
h = self.up(h)
h = self.up_descending(h)

# end
h = self.norm_out(h)
Expand Down

0 comments on commit 7333f34

Please sign in to comment.