Skip to content

Commit

Permalink
Implement truncated variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 6, 2022
1 parent 0624b6b commit 773dc50
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _logcdf(


@singledispatch
def _icdf(op: Op, rv: TensorVariable, value, *params):
def _icdf(op: Op, value, *params):
"""Create a graph for the icdf of a ``RandomVariable``.
This function dispatches on the type of ``op``, which should be a subclass
Expand Down
139 changes: 138 additions & 1 deletion aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
import warnings
from functools import singledispatch
from typing import List, Optional

import aesara.tensor as at
import aesara.tensor.random.basic as arb
import numpy as np
from aesara import scan
from aesara.compile.builders import OpFromGraph
from aesara.graph import Op
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.raise_op import Assert
from aesara.scalar.basic import Clip
from aesara.scalar.basic import clip as scalar_clip
from aesara.scan import until
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob, icdf, logcdf
from aeppl.opt import rv_sinking_db


Expand Down Expand Up @@ -123,3 +131,132 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
)

return logprob


class TruncatedRV(OpFromGraph):
"""An `Op` constructed from an Aesara graph that represents a truncated univariate RV."""

default_output = 1
base_rv_op = None

def __init__(self, base_rv_op: Op, *args, **kwargs):
self.base_rv_op = base_rv_op
super().__init__(*args, **kwargs)


@singledispatch
def _truncated(op: Op, lower, upper, *params):
"""Return the truncated equivalent of another ``RandomVariable``."""
raise NotImplementedError(
f"{op} does not have an equivalent truncated version implemented"
)


@_truncated.register(arb.UniformRV)
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
return at.random.uniform(
at.max((lower_orig, lower)),
at.min((upper_orig, upper)),
rng=rng,
size=size,
dtype=dtype,
)


def truncate(rv, lower=None, upper=None, max_n_steps=10_000):

lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
if lower is None and upper is None:
raise ValueError("lower and upper cannot both be None")

# Try to use specialized Op
try:
return _truncated(rv.owner.op, lower, upper, *rv.owner.inputs)
except NotImplementedError:
pass

if not isinstance(rv.owner.op, RandomVariable):
raise ValueError("truncation not implemented for Op {rv.owner.op}")

if rv.owner.op.ndim_supp > 0:
raise NotImplementedError(
"truncation not implemented for multivariate variables"
)

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs = [*rv.owner.inputs, lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = graph_inputs_

# Try to use inverted cdf sampling
try:
rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()
cdf_lower_ = at.exp(logcdf(rv_, lower_))
cdf_upper_ = at.exp(logcdf(rv_, upper_))
uniform_ = at.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rv_inputs_[0],
size=rv_inputs_[1],
)
truncated_rv_ = icdf(rv_, uniform_)
return TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[uniform_.owner.outputs[0], truncated_rv_],
inline=True,
)(*graph_inputs)
except NotImplementedError:
pass

# Fallback to rejection sampling

# Scan forces us to use a shared variable for the RNG
rng = rv.owner.inputs[0]
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = (rng, *graph_inputs_)
# TODO: Handle potential broadcast by lower / upper
rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()

def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
# TODO: Very unsure about this way of specifying the new_draws
next_rng, new_truncated_rv = rv.owner.op.make_node(*rv_inputs).outputs

rng = next_rng.owner.inputs[0]
rng.tag.is_rng = True
rng.default_update = next_rng
new_truncated_rv.rng = rng
new_truncated_rv.update = (rng, next_rng)

truncated_rv = at.set_subtensor(
truncated_rv[reject_draws], new_truncated_rv[reject_draws]
)
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))
return (truncated_rv, reject_draws), until(~at.any(reject_draws))

(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
at.empty_like(rv_),
at.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv_ = truncated_rv_[-1]
convergence_ = ~at.any(reject_draws_[-1])
truncated_rv_ = Assert(
f"truncation did not converge in predefined {max_n_steps} steps"
)(truncated_rv_, convergence_)

return TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[tuple(updates.values())[0], truncated_rv_],
inline=True,
strict=True,
)(*graph_inputs)
50 changes: 50 additions & 0 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pytest
import scipy as sp
import scipy.stats as st
from aesara.tensor.random.basic import UniformRV

from aeppl import factorized_joint_logprob, joint_logprob
from aeppl.transforms import LogTransform, TransformValuesOpt
from aeppl.truncation import TruncatedRV, truncate
from tests.utils import assert_no_rvs


Expand Down Expand Up @@ -189,3 +191,51 @@ def test_censored_transform():
)

assert np.isclose(obs_logp, exp_logp)


def test_truncation_specialized_op():
x = at.random.uniform(0, 10, name="x", size=100)
xt = truncate(x, lower=5, upper=15)

assert isinstance(xt.owner.op, UniformRV)

lower_upper = at.stack(xt.owner.inputs[3:])
assert np.all(lower_upper.eval() == [5, 10])


# TODO: Compare with expected CDF
def test_truncation_icdf():
x = at.random.exponential(10, name="x", size=100)
xt = truncate(x, lower=1, upper=2, max_n_steps=0)

assert isinstance(xt.owner.op, TruncatedRV)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[0]: xt.owner.outputs[0]})
xt_draws = np.array([xt_fn() for _ in range(5)])
assert np.all((xt_draws > 1) & (xt_draws < 2))
assert np.unique(xt_draws).size == xt_draws.size


# TODO: Compare with expected CDF
def test_truncation_rejection_sampling():
# TODO: Create dummy RV to make sure this test always uses rejection sampling branch
x = at.random.normal(0, 10, name="x", size=100)
xt = truncate(x, lower=-1, upper=1)

assert isinstance(xt.owner.op, TruncatedRV)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[-1]: xt.owner.outputs[0]})
xt_draws = np.array([xt_fn() for _ in range(5)])
assert np.all(np.abs(xt_draws) < 1)
assert np.unique(xt_draws).size == xt_draws.size


def test_truncation_rejection_sampling_convergence_check():
# TODO: Create dummy RV to make sure this test always uses rejection sampling branch
x = at.random.normal(0, 10, name="x", size=100)
# TODO: truncate call fails with max_n_steps < 2
xt = truncate(x, lower=-1, upper=1, max_n_steps=2)

xt_fn = aesara.function([], xt, updates={xt.owner.inputs[-1]: xt.owner.outputs[0]})
with pytest.raises(AssertionError):
xt_fn()

0 comments on commit 773dc50

Please sign in to comment.