Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extrapolation to the Bernstein polynomial transformation #37

Merged
merged 36 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9fb77a0
first draft of BPF extrapolation
MArpogaus Jan 22, 2024
8ef2f47
fix calculation of jacobian for linear and non linear data scaling
MArpogaus Jan 23, 2024
279f84b
add simple test script for extrapolation
MArpogaus Jan 23, 2024
788cb0b
fix extrapolation for batched data
MArpogaus Jan 23, 2024
ebb7939
use analytic jacobian in call_and_ladj
MArpogaus Jan 24, 2024
8b8c0cd
add bound as optional argument to BPF
MArpogaus Jan 24, 2024
3c2c469
fix extrapolation on gpu
Jan 25, 2024
801a3d9
add eps scaling to ladj
MArpogaus Jan 26, 2024
bcf0f9a
Fixed numerical issue involving log(0)
oduerr Jan 26, 2024
96db9c9
fix order of log and abs in log_abs_det_jacobian
MArpogaus Jan 29, 2024
b0ad995
fix ladj for linear case
MArpogaus Jan 29, 2024
cb703c0
remove unneeded eps scaling
MArpogaus Jan 29, 2024
500319a
remove files not yet merged in master
MArpogaus Jan 29, 2024
3d194f0
add additional reference to BNF density forecasting paper
MArpogaus Jan 29, 2024
ae9f2e9
remove sigmoid and analytical log_abs_det_jacobian
MArpogaus Feb 3, 2024
72d3d96
add inverse of extrapolation
MArpogaus Feb 3, 2024
397b2a9
Add `smooth_bounds` argument to ensure a smooth transition into extra…
MArpogaus Feb 5, 2024
1f45095
Updates doc string
MArpogaus Feb 13, 2024
8f2d26b
enable smooth_bounds per default
MArpogaus Feb 13, 2024
97c7234
remove bound argument from BPF
MArpogaus Feb 13, 2024
18f2b0e
use dim=-1 for concatenation
MArpogaus Feb 13, 2024
a299b70
remove unneeded attributes
MArpogaus Feb 13, 2024
4ba1c72
fix BernsteinTransform tests
MArpogaus Feb 13, 2024
bd86479
refactor small methods into f
MArpogaus Feb 13, 2024
13ead7d
add optional argument to keep the transformation inside the bounds
MArpogaus Feb 14, 2024
2668600
fix: correct typ annotation for bound `int` -> `float`
MArpogaus Feb 21, 2024
0b1a988
Merge branch 'master' into bpf_extrapolation
MArpogaus Mar 22, 2024
38f0c84
Add a bounded Bernstein Polynomial for simple chaining
MArpogaus Mar 27, 2024
83727f5
Use bounded version of Bernstein polynomial for BPF
MArpogaus Mar 27, 2024
9fb5f73
update doc strings of Bernstein transformations.
MArpogaus Mar 27, 2024
5da217e
use fiexed values for offset and slope in BoundedBernsteinTransform
MArpogaus Mar 27, 2024
19d02d9
reverts removal of torch transform star import
MArpogaus Mar 27, 2024
ecf0773
remove unnecessary use of partial and add original paper as reference
MArpogaus Mar 27, 2024
67293e2
uses single quotes for strings
MArpogaus Mar 27, 2024
8ff4386
remove redundant description from BoundedBernsteinTransform doc string
MArpogaus Mar 27, 2024
8633d6b
update some docstrings
francois-rozet Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ For more information, check out the documentation and tutorials at [zuko.readthe
| `UNAF` | 2019 | [Unconstrained Monotonic Neural Networks](https://arxiv.org/abs/1908.05164) |
| `CNF` | 2018 | [Neural Ordinary Differential Equations](https://arxiv.org/abs/1806.07366) |
| `GF` | 2020 | [Gaussianization Flows](https://arxiv.org/abs/2003.01941) |
| `BPF` | 2020 | [Bernstein-Polynomial Normalizing Flows](https://arxiv.org/abs/2204.13939) |
| `BPF` | 2020 | [Bernstein-Polynomial Normalizing Flows](https://arxiv.org/abs/2004.00464) |
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved

## Contributing

Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_univariate_transforms(batched: bool):
MonotonicRQSTransform(randn(*batch, 8), randn(*batch, 8), randn(*batch, 7)),
MonotonicTransform(lambda x: x**3),
BernsteinTransform(randn(*batch, 16)),
BernsteinTransform(randn(*batch, 16), linear=True),
BoundedBernsteinTransform(randn(*batch, 16)),
GaussianizationTransform(randn(*batch, 8), randn(*batch, 8)),
UnconstrainedMonotonicTransform(lambda x: torch.exp(-(x**2)) + 1e-2, randn(batch)),
SOSPolynomialTransform(randn(*batch, 3, 5), randn(batch)),
Expand Down
15 changes: 4 additions & 11 deletions zuko/flows/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
r"""Polynomial flows."""

__all__ = [
'SOSPF',
'BPF',
"SOSPF",
"BPF",
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
]

from functools import partial

# isort: local
from .autoregressive import MAF
from .core import Unconditional
from ..transforms import BernsteinTransform, SoftclipTransform, SOSPolynomialTransform
from ..transforms import BoundedBernsteinTransform, SoftclipTransform, SOSPolynomialTransform


class SOSPF(MAF):
Expand Down Expand Up @@ -77,7 +77,6 @@ class BPF(MAF):
features: The number of features.
context: The number of context features.
degree: The degree :math:`M` of the Bernstein polynomial.
linear: Whether to use a linear or sigmoid mapping to :math:`[0, 1]`.
kwargs: Keyword arguments passed to :class:`zuko.flows.autoregressive.MAF`.
"""

Expand All @@ -86,18 +85,12 @@ def __init__(
features: int,
context: int = 0,
degree: int = 16,
linear: bool = False,
**kwargs,
):
super().__init__(
features=features,
context=context,
univariate=partial(BernsteinTransform, linear=linear),
univariate=partial(BoundedBernsteinTransform),
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
shapes=[(degree + 1,)],
**kwargs,
)

transforms = self.transform.transforms

for i in reversed(range(1, len(transforms))):
transforms.insert(i, Unconditional(SoftclipTransform, bound=11.0))
151 changes: 122 additions & 29 deletions zuko/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'MonotonicRQSTransform',
'MonotonicTransform',
'BernsteinTransform',
'BoundedBernsteinTransform',
'GaussianizationTransform',
'UnconstrainedMonotonicTransform',
'SOSPolynomialTransform',
Expand All @@ -30,8 +31,7 @@

from textwrap import indent
from torch import BoolTensor, LongTensor, Size, Tensor
from torch.distributions import Transform, constraints
from torch.distributions.transforms import * # noqa: F403
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
from torch.distributions import Distribution, Transform, constraints
from torch.distributions.utils import _sum_rightmost
from typing import Any, Callable, Iterable, Tuple, Union

Expand Down Expand Up @@ -603,12 +603,15 @@ def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
class BernsteinTransform(MonotonicTransform):
r"""Creates a monotonic Bernstein polynomial transformation.

.. math:: f(x) = \frac{1}{M + 1} \sum_{i=0}^{M} b_{i+1,M-i+1}(\sigma(x)) \, \theta_i
.. math:: f(x) = \frac{1}{M + 1} \sum_{i=0}^{M} b_{i+1,M-i+1}(x) \, \theta_i

where :math:`b_{i,j}` are the Bernstein basis polynomials and :math:`\sigma(x)` is
the sigmoid function.
where :math:`b_{i,j}` are the Bernstein basis polynomials.
Since :math:`f` is only defined for :math:`x \in [0, 1]`, it is linearly extrapolated outside this interval.

References:
| Deep transformation models: Tackling complex regression problems with neural network based transformation models (Sick et al., 2020)
| https://arxiv.org/abs/2004.00464
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved

| Short-Term Density Forecasting of Low-Voltage Load using Bernstein-Polynomial Normalizing Flows (Arpogaus et al., 2022)
| https://arxiv.org/abs/2204.13939

Expand All @@ -618,9 +621,10 @@ class BernsteinTransform(MonotonicTransform):
Arguments:
theta: The unconstrained polynomial coefficients :math:`\theta`,
with shape :math:`(*, M + 1)`.
linear: Whether to replace the sigmoid function with a linear mapping
:math:`\frac{x + B}{2B}`. If :py:`True`, input features are assumed to be
in :math:`[-B, B]`. Failing to satisfy this constraint will result in NaNs.
smooth_bounds: When :py:`True` the second order derivative is set zero on the bounds, to
ensure smooth transition into extrapolation.
keep_in_bounds: Enforces the transformation to stay in the bounds :match:`[-B, B]` by
setting :math:`\theta_{0} = -B` :math:`\theta_{M} = B`.
kwargs: Keyword arguments passed to :class:`MonotonicTransform`.
"""

Expand All @@ -629,42 +633,131 @@ class BernsteinTransform(MonotonicTransform):
bijective = True
sign = +1

def __init__(self, theta: Tensor, linear: bool = False, **kwargs):
def __init__(self, theta: Tensor, **kwargs):
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(None, phi=(theta,), **kwargs)

self.theta = self._increasing(theta)
self.linear = linear
self.theta = self._constrain_theta(theta)
self.order = self.theta.shape[-1] - 1

dtheta = self.order * (self.theta[..., 1:] - self.theta[..., :-1])

degree = theta.shape[-1]
alpha = torch.arange(1, degree + 1, device=theta.device, dtype=theta.dtype)
beta = torch.arange(degree, 0, -1, device=theta.device, dtype=theta.dtype)
self.basis = self._bernstein_basis(self.order, device=theta.device, dtype=theta.dtype)
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
dbasis = self._bernstein_basis(self.order - 1, device=theta.device, dtype=theta.dtype)

self.basis = torch.distributions.Beta(alpha, beta)
# save slope on boundaries for interpolation
x = torch.tensor([self.eps, 1 - self.eps], device=theta.device, dtype=theta.dtype)
rank = self.theta.dim()
if rank > 1:
# add singleton batch dimensions
dims = [...] + [None] * (rank - 1)
x = x[dims]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary as torch.distributions.Beta.log_prob broadcasts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test fail if I remove this. Couldn't make it work without.
What would you suggest?


self.offset = self._bernstein_poly(x, self.theta, self.basis)
self.slope = self._bernstein_poly(x, dtheta, dbasis)

@staticmethod
def _increasing(theta: Tensor) -> Tensor:
r"""Processes the unconstrained output of the hyper-network to be increasing."""
def _constrain_theta(unconstrained_theta: Tensor) -> Tensor:
"""Processes the unconstrained output of the hyper-network to be increasing."""
shift = math.log(2.0) * unconstrained_theta.shape[-1] / 2

theta_min = unconstrained_theta[..., :1]
unconstrained_theta = unconstrained_theta[..., 1:]

# ensure smooth bounds
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
unconstrained_theta = torch.cat(
(
unconstrained_theta[..., :1],
unconstrained_theta,
unconstrained_theta[..., -1:],
),
dim=-1,
)

diffs = torch.nn.functional.softplus(unconstrained_theta)
diffs = torch.cat((theta_min, diffs), dim=-1)

shift = math.log(2.0) * theta.shape[-1] / 2
theta = torch.cumsum(diffs, dim=-1) - shift

widths = torch.nn.functional.softplus(theta[..., 1:])
widths = torch.cat((theta[..., :1], widths), dim=-1)
return theta

@staticmethod
def _bernstein_basis(order: int, **kwargs):
alpha = torch.arange(1, order + 2, **kwargs)
beta = torch.arange(order + 1, 0, -1, **kwargs)
basis = torch.distributions.Beta(alpha, beta)
return basis

return torch.cumsum(widths, dim=-1) - shift
@staticmethod
def _bernstein_poly(x: Tensor, theta: Tensor, basis: Distribution):
b = basis.log_prob(x.unsqueeze(-1)).exp()
y = torch.mean(b * theta, dim=-1)
return y

def f(self, x: Tensor) -> Tensor:
if self.linear:
x = (x + self.bound) / (2 * self.bound) # map [-B, B] to [0, 1]
else:
x = torch.nn.functional.sigmoid(x) # map [-inf, inf] to [0, 1]
x = (x + self.bound) / (2 * self.bound) # map [-B, B] to [0, 1]

lower_bound = x <= self.eps
upper_bound = x >= 1 - self.eps
x_safe = torch.where(lower_bound | upper_bound, 0.5 * torch.ones_like(x), x)

x = x * (1 - 2e-6) + 1e-6
x = x.unsqueeze(-1)
b = self.basis.log_prob(x).exp()
y = torch.mean(b * self.theta, dim=-1)
y = self._bernstein_poly(x_safe, self.theta, self.basis)

# f'(eps) * (x - eps) + f(eps)
y0 = self.slope[0] * (x - self.eps) + self.offset[0]

# f'(1-eps) * (x - 1 - eps) + f(1-eps)
y1 = self.slope[1] * (x - 1 + self.eps) + self.offset[1]

y = torch.where(lower_bound, y0, y)
y = torch.where(upper_bound, y1, y)

return y

def _inverse(self, y: Tensor):
left_bound = y <= self.offset[0]
right_bound = y >= self.offset[1]

x = super()._inverse(y)
x0 = (y - self.offset[0]) / self.slope[0] + self.eps
x1 = (y - self.offset[1]) / self.slope[1] - self.eps + 1

# map [0, 1] to [-B, B]
x0 = x0 * 2 * self.bound - self.bound
x1 = x1 * 2 * self.bound - self.bound

x = torch.where(left_bound, x0, x)
x = torch.where(right_bound, x1, x)

return x


class BoundedBernsteinTransform(BernsteinTransform):
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
def _constrain_theta(self, unconstrained_theta: Tensor) -> Tensor:
"""Processes the unconstrained output of the hyper-network to be increasing."""

theta_min = -self.bound * torch.ones_like(unconstrained_theta[..., :1])

diff_on_bounds = (2 * self.bound) / (unconstrained_theta.shape[-1] + 4)

diffs = torch.nn.functional.softmax(unconstrained_theta, dim=-1) * (
2 * self.bound - 4 * diff_on_bounds
)

# ensure identity on bounds by enforcing Be'(0,1) = 1 and Be''(0,1) = 0
# Be'(0) = order * theta_1 - theta_0 = order * diff_0 -> diff_0 = 1 / order
# Be''(0) = (order - 1) * (diff_1 - diff_0) -> diff_0 == diff_1
diffs = torch.cat(
(
theta_min,
diff_on_bounds * torch.ones_like(diffs[..., :2]),
diffs,
diff_on_bounds * torch.ones_like(diffs[..., :2]),
),
dim=-1,
)

return torch.cumsum(diffs, dim=-1)


class GaussianizationTransform(MonotonicTransform):
r"""Creates a gaussianization transformation.
Expand Down