Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-level masking #125

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open

Multi-level masking #125

wants to merge 8 commits into from

Conversation

51N84D
Copy link
Member

@51N84D 51N84D commented Aug 24, 2020

Current approach just adds bit-conditioning to the latent vector, and determines which (simulated) ground truth mask to compute losses with.

@vict0rsch
Copy link
Contributor

Is that up to date? What's keeping us from merging?

@51N84D 51N84D marked this pull request as ready for review September 22, 2020 02:12
@51N84D 51N84D requested review from vict0rsch, melisandeteng and tianyu-z and removed request for melisandeteng September 22, 2020 02:12
@51N84D
Copy link
Member Author

51N84D commented Sep 22, 2020

Alright it's finally ready to review. Try running it before approving, just to make sure nothing is broken...

Copy link
Contributor

@vict0rsch vict0rsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes because I'd like to go over some of the code liv with you @51N84D

omnigan/discriminator.py Show resolved Hide resolved
omnigan/trainer.py Show resolved Hide resolved
if "m2" in self.opts.tasks:
prediction = self.G.decoders[update_task](
torch.cat(
(self.z, self.label_1[0, :, :, :].unsqueeze(0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem

if update_task == "m2":
prediction = self.G.decoders["m"](
torch.cat(
(self.z, self.label_2[0, :, :, :].unsqueeze(0)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.label_0[0, :, :, :].unsqueeze(0) should be the same as self.label_0[:1, ...]

prediction = prediction.repeat(1, 3, 1, 1)
task_saves.append(x * (1.0 - prediction))
task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1)))

elif update_task == "d":
if update_task == "d":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not elif ?


step_loss += update_loss

elif update_task == "m2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I read this right, the only things that change between the if and elif are self.label_1[:, 0, 0, 0].squeeze() vs self.label_2[:, 0, 0, 0].squeeze() and self.logger.losses.generator.task_loss. I bet you can refactor this whole block in a much shorter way by having variables and a common code dependent on those. Cleaner, shorter, less error-prone (don't need to change 2 pieces of code if you change the logic) more versatile (what about more flood levels?)

self.D["m"]["Advent"],
)
if "m2" in self.opts.tasks:
# --------ADVENT LOSS---------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add CoBlock to your vscode extension and use cmd+shift+k to make nice comment blocks instead of those atrocious imbalanced hybrid headers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:p

@vict0rsch
Copy link
Contributor

How about this PR, is it ready @51N84D ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants