Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extrapolation to the Bernstein polynomial transformation #37

Merged
merged 36 commits into from
Apr 11, 2024

Conversation

MArpogaus
Copy link
Contributor

This adresses #36.

@MArpogaus
Copy link
Contributor Author

Just mentioning @oduerr here to notify him of the PR.

zuko/transforms.py Outdated Show resolved Hide resolved
zuko/transforms.py Outdated Show resolved Hide resolved
Comment on lines 647 to 658
rank = self.theta.dim()
if rank > 1:
# add singleton batch dimensions
dims = [...] + [None] * (rank - 1)
x = x[dims]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary as torch.distributions.Beta.log_prob broadcasts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test fail if I remove this. Couldn't make it work without.
What would you suggest?

zuko/transforms.py Outdated Show resolved Hide resolved
zuko/transforms.py Outdated Show resolved Hide resolved
zuko/flows/polynomial.py Outdated Show resolved Hide resolved
zuko/transforms.py Show resolved Hide resolved
zuko/transforms.py Outdated Show resolved Hide resolved

left_bound = x <= self.eps
right_bound = x >= 1 - self.eps
x_safe = torch.where(left_bound | right_bound, 0.5, x)
Copy link
Member

@francois-rozet francois-rozet Feb 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid that the 0.5 might not work with all supported versions of PyTorch / CUDA tensors. To be tested. Also, left_bound and right_bound are not great names as they are out-of-bound indicators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be better:

        x_safe = torch.where(lower_bound | upper_bound, 0.5 * torch.ones_like(x), x)

zuko/transforms.py Outdated Show resolved Hide resolved
@francois-rozet francois-rozet changed the title Bpf extrapolation Bernstein polynomial flow extrapolation Feb 5, 2024
Copy link
Member

@francois-rozet francois-rozet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work and very clean code! I just have a few minor comments.

zuko/transforms.py Outdated Show resolved Hide resolved
zuko/transforms.py Show resolved Hide resolved
zuko/transforms.py Outdated Show resolved Hide resolved
zuko/transforms.py Show resolved Hide resolved
zuko/transforms.py Show resolved Hide resolved
zuko/flows/polynomial.py Outdated Show resolved Hide resolved
zuko/flows/polynomial.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@MArpogaus MArpogaus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i resolved all open issues. Anything else that holds us back from closing this PR?

zuko/transforms.py Show resolved Hide resolved
@francois-rozet
Copy link
Member

Hello @MArpogaus, I built the documentation with the new components and there was a few errors, so I made a few edits + some reformulation. I also limited the domain to [-5, 5] (like NSF) instead of [-10, 10], which should improve the results of BPF for standardized (zero mean, unit variance) data.

Otherwise everything seems fine to me. Great job! I am ready to merge the PR, so tell me if you agree with my edits.

@MArpogaus
Copy link
Contributor Author

Hello @francois-rozet,

thanks for your final review. Your edits seam absolutely reasonable to me and improve the clarity of the doc strings.
I am happy to close this PR now, and exited to see if the results improve with the new bounds.

@francois-rozet francois-rozet merged commit 763a924 into probabilists:master Apr 11, 2024
5 checks passed
@francois-rozet francois-rozet changed the title Bernstein polynomial flow extrapolation Add extrapolation to the Bernstein polynomial transformation Apr 11, 2024
@francois-rozet francois-rozet linked an issue Apr 11, 2024 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Extrapolation to Bernstein Polynomial Flow
3 participants