-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-level masking #125
base: master
Are you sure you want to change the base?
Multi-level masking #125
Conversation
Is that up to date? What's keeping us from merging? |
Alright it's finally ready to review. Try running it before approving, just to make sure nothing is broken... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requesting changes because I'd like to go over some of the code liv with you @51N84D
omnigan/trainer.py
Outdated
if "m2" in self.opts.tasks: | ||
prediction = self.G.decoders[update_task]( | ||
torch.cat( | ||
(self.z, self.label_1[0, :, :, :].unsqueeze(0)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idem
omnigan/trainer.py
Outdated
if update_task == "m2": | ||
prediction = self.G.decoders["m"]( | ||
torch.cat( | ||
(self.z, self.label_2[0, :, :, :].unsqueeze(0)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.label_0[0, :, :, :].unsqueeze(0) should be the same as self.label_0[:1, ...]
omnigan/trainer.py
Outdated
prediction = prediction.repeat(1, 3, 1, 1) | ||
task_saves.append(x * (1.0 - prediction)) | ||
task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) | ||
|
||
elif update_task == "d": | ||
if update_task == "d": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not elif
?
|
||
step_loss += update_loss | ||
|
||
elif update_task == "m2": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I read this right, the only things that change between the if
and elif
are self.label_1[:, 0, 0, 0].squeeze()
vs self.label_2[:, 0, 0, 0].squeeze()
and self.logger.losses.generator.task_loss
. I bet you can refactor this whole block in a much shorter way by having variables and a common code dependent on those. Cleaner, shorter, less error-prone (don't need to change 2 pieces of code if you change the logic) more versatile (what about more flood levels?)
self.D["m"]["Advent"], | ||
) | ||
if "m2" in self.opts.tasks: | ||
# --------ADVENT LOSS--------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add CoBlock
to your vscode extension and use cmd+shift+k to make nice comment blocks instead of those atrocious imbalanced hybrid headers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:p
How about this PR, is it ready @51N84D ? |
Current approach just adds bit-conditioning to the latent vector, and determines which (simulated) ground truth mask to compute losses with.