Skip to content

Commit

Permalink
Implement truncated variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 6, 2022
1 parent 0624b6b commit cc14e41
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _logcdf(


@singledispatch
def _icdf(op: Op, rv: TensorVariable, value, *params):
def _icdf(op: Op, value, *params):
"""Create a graph for the icdf of a ``RandomVariable``.
This function dispatches on the type of ``op``, which should be a subclass
Expand Down
134 changes: 133 additions & 1 deletion aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import warnings
from functools import singledispatch
from typing import List, Optional

import aesara.tensor as at
import aesara.tensor.random.basic as arb
import numpy as np
from aesara import scan
from aesara.compile.builders import OpFromGraph
from aesara.graph import Op
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.raise_op import Assert
from aesara.scalar.basic import Clip
from aesara.scalar.basic import clip as scalar_clip
from aesara.scan import until
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, icdf, logcdf
from aeppl.opt import rv_sinking_db


Expand Down Expand Up @@ -123,3 +131,127 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
)

return logprob


@singledispatch
def _truncated(op: Op, lower, upper, *params):
"""Return the truncated equivalent of antoher ``RandomVariable``.
The returned Variable should also be a ``RandomVariable``
"""
raise NotImplementedError(
f"{op} does not have an equivalent truncated version implemented"
)


@_truncated.register(arb.UniformRV)
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
return at.random.uniform(
at.max((lower_orig, lower)),
at.min((upper_orig, upper)),
rng=rng,
size=size,
dtype=dtype,
)


class TruncatedRV(OpFromGraph):
"""An `Op` constructed from an Aesara graph that represents a truncated univariate RV."""

default_output = 1


def truncate(rv, lower=None, upper=None, max_n_steps=10_000):

if not rv.owner and isinstance(rv.owner.op, RandomVariable):
raise ValueError("truncate rv must be a RandomVariable")
if rv.owner.op.ndim_supp > 0:
raise NotImplementedError(
"truncation not implemented for multivariate variables"
)
if lower is None and upper is None:
raise ValueError("lower and upper cannot both be None")

lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
graph_inputs = [*rv.owner.inputs, lower, upper]

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = graph_inputs_

try:
# Use specialized Op
truncated_rv_ = _truncated(rv.owner.op, lower_, upper_, *rv_inputs_)
return TruncatedRV(
inputs=graph_inputs_,
outputs=truncated_rv_.owner.outputs,
inline=True,
)(*graph_inputs)
except NotImplementedError:
try:
# Use inverted cdf sampling
rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()
cdf_lower_ = at.exp(logcdf(rv_, lower_))
cdf_upper_ = at.exp(logcdf(rv_, upper_))
uniform_ = at.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rv_inputs_[0],
size=rv_inputs_[1],
)
truncated_rv_ = icdf(rv_, uniform_)
return TruncatedRV(
inputs=graph_inputs_,
outputs=[uniform_.owner.outputs[0], truncated_rv_],
inline=True,
)(*graph_inputs)
except NotImplementedError:
# Use rejection sampling

# Scan forces us to use a shared variable for the RNG
rng = rv.owner.inputs[0]
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = (rng, *graph_inputs_)
rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()

def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
# TODO: Very unsure about this way of specifying the new_draws
next_rng, new_truncated_rv = rv.owner.op.make_node(*rv_inputs).outputs

rng = next_rng.owner.inputs[0]
rng.tag.is_rng = True
new_truncated_rv.rng = rng
new_truncated_rv.update = (rng, next_rng)
rng.default_update = next_rng

truncated_rv = at.set_subtensor(
truncated_rv[reject_draws], new_truncated_rv[reject_draws]
)
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
return (truncated_rv, reject_draws), until(~at.any(reject_draws))

(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
at.empty_like(rv_),
at.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv_ = truncated_rv_[-1]
convergence_ = ~at.any(reject_draws_[-1])
truncated_rv_ = Assert(
f"truncation did not converge in predefined {max_n_steps} steps"
)(truncated_rv_, convergence_)

return TruncatedRV(
inputs=graph_inputs_,
outputs=[tuple(updates.values())[0], truncated_rv_],
inline=True,
strict=True,
)(*graph_inputs)
49 changes: 49 additions & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from aeppl import factorized_joint_logprob, joint_logprob
from aeppl.transforms import LogTransform, TransformValuesOpt
from aeppl.truncation import TruncatedRV, truncate
from tests.utils import assert_no_rvs


Expand Down Expand Up @@ -189,3 +190,51 @@ def test_censored_transform():
)

assert np.isclose(obs_logp, exp_logp)


def test_truncation_specialized_op():
x = at.random.uniform(0, 10, name="x", size=100)
xt = truncate(x, lower=5, upper=15)

assert isinstance(xt.owner.op, TruncatedRV)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[0]: xt.owner.outputs[0]})
xt_draws = np.array([xt_fn() for _ in range(5)])
assert np.all((xt_draws > 5) & (xt_draws < 10))
assert np.all(np.diff(xt_draws[0], axis=0) != 0)


def test_truncation_icdf():
x = at.random.exponential(10, name="x", size=100)
xt = truncate(x, lower=1, upper=2)

assert isinstance(xt.owner.op, TruncatedRV)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[0]: xt.owner.outputs[0]})
xt_draws = np.array([xt_fn() for _ in range(5)])
assert np.all((xt_draws > 1) & (xt_draws < 2))
assert np.all(np.diff(xt_draws[0], axis=0) != 0)


def test_truncation_rejection_sampling():
# TODO: Create dummy RV to make sure this test always uses rejection sampling branch
x = at.random.normal(0, 10, name="x", size=100)
xt = truncate(x, lower=-1, upper=1)

assert isinstance(xt.owner.op, TruncatedRV)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[-1]: xt.owner.outputs[0]})
xt_draws = np.array([xt_fn() for _ in range(5)])
assert np.all(np.abs(xt_draws) < 1)
assert np.all(np.diff(xt_draws[0], axis=0) != 0)


def test_truncation_rejection_sampling_convergence_fails():
# TODO: Create dummy RV to make sure this test always uses rejection sampling branch
x = at.random.normal(0, 10, name="x", size=100)
# TODO: truncate call fails with max_n_steps < 2
xt = truncate(x, lower=-1, upper=1, max_n_steps=2)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[-1]: xt.owner.outputs[0]})
with pytest.raises(AssertionError):
xt_fn()

0 comments on commit cc14e41

Please sign in to comment.