Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extrapolation to the Bernstein polynomial transformation #37

Merged
merged 36 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9fb77a0
first draft of BPF extrapolation
MArpogaus Jan 22, 2024
8ef2f47
fix calculation of jacobian for linear and non linear data scaling
MArpogaus Jan 23, 2024
279f84b
add simple test script for extrapolation
MArpogaus Jan 23, 2024
788cb0b
fix extrapolation for batched data
MArpogaus Jan 23, 2024
ebb7939
use analytic jacobian in call_and_ladj
MArpogaus Jan 24, 2024
8b8c0cd
add bound as optional argument to BPF
MArpogaus Jan 24, 2024
3c2c469
fix extrapolation on gpu
Jan 25, 2024
801a3d9
add eps scaling to ladj
MArpogaus Jan 26, 2024
bcf0f9a
Fixed numerical issue involving log(0)
oduerr Jan 26, 2024
96db9c9
fix order of log and abs in log_abs_det_jacobian
MArpogaus Jan 29, 2024
b0ad995
fix ladj for linear case
MArpogaus Jan 29, 2024
cb703c0
remove unneeded eps scaling
MArpogaus Jan 29, 2024
500319a
remove files not yet merged in master
MArpogaus Jan 29, 2024
3d194f0
add additional reference to BNF density forecasting paper
MArpogaus Jan 29, 2024
ae9f2e9
remove sigmoid and analytical log_abs_det_jacobian
MArpogaus Feb 3, 2024
72d3d96
add inverse of extrapolation
MArpogaus Feb 3, 2024
397b2a9
Add `smooth_bounds` argument to ensure a smooth transition into extra…
MArpogaus Feb 5, 2024
1f45095
Updates doc string
MArpogaus Feb 13, 2024
8f2d26b
enable smooth_bounds per default
MArpogaus Feb 13, 2024
97c7234
remove bound argument from BPF
MArpogaus Feb 13, 2024
18f2b0e
use dim=-1 for concatenation
MArpogaus Feb 13, 2024
a299b70
remove unneeded attributes
MArpogaus Feb 13, 2024
4ba1c72
fix BernsteinTransform tests
MArpogaus Feb 13, 2024
bd86479
refactor small methods into f
MArpogaus Feb 13, 2024
13ead7d
add optional argument to keep the transformation inside the bounds
MArpogaus Feb 14, 2024
2668600
fix: correct typ annotation for bound `int` -> `float`
MArpogaus Feb 21, 2024
0b1a988
Merge branch 'master' into bpf_extrapolation
MArpogaus Mar 22, 2024
38f0c84
Add a bounded Bernstein Polynomial for simple chaining
MArpogaus Mar 27, 2024
83727f5
Use bounded version of Bernstein polynomial for BPF
MArpogaus Mar 27, 2024
9fb5f73
update doc strings of Bernstein transformations.
MArpogaus Mar 27, 2024
5da217e
use fiexed values for offset and slope in BoundedBernsteinTransform
MArpogaus Mar 27, 2024
19d02d9
reverts removal of torch transform star import
MArpogaus Mar 27, 2024
ecf0773
remove unnecessary use of partial and add original paper as reference
MArpogaus Mar 27, 2024
67293e2
uses single quotes for strings
MArpogaus Mar 27, 2024
8ff4386
remove redundant description from BoundedBernsteinTransform doc string
MArpogaus Mar 27, 2024
8633d6b
update some docstrings
francois-rozet Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions zuko/flows/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ class BPF(MAF):
r"""Creates a Bernstein polynomial flow (BPF).

Warning:
Invertibility is only guaranteed for features within the interval :math:`[-10,
10]`. It is recommended to standardize features (zero mean, unit variance)
before training.
The Bernstein polynomial is bounded to the interval :math:`[-5, 5]`. Any feature
outside of this domain is not transformed. It is recommended to standardize
features (zero mean, unit variance) before training.

See also:
:class:`zuko.transforms.BernsteinTransform`
:class:`zuko.transforms.BoundedBernsteinTransform`

References:
| Deep transformation models: Tackling complex regression problems with neural network based transformation models (Sick et al., 2020)
Expand Down
71 changes: 37 additions & 34 deletions zuko/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,14 +604,13 @@ def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
class BernsteinTransform(MonotonicTransform):
r"""Creates a monotonic Bernstein polynomial transformation.

The transformation function is defined as:

.. math:: f(x) = \frac{1}{M + 1} \sum_{i=0}^{M} b_{i+1,M-i+1}(x) \, \theta_i
.. math:: f(x) = \frac{1}{M + 1} \sum_{i=0}^{M} b_{i+1,M-i+1} \left( \frac{x + B}{2B} \right) \, \theta_i

where :math:`b_{i,j}` are the Bernstein basis polynomials.

Since :math:`f` is only defined for :math:`x \in [0, 1]`, it is linearly extrapolated outside this interval.
The second-order derivative is enforced to be zero at the bounds for smooth transitions.
As the polynomial :math:`f(x)` is only defined for :math:`x \in [-B, B]`, the
transformation linearly extrapolates it outside this domain. The second derivative
at the bounds is enforced to be zero for smooth extrapolation.

References:
| Deep transformation models: Tackling complex regression problems with neural network based transformation models (Sick et al., 2020)
Expand All @@ -625,7 +624,8 @@ class BernsteinTransform(MonotonicTransform):

Arguments:
theta: The unconstrained polynomial coefficients :math:`\theta`,
with shape :math:`(*, M - 1)`.
with shape :math:`(*, M - 2)`.
bound: The polynomial's domain bound :math:`B`.
kwargs: Keyword arguments passed to :class:`MonotonicTransform`.
"""

Expand All @@ -634,29 +634,31 @@ class BernsteinTransform(MonotonicTransform):
bijective = True
sign = +1

def __init__(self, theta: Tensor, **kwargs):
super().__init__(None, phi=(theta,), **kwargs)
def __init__(self, theta: Tensor, bound: float = 5.0, **kwargs):
super().__init__(None, phi=(theta,), bound=bound, **kwargs)

self.theta = self._constrain_theta(theta)
self.basis = self._bernstein_basis(self.order, device=theta.device, dtype=theta.dtype)

# save slope on boundaries for interpolation
self.offset, self.slope = self._calculate_offset_and_slope()
self.offset, self.slope = self._offset_and_slope()

@property
def order(self) -> int:
return self.theta.shape[-1] - 1

def _calculate_offset_and_slope(self) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
def _offset_and_slope(self) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
r"""Calculates the offsets and slopes at the domain bounds for extrapolation."""

dtheta = self.order * (self.theta[..., 1:] - self.theta[..., :-1])
dbasis = self._bernstein_basis(
self.order - 1, device=self.theta.device, dtype=self.theta.dtype
)

bounds = [
torch.tensor(self.eps, device=self.theta.device, dtype=self.theta.dtype),
torch.tensor(1 - self.eps, device=self.theta.device, dtype=self.theta.dtype),
self.theta.new_tensor(self.eps),
self.theta.new_tensor(1 - self.eps),
]

offset = [self._bernstein_poly(x, self.theta, self.basis) for x in bounds]
slope = [self._bernstein_poly(x, dtheta, dbasis) for x in bounds]

Expand All @@ -665,6 +667,7 @@ def _calculate_offset_and_slope(self) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tens
@staticmethod
def _constrain_theta(unconstrained_theta: Tensor) -> Tensor:
"""Processes the unconstrained output of the hyper-network to be increasing."""

shift = math.log(2.0) * unconstrained_theta.shape[-1] / 2

theta_min = unconstrained_theta[..., :1]
Expand Down Expand Up @@ -739,34 +742,33 @@ def _inverse(self, y: Tensor) -> Tensor:


class BoundedBernsteinTransform(BernsteinTransform):
MArpogaus marked this conversation as resolved.
Show resolved Hide resolved
r"""Bounded version of :py:`BernsteinTransform`, optimized for chained Flows.

This subclass scales the Bernstein coefficients so that the transformation's domain and
codomain match the interval :math:`[-B, B]`, where :math:`B` represents the bounds of the
base class. It also ensures that the derivative at the boundaries is equal to 1
(:math:`Be'(0,1) = 1 \to M \cdot (\theta_1 - \theta_0) = M \cdot\Delta_0 \to \Delta_0 = 1/M`)
and that the second order derivative is zero
(:math:`Be''(0,1) = 0 \propto (\Delta_1 - \Delta_0) = M \Delta_0 = \Delta_1`),
ensuring a smooth transition to the identity function outside the bounds.

These conditions make the transformation particularly suitable for chaining
and are hence used in :py:`BPF`.
r"""Creates a bounded Bernstein polynomial transformation.

This subclass of :class:`BernsteinTransform` scales the Bernstein coefficients so
that the polynomial's domain and codomain match the interval :math:`[-B, B]`. It
also enforces that the first derivative at the bounds is one and that the second
derivative is zero, ensuring a smooth transition to the identity function outside
the bounds.

These constraints make the transformation suitable for chaining in flows.

Arguments:
theta: The unconstrained polynomial coefficients :math:`\theta`,
with shape :math:`(*, M - 5)`.
francois-rozet marked this conversation as resolved.
Show resolved Hide resolved
kwargs: Keyword arguments passed to :class:`BernsteinTransform`.
"""

def _constrain_theta(self, unconstrained_theta: Tensor) -> Tensor:
"""Processes the unconstrained output of the hyper-network to be increasing."""

theta_min = -self.bound * torch.ones_like(unconstrained_theta[..., :1])

diff_on_bounds = (2 * self.bound) / (unconstrained_theta.shape[-1] + 4)

diffs = torch.nn.functional.softmax(unconstrained_theta, dim=-1) * (
2 * self.bound - 4 * diff_on_bounds
)

# ensure identity on bounds by enforcing Be'(0,1) = 1 and Be''(0,1) = 0
# Be'(0) = order * theta_1 - theta_0 = order * diff_0 -> diff_0 = 1 / order
# Be''(0) = (order - 1) * (diff_1 - diff_0) -> diff_0 == diff_1
# Be''(0) = (order - 1) * (diff_1 - diff_0) -> diff_0 = diff_1
diffs = torch.cat(
(
theta_min,
Expand All @@ -779,14 +781,15 @@ def _constrain_theta(self, unconstrained_theta: Tensor) -> Tensor:

return torch.cumsum(diffs, dim=-1)

def _calculate_offset_and_slope(self) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
def _offset_and_slope(self) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
offset = (
torch.tensor(-self.bound, device=self.theta.device, dtype=self.theta.dtype),
torch.tensor(self.bound, device=self.theta.device, dtype=self.theta.dtype),
self.theta.new_tensor(-self.bound),
self.theta.new_tensor(self.bound),
)

slope = (
torch.tensor(2 * self.bound, device=self.theta.device, dtype=self.theta.dtype),
torch.tensor(2 * self.bound, device=self.theta.device, dtype=self.theta.dtype),
self.theta.new_tensor(2 * self.bound),
self.theta.new_tensor(2 * self.bound),
)

return offset, slope
Expand Down