-
-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: implement
IfLifting
structural simplification pass
- Loading branch information
1 parent
63d2658
commit d5b773a
Showing
2 changed files
with
392 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,391 @@ | ||
""" | ||
struct CondRewriter | ||
Callable struct used to transform symbolic conditions into conditions involving discrete | ||
variables. | ||
""" | ||
struct CondRewriter | ||
""" | ||
The independent variable which the discrete variables depend on. | ||
""" | ||
iv::BasicSymbolic | ||
""" | ||
A mapping from a discrete variables to a `NamedTuple` containing the condition | ||
determining whether the discrete variable needs to be evaluated and the symbolic | ||
expression the discrete variable represents. The expression is a comparison operation | ||
such that the LHS of the comparison is used as a rootfinding function, and | ||
zero-crossings trigger re-evaluation of the condition (if `dependency` is `true`). | ||
""" | ||
conditions::Dict{Any, @NamedTuple{dependency, expression}} | ||
end | ||
|
||
function CondRewriter(iv) | ||
return CondRewriter(iv, Dict()) | ||
end | ||
|
||
""" | ||
A function which transforms comparison operations of the form `var op var` into | ||
`var - var op 0`. | ||
""" | ||
const COMPARISON_TRANSFORM = unwrap ∘ SymbolicUtils.Rewriters.Chain([ | ||
(@rule (~a) < (~b) => ~a - ~b < 0), | ||
(@rule (~a) > (~b) => ~a - ~b > 0), | ||
(@rule (~a) <= (~b) => ~a - ~b <= 0), | ||
(@rule (~a) >= (~b) => ~a - ~b >= 0), | ||
]) | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Given a symbolic condition `expr` and the condition `dep` it depends on, update the | ||
mapping in `cw` and generate a new discrete variable if necessary. | ||
""" | ||
function new_cond_sym(cw::CondRewriter, expr, dep) | ||
# check if the same expression exists in the mapping | ||
existing_var = findfirst(p -> isequal(p.expression, expr), cw.conditions) | ||
if existing_var !== nothing | ||
# cache hit | ||
(existing_dep, _) = cw.conditions[existing_var] | ||
# update the dependency condition | ||
cw.conditions[existing_var] = (dependency=(dep | existing_dep), expression=expr) | ||
return existing_var | ||
end | ||
# generate a new condition variable | ||
cvar = gensym("cond") | ||
st = symtype(expr) | ||
iv = cw.iv | ||
cv = first(@parameters $(cvar)(iv)::st = true) # TODO: real init | ||
cw.conditions[cv] = (dependency=dep, expression=expr) | ||
return cv | ||
end | ||
|
||
""" | ||
A list of comparison operations. | ||
""" | ||
const COMPARISONS = Set([Base.:<, Base.:>, Base.:<=, Base.:>=]) | ||
|
||
""" | ||
Utility function for boolean implication. | ||
""" | ||
implies(a, b) = !a & b | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Recursively rewrite conditions into discrete variables. `expr` is the condition to rewrite, | ||
`dep` is a boolean expression/value which determines when the `expr` is to be evaluated. For | ||
example, if `expr = expr1 | expr2` and `dep = dep1`, then `expr` should only be evaluated if | ||
`dep1` evaluates to `true`. Recursively, `expr1` should only be evaluated if `dep1` is `true`, | ||
and `expr2` should only be evaluated if `dep & !expr1`. | ||
Returns a 3-tuple of the substituted expression, a condition describing when `expr` evaluates | ||
to `true`, and a condition describing when `expr` evaluates to `false`. | ||
""" | ||
function (cw::CondRewriter)(expr, dep) | ||
# single variable, trivial case | ||
if issym(expr) || iscall(expr) && issym(operation(expr)) | ||
return (expr, expr, !expr) | ||
# literal boolean or integer | ||
elseif expr isa Bool | ||
return (expr, expr, !expr) | ||
elseif expr isa Int | ||
return (expr, true, true) | ||
# other singleton symbolic variables | ||
elseif !iscall(expr) | ||
@warn "Automatic conversion of if statments to events requires use of a limited conditional grammar; see the documentation. Skipping due to $expr" | ||
return (expr, true, true) # error case => conservative assumption is that both true and false have to be evaluated | ||
elseif operation(expr) == Base.:(|) # OR of two conditions | ||
a, b = arguments(expr) | ||
(rw_conda, truea, falsea) = cw(a, dep) | ||
# only evaluate second if first is false | ||
(rw_condb, trueb, falseb) = cw(b, dep & falsea) | ||
return (rw_conda | rw_condb, truea | trueb, falsea & falseb) | ||
|
||
elseif operation(expr) == Base.:(&) # AND of two conditions | ||
a, b = arguments(expr) | ||
(rw_conda, truea, falsea) = cw(a, dep) | ||
# only evaluate second if first is true | ||
(rw_condb, trueb, falseb) = cw(b, dep & truea) | ||
return (rw_conda & rw_condb, truea & trueb, falsea | falseb) | ||
elseif operation(expr) == ifelse | ||
c, a, b = arguments(expr) | ||
(rw_cond, ctrue, cfalse) = cw(c, dep) | ||
# only evaluate if condition is true | ||
(rw_conda, truea, falsea) = cw(a, dep & ctrue) | ||
# only evaluate if condition is false | ||
(rw_condb, trueb, falseb) = cw(b, dep & cfalse) | ||
# expression is true if condition is true and THEN branch is true, or condition is false | ||
# and ELSE branch is true | ||
# similarly for expression being false | ||
return (ifelse(rw_cond, rw_conda, rw_condb), implies(ctrue, truea) | implies(cfalse, trueb), implies(ctrue, falsea) | implies(cfalse, falseb)) | ||
elseif operation(expr) == Base.:(!) # NOT of expression | ||
(a,) = arguments(expr) | ||
(rw, ctrue, cfalse) = cw(a, dep) | ||
return (!rw, cfalse, ctrue) | ||
elseif operation(expr) in COMPARISONS # comparison operators | ||
# turn int `var - var op 0` | ||
expr = COMPARISON_TRANSFORM(expr) | ||
# a new discrete variable to represent `var - var op 0` | ||
cv = new_cond_sym(cw, expr, dep) | ||
return (cv, cv, !cv) | ||
elseif operation(expr) == (==) | ||
# we don't touch equality since it's a point discontinuity. It's basically always | ||
# false for continuous variables. In case it's an equality between discrete | ||
# quantities, we don't need to transform it. | ||
return (expr, expr, !expr) | ||
end | ||
error("Unsupported expression form in decision variable computation $expr") | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Acts as the identity function, and prevents transformation of conditional expressions inside it. Useful | ||
if specific `ifelse` or other functions with discontinuous derivatives shouldn't be transformed into | ||
callbacks. | ||
""" | ||
no_if_lift(s) = s | ||
@register_symbolic no_if_lift(s) | ||
|
||
""" | ||
$(TYPEDEF) | ||
A utility struct to search through an expression specifically for `ifelse` terms, and find | ||
all variables used in the condition of such terms. The variables are stored in a field of | ||
the struct. | ||
""" | ||
struct VarsUsedInCondition | ||
""" | ||
Stores variables used in conditions of `ifelse` statements in the expression. | ||
""" | ||
vars::Set{Any} | ||
end | ||
|
||
VarsUsedInCondition() = VarsUsedInCondition(Set()) | ||
|
||
function (v::VarsUsedInCondition)(expr) | ||
expr = Symbolics.unwrap(expr) | ||
if symbolic_type(expr) == NotSymbolic() | ||
is_array_of_symbolics(expr) || return | ||
foreach(v, expr) | ||
return | ||
end | ||
iscall(expr) || return | ||
op = operation(expr) | ||
|
||
# do not search inside no_if_lift to avoid discovering | ||
# redundant variables | ||
op == no_if_lift && return | ||
|
||
args = arguments(expr) | ||
if op == ifelse | ||
cond, branch_a, branch_b = arguments(expr) | ||
vars!(v.vars, cond) | ||
v(branch_a) | ||
v(branch_b) | ||
end | ||
foreach(v, args) | ||
return | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Given an expression `expr` which is to be evaluated if `dep` evaluates to `true`, transform | ||
the conditions of all all `ifelse` statements in `expr` into functions of new discrete | ||
variables. `cw` is used to store the information relevant to these newly introduced variables. | ||
""" | ||
function rewrite_ifs(cw::CondRewriter, expr, dep) | ||
expr = unwrap(expr) | ||
if symbolic_type(expr) == NotSymbolic() | ||
is_array_of_symbolics(expr) || return expr | ||
return map(expr) do ex | ||
rewrite_ifs(cw, ex, dep) | ||
end | ||
end | ||
iscall(expr) || return expr | ||
op = operation(expr) | ||
# don't recurse into singleton variables or places where the user doesn't want if-lifting | ||
(issym(op) || op == no_if_lift) && return expr | ||
args = arguments(expr) | ||
|
||
# transform `ifelse` that don't depend on a single symbolic variable. | ||
if op == ifelse && (!issym(args[1]) || iscall(args[1]) && !issym(operation(args[1]))) | ||
cond, iftrue, iffalse = args | ||
(rw_cond, deptrue, depfalse) = cw(cond, dep) | ||
rw_iftrue = rewrite_ifs(cw, iftrue, deptrue) | ||
rw_iffalse = rewrite_ifs(cw, iffalse, depfalse) | ||
return ifelse(unwrap(rw_cond), rw_iftrue, rw_iffalse) | ||
end | ||
# recursively rewrite | ||
return maketerm(typeof(expr), op, map(x -> rewrite_ifs(cw, x, dep), args), metadata(expr)) | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Return a modified `expr` where functions with known discontinuities or discontinuous | ||
derivatives are transformed into `ifelse` statements. Utilizes the discontinuity API | ||
in Symbolics. See [`Symbolics.rootfunction`](@ref), | ||
[`Symbolics.left_continuous_function`](@ref), [`Symbolics.right_continuous_function`](@ref). | ||
""" | ||
function discontinuities_to_ifelse(expr) | ||
if symbolic_type(expr) == NotSymbolic() | ||
is_array_of_symbolics(expr) || return expr | ||
return map(discontinuities_to_ifelse, expr) | ||
end | ||
iscall(expr) || return expr | ||
op = operation(expr) | ||
# don't transform inside `no_if_lift` | ||
(issym(op) || op === no_if_lift) && return expr | ||
args = arguments(expr) | ||
args = map(discontinuities_to_ifelse, args) | ||
# if the operation is a known discontinuity | ||
if hasmethod(Symbolics.rootfunction, Tuple{typeof(op)}) | ||
rootfn = Symbolics.rootfunction(op) | ||
leftfn = Symbolics.left_continuous_function(op) | ||
rightfn = Symbolics.right_continuous_function(op) | ||
rootexpr = rootfn(args...) < 0 | ||
leftexpr = leftfn(args...) | ||
rightexpr = rightfn(args...) | ||
return ifelse(rootexpr, leftexpr, rightexpr) | ||
end | ||
return maketerm(typeof(expr), op, args, Symbolics.metadata(expr)) | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Generate the symbolic condition for discrete variable `sym`, which represents the condition | ||
of an `ifelse` statement created through [`IfLifting`](@ref). This condition is used to | ||
trigger a callback which updates the value of the condition appropriately. | ||
""" | ||
function generate_condition(cw::CondRewriter, sym) | ||
(dep, uexpr) = cw.conditions[sym] | ||
# `uexpr` is a comparison, the LHS is the zero-crossing function | ||
zero_crossing = arguments(uexpr)[1] | ||
# if we're meant to evaluate the condition, evaluate it. Otherwise, return `NaN`. | ||
# the solvers don't treat the transition from a number to NaN or back as a zero-crossing, | ||
# so it can be used to effectively disable the affect when the condition is not meant to | ||
# be evaluated. | ||
return ifelse(dep, arguments(uexpr)[1], NaN) ~ 0 | ||
end | ||
|
||
""" | ||
$(TYPEDSIGNATURES) | ||
Generate the affect function for discrete variable `sym` involved in `ifelse` statements that | ||
are lifted to callbacks using [`IfLifting`](@ref). `syms` is a condition variable introduced | ||
by `cw`, and is thus a key in `cw.conditions`. `new_cond_vars` is the list of all such new | ||
condition variables, corresponding to the order of vertices in `new_cond_vars_graph`. | ||
`new_cond_vars_graph` is a directed graph where edges denote the condition variables involved | ||
in the dependency expression of the source vertex. | ||
""" | ||
function generate_affect(cw::CondRewriter, sym, new_cond_vars, new_cond_vars_graph) | ||
sym_idx = findfirst(isequal(sym), new_cond_vars) | ||
if sym_idx === nothing | ||
throw(ArgumentError("Expected variable $sym to be a condition variable in $new_cond_vars.")) | ||
end | ||
# use reverse direction of edges because instead of finding the variables it depends | ||
# on, we want the variables that depend on it | ||
parents = bfs_parents(new_cond_vars_graph, sym_idx; dir = :in) | ||
cond_vars_to_update = [new_cond_vars[i] for i in eachindex(parents) if !iszero(parents[i])] | ||
update_syms = Symbol.(cond_vars_to_update) | ||
update_exprs = [last(cw.conditions[sym]) for sym in cond_vars_to_update] | ||
return ImperativeAffect(modified=NamedTuple{(update_syms...,)}(cond_vars_to_update), observed=NamedTuple{(update_syms...,)}(update_exprs), skip_checks=true) do x, o, c, i | ||
x .= o | ||
end | ||
end | ||
|
||
""" | ||
If lifting converts (nested) if statements into a series of continous events + a logically equivalent if statement + parameters. | ||
Lifting proceeds through the following process: | ||
* rewrite comparisons to be of the form eqn [op] 0; subtract the RHS from the LHS | ||
* replace comparisons with generated parameters; for each comparison eqn [op] 0, generate an event (dependent on op) that sets the parameter | ||
""" | ||
function IfLifting(sys::ODESystem) | ||
cw = CondRewriter(get_iv(sys)) | ||
|
||
eqs = copy(equations(sys)) | ||
obs = copy(observed(sys)) | ||
|
||
# get variables used by `eqs` | ||
syms = vars(eqs) | ||
# get observed equations used by `eqs` | ||
obs_idxs = observed_equations_used_by(sys, eqs; involved_vars = syms) | ||
# and the variables used in those equations | ||
for i in obs_idxs | ||
vars!(syms, obs[i]) | ||
end | ||
|
||
# get all integral variables used in conditions | ||
# this is used when performing the transformation on observed equations | ||
# since they are transformed differently depending on whether they are | ||
# discrete variables involved in a condition or not | ||
condition_vars = Set() | ||
# searcher struct | ||
# we can use the same one since it avoids iterating over duplicates | ||
vars_in_condition! = VarsUsedInCondition() | ||
for i in eachindex(eqs) | ||
eq = eqs[i] | ||
vars_in_condition!(eq.rhs) | ||
# also transform the equation | ||
eqs[i] = eq.lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true) | ||
end | ||
# also search through relevant observed equations | ||
for i in obs_idxs | ||
vars_in_condition!(obs[i].rhs) | ||
end | ||
# add to `condition_vars` after filtering out differential, parameter, independent and | ||
# non-integral variables | ||
for v in vars_in_condition!.vars | ||
v = unwrap(v) | ||
stype = symtype(v) | ||
if isdifferential(v) || isparameter(v) || isequal(v, get_iv(sys)) | ||
continue | ||
end | ||
stype <: Union{Integer, AbstractArray{Integer}} && push!(condition_vars, v) | ||
end | ||
# transform observed equations | ||
for i in obs_idxs | ||
obs[i] = if obs[i].lhs in condition_vars | ||
obs[i].lhs ~ first(cw(obs[i].rhs, true)) | ||
else | ||
obs[i].lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true) | ||
end | ||
end | ||
|
||
# get directed graph where nodes are the new condition variables and edges from each | ||
# node denote the condition variables used in it's dependency expression | ||
|
||
# so we have an ordering for the vertices | ||
new_cond_vars = collect(keys(cw.conditions)) | ||
# "observed" equations | ||
new_cond_dep_eqs = [v ~ cw.conditions[v] for v in new_cond_vars] | ||
# construct the graph as a `DiCMOBiGraph` | ||
new_cond_vars_graph = observed_dependency_graph(new_cond_dep_eqs) | ||
|
||
new_callbacks = continuous_events(sys) | ||
new_defaults = defaults(sys) | ||
new_ps = parameters(sys) | ||
|
||
for var in new_cond_vars | ||
condition = generate_condition(cw, var) | ||
affect = generate_affect(cw, var, new_cond_vars, new_cond_vars_graph) | ||
cb = SymbolicContinuousCallback([condition], affect; affect_neg=affect, initialize=affect, rootfind=SciMLBase.RightRootFind) | ||
|
||
push!(new_callbacks, cb) | ||
new_defaults[var] = getdefault(var) | ||
push!(new_ps, var) | ||
end | ||
|
||
@set! sys.defaults = new_defaults | ||
@set! sys.eqs = eqs | ||
# do not need to topsort because we didn't modify the order | ||
@set! sys.observed = obs | ||
@set! sys.continuous_events = new_callbacks | ||
@set! sys.ps = new_ps | ||
return sys | ||
end | ||
|