Skip to content

Commit

Permalink
Use RandomStream in truncate
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo authored and brandonwillard committed Jan 17, 2023
1 parent 2c3d461 commit 564569f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
60 changes: 35 additions & 25 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from functools import singledispatch
from typing import Tuple
from typing import Optional, Tuple

import aesara.tensor as at
import aesara.tensor.random.basic as arb
import numpy as np
from aesara import scan, shared
from aesara import scan
from aesara.compile.builders import OpFromGraph
from aesara.graph.op import Op
from aesara.raise_op import CheckAndRaise
from aesara.scan import until
from aesara.tensor.random import RandomStream
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant, TensorVariable

Expand Down Expand Up @@ -68,7 +69,11 @@ def __str__(self):


def truncate(
rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None
rv: TensorVariable,
lower=None,
upper=None,
max_n_steps: int = 10_000,
srng: Optional[RandomStream] = None,
) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]:
"""Truncate a univariate `RandomVariable` between `lower` and `upper`.
Expand Down Expand Up @@ -99,13 +104,13 @@ def truncate(
lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)

if rng is None:
rng = shared(np.random.RandomState(), borrow=True)
if srng is None:
srng = RandomStream()

# Try to use specialized Op
try:
truncated_rv, updates = _truncated(
rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:]
rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:]
)
return truncated_rv, updates
except NotImplementedError:
Expand All @@ -116,8 +121,8 @@ def truncate(
# though it would not be necessary for the icdf OpFromGraph
graph_inputs = [*rv.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = graph_inputs_
rv_ = rv.owner.op.make_node(rng, *rv_inputs_).default_output()
size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_
rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_)

# Try to use inverted cdf sampling
try:
Expand All @@ -126,11 +131,10 @@ def truncate(
lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
cdf_upper_ = at.exp(logcdf(rv_, upper_))
uniform_ = at.random.uniform(
uniform_ = srng.uniform(
cdf_lower_,
cdf_upper_,
rng=rng,
size=rv_inputs_[0],
size=size_,
)
truncated_rv_ = icdf(rv_, uniform_)
truncated_rv = TruncatedRV(
Expand All @@ -146,27 +150,23 @@ def truncate(

# Fallback to rejection sampling
# TODO: Handle potential broadcast by lower / upper
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs
def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs):
new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore
truncated_rv = at.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))

return (
(truncated_rv, reject_draws),
[(rng, next_rng)],
until(~at.any(reject_draws)),
)
return (truncated_rv, reject_draws), until(~at.any(reject_draws))

(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
at.zeros_like(rv_),
at.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, rng, *rv_inputs_],
non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)
Expand All @@ -180,18 +180,28 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
truncated_rv = TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[truncated_rv_, tuple(updates.values())[0]],
# This will fail with `n_steps==1`, because in that case `Scan` won't return any updates
outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]],
inline=True,
)(*graph_inputs)
updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]}
# TODO: Is the order of multiple shared variables determnistic?
assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0]
updates = {
truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2],
truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1],
}
return truncated_rv, updates


@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
(value,) = values

*rv_inputs, lower_bound, upper_bound, rng = inputs
# Rejection sample graph has two rngs
if len(op.shared_inputs) == 2:
*rv_inputs, lower_bound, upper_bound, _, rng = inputs
else:
*rv_inputs, lower_bound, upper_bound, rng = inputs
rv_inputs = [rng, *rv_inputs]

base_rv_op = op.base_rv_op
Expand Down Expand Up @@ -242,11 +252,11 @@ def truncated_logprob(op, values, *inputs, **kwargs):


@_truncated.register(arb.UniformRV)
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
truncated_uniform = at.random.uniform(
def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig):
truncated_uniform = srng.gen(
op,
at.max((lower_orig, lower)),
at.min((upper_orig, upper)),
rng=rng,
size=size,
dtype=dtype,
)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _icdf_not_implemented(*args, **kwargs):
def test_truncation_specialized_op():
x = at.random.uniform(0, 10, name="x", size=100)

rng = aesara.shared(np.random.RandomState())
xt, _ = truncate(x, lower=5, upper=15, rng=rng)
srng = at.random.RandomStream()
xt, _ = truncate(x, lower=5, upper=15, srng=srng)
assert isinstance(xt.owner.op, UniformRV)
assert xt.owner.inputs[0] is rng
assert xt.owner.inputs[0] is srng.updates()[0][0]

lower_upper = at.stack(xt.owner.inputs[3:])
assert np.all(lower_upper.eval() == [5, 10])
Expand All @@ -68,10 +68,10 @@ def test_truncation_continuous_random(op_type, lower, upper):
normal_op = icdf_normal if op_type == "icdf" else rejection_normal
x = normal_op(loc, scale, name="x", size=100)

rng = aesara.shared(np.random.RandomState())
xt, xt_update = truncate(x, lower=lower, upper=upper, rng=rng)
srng = at.random.RandomStream()
xt, xt_update = truncate(x, lower=lower, upper=upper, srng=srng)
assert isinstance(xt.owner.op, TruncatedRV)
assert xt.owner.inputs[-1] is rng
assert xt.owner.inputs[-1] is srng.updates()[1 if op_type == "icdf" else 2][0]
assert xt.type.dtype == x.type.dtype
assert xt.type.ndim == x.type.ndim

Expand All @@ -94,7 +94,7 @@ def test_truncation_continuous_random(op_type, lower, upper):
assert scipy.stats.cramervonmises(xt_draws.ravel(), ref_xt.cdf).pvalue > 0.001

# Test max_n_steps
xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=1)
xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=2)
xt_fn = aesara.function([], xt, updates=xt_update)
if op_type == "icdf":
xt_draws = xt_fn()
Expand Down

0 comments on commit 564569f

Please sign in to comment.