-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add extrapolation to the Bernstein polynomial transformation #37
Add extrapolation to the Bernstein polynomial transformation #37
Conversation
Just mentioning @oduerr here to notify him of the PR. |
zuko/transforms.py
Outdated
rank = self.theta.dim() | ||
if rank > 1: | ||
# add singleton batch dimensions | ||
dims = [...] + [None] * (rank - 1) | ||
x = x[dims] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be necessary as torch.distributions.Beta.log_prob
broadcasts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test fail if I remove this. Couldn't make it work without.
What would you suggest?
zuko/transforms.py
Outdated
|
||
left_bound = x <= self.eps | ||
right_bound = x >= 1 - self.eps | ||
x_safe = torch.where(left_bound | right_bound, 0.5, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am afraid that the 0.5
might not work with all supported versions of PyTorch / CUDA tensors. To be tested. Also, left_bound
and right_bound
are not great names as they are out-of-bound indicators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this be better:
x_safe = torch.where(lower_bound | upper_bound, 0.5 * torch.ones_like(x), x)
1. When to log-Scale 2. #sigmoid(17) is 1.0! --> clamping to 1e-6 1e+6
adds additional reference to BNF density forecasting paper
3096cb4
to
13ead7d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work and very clean code! I just have a few minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i resolved all open issues. Anything else that holds us back from closing this PR?
Hello @MArpogaus, I built the documentation with the new components and there was a few errors, so I made a few edits + some reformulation. I also limited the domain to [-5, 5] (like NSF) instead of [-10, 10], which should improve the results of Otherwise everything seems fine to me. Great job! I am ready to merge the PR, so tell me if you agree with my edits. |
Hello @francois-rozet, thanks for your final review. Your edits seam absolutely reasonable to me and improve the clarity of the doc strings. |
This adresses #36.