Skip to content

Commit

Permalink
feat: implement IfLifting structural simplification pass
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 13, 2024
1 parent 63d2658 commit d5b773a
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ include("discretedomain.jl")
include("systems/systemstructure.jl")
include("systems/clock_inference.jl")
include("systems/systems.jl")
include("systems/if_lifting.jl")

include("debugging.jl")
include("systems/alias_elimination.jl")
Expand Down
391 changes: 391 additions & 0 deletions src/systems/if_lifting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
"""
struct CondRewriter
Callable struct used to transform symbolic conditions into conditions involving discrete
variables.
"""
struct CondRewriter
"""
The independent variable which the discrete variables depend on.
"""
iv::BasicSymbolic
"""
A mapping from a discrete variables to a `NamedTuple` containing the condition
determining whether the discrete variable needs to be evaluated and the symbolic
expression the discrete variable represents. The expression is a comparison operation
such that the LHS of the comparison is used as a rootfinding function, and
zero-crossings trigger re-evaluation of the condition (if `dependency` is `true`).
"""
conditions::Dict{Any, @NamedTuple{dependency, expression}}
end

function CondRewriter(iv)
return CondRewriter(iv, Dict())
end

"""
A function which transforms comparison operations of the form `var op var` into
`var - var op 0`.
"""
const COMPARISON_TRANSFORM = unwrap SymbolicUtils.Rewriters.Chain([
(@rule (~a) < (~b) => ~a - ~b < 0),
(@rule (~a) > (~b) => ~a - ~b > 0),
(@rule (~a) <= (~b) => ~a - ~b <= 0),
(@rule (~a) >= (~b) => ~a - ~b >= 0),
])

"""
$(TYPEDSIGNATURES)
Given a symbolic condition `expr` and the condition `dep` it depends on, update the
mapping in `cw` and generate a new discrete variable if necessary.
"""
function new_cond_sym(cw::CondRewriter, expr, dep)
# check if the same expression exists in the mapping
existing_var = findfirst(p -> isequal(p.expression, expr), cw.conditions)
if existing_var !== nothing
# cache hit
(existing_dep, _) = cw.conditions[existing_var]
# update the dependency condition
cw.conditions[existing_var] = (dependency=(dep | existing_dep), expression=expr)
return existing_var
end
# generate a new condition variable
cvar = gensym("cond")
st = symtype(expr)
iv = cw.iv
cv = first(@parameters $(cvar)(iv)::st = true) # TODO: real init
cw.conditions[cv] = (dependency=dep, expression=expr)
return cv
end

"""
A list of comparison operations.
"""
const COMPARISONS = Set([Base.:<, Base.:>, Base.:<=, Base.:>=])

"""
Utility function for boolean implication.
"""
implies(a, b) = !a & b

"""
$(TYPEDSIGNATURES)
Recursively rewrite conditions into discrete variables. `expr` is the condition to rewrite,
`dep` is a boolean expression/value which determines when the `expr` is to be evaluated. For
example, if `expr = expr1 | expr2` and `dep = dep1`, then `expr` should only be evaluated if
`dep1` evaluates to `true`. Recursively, `expr1` should only be evaluated if `dep1` is `true`,
and `expr2` should only be evaluated if `dep & !expr1`.
Returns a 3-tuple of the substituted expression, a condition describing when `expr` evaluates
to `true`, and a condition describing when `expr` evaluates to `false`.
"""
function (cw::CondRewriter)(expr, dep)
# single variable, trivial case
if issym(expr) || iscall(expr) && issym(operation(expr))
return (expr, expr, !expr)
# literal boolean or integer
elseif expr isa Bool
return (expr, expr, !expr)
elseif expr isa Int
return (expr, true, true)
# other singleton symbolic variables
elseif !iscall(expr)
@warn "Automatic conversion of if statments to events requires use of a limited conditional grammar; see the documentation. Skipping due to $expr"

Check warning on line 95 in src/systems/if_lifting.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"statments" should be "statements".
return (expr, true, true) # error case => conservative assumption is that both true and false have to be evaluated
elseif operation(expr) == Base.:(|) # OR of two conditions
a, b = arguments(expr)
(rw_conda, truea, falsea) = cw(a, dep)
# only evaluate second if first is false
(rw_condb, trueb, falseb) = cw(b, dep & falsea)
return (rw_conda | rw_condb, truea | trueb, falsea & falseb)

elseif operation(expr) == Base.:(&) # AND of two conditions
a, b = arguments(expr)
(rw_conda, truea, falsea) = cw(a, dep)
# only evaluate second if first is true
(rw_condb, trueb, falseb) = cw(b, dep & truea)
return (rw_conda & rw_condb, truea & trueb, falsea | falseb)
elseif operation(expr) == ifelse
c, a, b = arguments(expr)
(rw_cond, ctrue, cfalse) = cw(c, dep)
# only evaluate if condition is true
(rw_conda, truea, falsea) = cw(a, dep & ctrue)
# only evaluate if condition is false
(rw_condb, trueb, falseb) = cw(b, dep & cfalse)
# expression is true if condition is true and THEN branch is true, or condition is false
# and ELSE branch is true
# similarly for expression being false
return (ifelse(rw_cond, rw_conda, rw_condb), implies(ctrue, truea) | implies(cfalse, trueb), implies(ctrue, falsea) | implies(cfalse, falseb))
elseif operation(expr) == Base.:(!) # NOT of expression
(a,) = arguments(expr)
(rw, ctrue, cfalse) = cw(a, dep)
return (!rw, cfalse, ctrue)
elseif operation(expr) in COMPARISONS # comparison operators
# turn int `var - var op 0`
expr = COMPARISON_TRANSFORM(expr)
# a new discrete variable to represent `var - var op 0`
cv = new_cond_sym(cw, expr, dep)
return (cv, cv, !cv)
elseif operation(expr) == (==)
# we don't touch equality since it's a point discontinuity. It's basically always
# false for continuous variables. In case it's an equality between discrete
# quantities, we don't need to transform it.
return (expr, expr, !expr)
end
error("Unsupported expression form in decision variable computation $expr")
end

"""
$(TYPEDSIGNATURES)
Acts as the identity function, and prevents transformation of conditional expressions inside it. Useful
if specific `ifelse` or other functions with discontinuous derivatives shouldn't be transformed into
callbacks.
"""
no_if_lift(s) = s
@register_symbolic no_if_lift(s)

"""
$(TYPEDEF)
A utility struct to search through an expression specifically for `ifelse` terms, and find
all variables used in the condition of such terms. The variables are stored in a field of
the struct.
"""
struct VarsUsedInCondition
"""
Stores variables used in conditions of `ifelse` statements in the expression.
"""
vars::Set{Any}
end

VarsUsedInCondition() = VarsUsedInCondition(Set())

function (v::VarsUsedInCondition)(expr)
expr = Symbolics.unwrap(expr)
if symbolic_type(expr) == NotSymbolic()
is_array_of_symbolics(expr) || return
foreach(v, expr)
return
end
iscall(expr) || return
op = operation(expr)

# do not search inside no_if_lift to avoid discovering
# redundant variables
op == no_if_lift && return

args = arguments(expr)
if op == ifelse
cond, branch_a, branch_b = arguments(expr)
vars!(v.vars, cond)
v(branch_a)
v(branch_b)
end
foreach(v, args)
return
end

"""
$(TYPEDSIGNATURES)
Given an expression `expr` which is to be evaluated if `dep` evaluates to `true`, transform
the conditions of all all `ifelse` statements in `expr` into functions of new discrete
variables. `cw` is used to store the information relevant to these newly introduced variables.
"""
function rewrite_ifs(cw::CondRewriter, expr, dep)
expr = unwrap(expr)
if symbolic_type(expr) == NotSymbolic()
is_array_of_symbolics(expr) || return expr
return map(expr) do ex
rewrite_ifs(cw, ex, dep)
end
end
iscall(expr) || return expr
op = operation(expr)
# don't recurse into singleton variables or places where the user doesn't want if-lifting
(issym(op) || op == no_if_lift) && return expr
args = arguments(expr)

# transform `ifelse` that don't depend on a single symbolic variable.
if op == ifelse && (!issym(args[1]) || iscall(args[1]) && !issym(operation(args[1])))
cond, iftrue, iffalse = args
(rw_cond, deptrue, depfalse) = cw(cond, dep)
rw_iftrue = rewrite_ifs(cw, iftrue, deptrue)
rw_iffalse = rewrite_ifs(cw, iffalse, depfalse)
return ifelse(unwrap(rw_cond), rw_iftrue, rw_iffalse)
end
# recursively rewrite
return maketerm(typeof(expr), op, map(x -> rewrite_ifs(cw, x, dep), args), metadata(expr))
end

"""
$(TYPEDSIGNATURES)
Return a modified `expr` where functions with known discontinuities or discontinuous
derivatives are transformed into `ifelse` statements. Utilizes the discontinuity API
in Symbolics. See [`Symbolics.rootfunction`](@ref),
[`Symbolics.left_continuous_function`](@ref), [`Symbolics.right_continuous_function`](@ref).
"""
function discontinuities_to_ifelse(expr)
if symbolic_type(expr) == NotSymbolic()
is_array_of_symbolics(expr) || return expr
return map(discontinuities_to_ifelse, expr)
end
iscall(expr) || return expr
op = operation(expr)
# don't transform inside `no_if_lift`
(issym(op) || op === no_if_lift) && return expr
args = arguments(expr)
args = map(discontinuities_to_ifelse, args)
# if the operation is a known discontinuity
if hasmethod(Symbolics.rootfunction, Tuple{typeof(op)})
rootfn = Symbolics.rootfunction(op)
leftfn = Symbolics.left_continuous_function(op)
rightfn = Symbolics.right_continuous_function(op)
rootexpr = rootfn(args...) < 0
leftexpr = leftfn(args...)
rightexpr = rightfn(args...)
return ifelse(rootexpr, leftexpr, rightexpr)
end
return maketerm(typeof(expr), op, args, Symbolics.metadata(expr))
end

"""
$(TYPEDSIGNATURES)
Generate the symbolic condition for discrete variable `sym`, which represents the condition
of an `ifelse` statement created through [`IfLifting`](@ref). This condition is used to
trigger a callback which updates the value of the condition appropriately.
"""
function generate_condition(cw::CondRewriter, sym)
(dep, uexpr) = cw.conditions[sym]
# `uexpr` is a comparison, the LHS is the zero-crossing function
zero_crossing = arguments(uexpr)[1]
# if we're meant to evaluate the condition, evaluate it. Otherwise, return `NaN`.
# the solvers don't treat the transition from a number to NaN or back as a zero-crossing,
# so it can be used to effectively disable the affect when the condition is not meant to
# be evaluated.
return ifelse(dep, arguments(uexpr)[1], NaN) ~ 0
end

"""
$(TYPEDSIGNATURES)
Generate the affect function for discrete variable `sym` involved in `ifelse` statements that
are lifted to callbacks using [`IfLifting`](@ref). `syms` is a condition variable introduced
by `cw`, and is thus a key in `cw.conditions`. `new_cond_vars` is the list of all such new
condition variables, corresponding to the order of vertices in `new_cond_vars_graph`.
`new_cond_vars_graph` is a directed graph where edges denote the condition variables involved
in the dependency expression of the source vertex.
"""
function generate_affect(cw::CondRewriter, sym, new_cond_vars, new_cond_vars_graph)
sym_idx = findfirst(isequal(sym), new_cond_vars)
if sym_idx === nothing
throw(ArgumentError("Expected variable $sym to be a condition variable in $new_cond_vars."))
end
# use reverse direction of edges because instead of finding the variables it depends
# on, we want the variables that depend on it
parents = bfs_parents(new_cond_vars_graph, sym_idx; dir = :in)
cond_vars_to_update = [new_cond_vars[i] for i in eachindex(parents) if !iszero(parents[i])]
update_syms = Symbol.(cond_vars_to_update)
update_exprs = [last(cw.conditions[sym]) for sym in cond_vars_to_update]
return ImperativeAffect(modified=NamedTuple{(update_syms...,)}(cond_vars_to_update), observed=NamedTuple{(update_syms...,)}(update_exprs), skip_checks=true) do x, o, c, i
x .= o
end
end

"""
If lifting converts (nested) if statements into a series of continous events + a logically equivalent if statement + parameters.

Check warning on line 301 in src/systems/if_lifting.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"continous" should be "continuous".
Lifting proceeds through the following process:
* rewrite comparisons to be of the form eqn [op] 0; subtract the RHS from the LHS
* replace comparisons with generated parameters; for each comparison eqn [op] 0, generate an event (dependent on op) that sets the parameter
"""
function IfLifting(sys::ODESystem)
cw = CondRewriter(get_iv(sys))

eqs = copy(equations(sys))
obs = copy(observed(sys))

# get variables used by `eqs`
syms = vars(eqs)
# get observed equations used by `eqs`
obs_idxs = observed_equations_used_by(sys, eqs; involved_vars = syms)
# and the variables used in those equations
for i in obs_idxs
vars!(syms, obs[i])
end

# get all integral variables used in conditions
# this is used when performing the transformation on observed equations
# since they are transformed differently depending on whether they are
# discrete variables involved in a condition or not
condition_vars = Set()
# searcher struct
# we can use the same one since it avoids iterating over duplicates
vars_in_condition! = VarsUsedInCondition()
for i in eachindex(eqs)
eq = eqs[i]
vars_in_condition!(eq.rhs)
# also transform the equation
eqs[i] = eq.lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true)
end
# also search through relevant observed equations
for i in obs_idxs
vars_in_condition!(obs[i].rhs)
end
# add to `condition_vars` after filtering out differential, parameter, independent and
# non-integral variables
for v in vars_in_condition!.vars
v = unwrap(v)
stype = symtype(v)
if isdifferential(v) || isparameter(v) || isequal(v, get_iv(sys))
continue
end
stype <: Union{Integer, AbstractArray{Integer}} && push!(condition_vars, v)
end
# transform observed equations
for i in obs_idxs
obs[i] = if obs[i].lhs in condition_vars
obs[i].lhs ~ first(cw(obs[i].rhs, true))
else
obs[i].lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs), true)
end
end

# get directed graph where nodes are the new condition variables and edges from each
# node denote the condition variables used in it's dependency expression

# so we have an ordering for the vertices
new_cond_vars = collect(keys(cw.conditions))
# "observed" equations
new_cond_dep_eqs = [v ~ cw.conditions[v] for v in new_cond_vars]
# construct the graph as a `DiCMOBiGraph`
new_cond_vars_graph = observed_dependency_graph(new_cond_dep_eqs)

new_callbacks = continuous_events(sys)
new_defaults = defaults(sys)
new_ps = parameters(sys)

for var in new_cond_vars
condition = generate_condition(cw, var)
affect = generate_affect(cw, var, new_cond_vars, new_cond_vars_graph)
cb = SymbolicContinuousCallback([condition], affect; affect_neg=affect, initialize=affect, rootfind=SciMLBase.RightRootFind)

push!(new_callbacks, cb)
new_defaults[var] = getdefault(var)
push!(new_ps, var)
end

@set! sys.defaults = new_defaults
@set! sys.eqs = eqs
# do not need to topsort because we didn't modify the order
@set! sys.observed = obs
@set! sys.continuous_events = new_callbacks
@set! sys.ps = new_ps
return sys
end

0 comments on commit d5b773a

Please sign in to comment.