Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix downstream indexing tests #2927

Merged
merged 8 commits into from
Aug 6, 2024
5 changes: 4 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,11 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if idx.portion isa SciMLStructures.Discrete &&
idx.idx[2] == idx.idx[3] == nothing
return nothing
elseif idx.portion isa SciMLStructures.Tunable
return ParameterIndex(
idx.portion, idx.idx[arguments(sym)[(begin + 1):end]...])
else
ParameterIndex(
return ParameterIndex(
idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
end
else
Expand Down
4 changes: 2 additions & 2 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ namespace_affects(::Nothing, s) = nothing
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
SymbolicContinuousCallback(
namespace_equation.(equations(cb), (s,)),
namespace_affects(affects(cb), s),
namespace_affects(affect_negs(cb), s))
namespace_affects(affects(cb), s);
affect_neg = namespace_affects(affect_negs(cb), s))
end

"""
Expand Down
138 changes: 0 additions & 138 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,141 +195,3 @@ function split_system(ci::ClockInference{S}) where {S}
end
return tss, inputs, continuous_id, id_to_clock
end

function generate_discrete_affect(
osys::AbstractODESystem, syss, inputs, continuous_id, id_to_clock;
checkbounds = true,
eval_module = @__MODULE__, eval_expression = false)
@static if VERSION < v"1.7"
error("The `generate_discrete_affect` function requires at least Julia 1.7")
end
has_index_cache(osys) && get_index_cache(osys) !== nothing ||
error("Hybrid systems require `split = true`")
out = Sym{Any}(:out)
appended_parameters = full_parameters(syss[continuous_id])
offset = length(appended_parameters)
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
for p in appended_parameters)
affect_funs = []
clocks = TimeDomain[]
for (i, (sys, input)) in enumerate(zip(syss, inputs))
i == continuous_id && continue
push!(clocks, id_to_clock[i])
subs = get_substitutions(sys)
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
let_block = Let(assignments, let_body, false)
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
# TODO: filter the needed ones
fullvars = Set{Any}(eq.lhs for eq in observed(sys))
for s in unknowns(sys)
push!(fullvars, s)
end
needed_disc_to_cont_obs = []
disc_to_cont_idxs = ParameterIndex[]
for v in inputs[continuous_id]
_v = arguments(v)[1]
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end

# If the held quantity is calculated through observed
# it will be shifted forward by 1
_v = Shift(get_iv(sys), 1)(_v)
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end
end
append!(appended_parameters, input)
cont_to_disc_obs = build_explicit_observed_function(
osys,
needed_cont_to_disc_obs,
throw = false,
expression = true,
output_type = SVector)
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
throw = false,
expression = true,
output_type = SVector,
op = Shift,
ps = reorder_parameters(osys, appended_parameters))
ni = length(input)
ns = length(unknowns(sys))
disc = Func(
[
out,
DestructuredArgs(unknowns(osys)),
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
get_iv(sys)
],
[],
let_block) |> toexpr
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
save_expr = :($(SciMLBase.save_discretes!)(integrator, $i))
empty_disc = isempty(disc_range)

# @show disc_to_cont_idxs
# @show cont_to_disc_idxs
# @show disc_range
affect! = :(function (integrator)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
# TODO: find a way to do this without allocating
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
disc = $disc

# Write continuous into to discrete: handles `Sample`
# Write discrete into to continuous
# Update discrete unknowns

# At a tick, c2d must come first
# state update comes in the middle
# d2c comes last
# @show t
# @show "incoming", p
result = c2d_obs(u, p..., t)
for (val, i) in zip(result, $cont_to_disc_idxs)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end
$(if !empty_disc
quote
disc(disc_unknowns, u, p..., t)
for (val, i) in zip(disc_unknowns, $disc_range)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end
end
end)
# @show "after c2d", p
# @show "after state update", p
result = d2c_obs(disc_unknowns, p..., t)
for (val, i) in zip(result, $disc_to_cont_idxs)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end

$save_expr

# @show "after d2c", p
discretes, repack, _ = $(SciMLStructures.canonicalize)(
$(SciMLStructures.Discrete()), p)
repack(discretes)
end)

push!(affect_funs, affect!)
end
if eval_expression
affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs)
else
affects = map(affect_funs) do a
drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, toexpr(LiteralExpr(a))))
end
end
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
return affects, clocks, appended_parameters, defaults
end
115 changes: 3 additions & 112 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,12 +782,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
varlist = collect(map(unwrap, dvs))
missingvars = setdiff(varlist, collect(keys(varmap)))

# Append zeros to the variables which are determined by the initialization system
# This essentially bypasses the check for if initial conditions are defined for DAEs
# since they will be checked in the initialization problem's construction
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
ci = infer_clocks!(ClockInference(TearingState(sys)))

if eltype(parammap) <: Pair
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
elseif parammap isa AbstractArray
Expand All @@ -798,38 +792,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
end
end

if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
clockedparammap = Dict()
defs = ModelingToolkit.get_defaults(sys)
for v in ps
v = unwrap(v)
is_discrete_domain(v) || continue
op = operation(v)
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
haskey(parammap, v)
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
end
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
if parammap != SciMLBase.NullParameters() &&
(val = get(parammap, shiftedv, nothing)) !== nothing
clockedparammap[v] = val
elseif op isa Shift
root = arguments(v)[1]
haskey(defs, root) || error("Initial condition for $v not provided.")
clockedparammap[v] = defs[root]
end
end
parammap = if parammap == SciMLBase.NullParameters()
clockedparammap
else
merge(parammap, clockedparammap)
end
end
# TODO: make it work with clocks
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if sys isa ODESystem && build_initializeprob &&
(((implicit_dae || !isempty(missingvars)) &&
all(==(Continuous), ci.var_domain) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys))) && t !== nothing
if eltype(u0map) <: Number
Expand Down Expand Up @@ -1010,29 +975,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
inits = []
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks = ModelingToolkit.generate_discrete_affect(
sys, dss...; eval_expression, eval_module)
discrete_cbs = map(affects, clocks) do affect, clock
@match clock begin
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
final_affect = true, initial_affect = true)
&SolverStepClock => DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
_ => error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs...)
end
end

kwargs = filter_kwargs(kwargs)
pt = something(get_metadata(sys), StandardODEProblem())

Expand Down Expand Up @@ -1112,40 +1055,14 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p, t) = h_oop(p, t)
h(p::MTKParameters, t) = h_oop(p..., t)
u0 = h(p, tspan[1])

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks = ModelingToolkit.generate_discrete_affect(
sys, dss...; eval_expression, eval_module)
discrete_cbs = map(affects, clocks) do affect, clock
@match clock begin
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
final_affect = true, initial_affect = true)
&SolverStepClock => DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
_ => error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
end

Expand Down Expand Up @@ -1175,40 +1092,14 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p::MTKParameters, t) = h_oop(p..., t)
h(out, p::MTKParameters, t) = h_iip(out, p..., t)
u0 = h(p, tspan[1])

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks = ModelingToolkit.generate_discrete_affect(
sys, dss...; eval_expression, eval_module)
discrete_cbs = map(affects, clocks) do affect, clock
@match clock begin
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
final_affect = true, initial_affect = true)
&SolverStepClock => DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
_ => error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end

noiseeqs = get_noiseeqs(sys)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
Expand Down
9 changes: 0 additions & 9 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,6 @@ function build_explicit_observed_function(sys, ts;
dep_vars = scalarize(setdiff(vars, ivs))

obs = param_only ? Equation[] : observed(sys)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
# each subsystem is topologically sorted independently. We can append the
# equations to override the `lhs ~ 0` equations in `observed(sys)`
syss, _, continuous_id, _... = dss
for (i, subsys) in enumerate(syss)
i == continuous_id && continue
append!(obs, observed(subsys))
end
end

cs = collect_constants(obs)
if !isempty(cs) > 0
Expand Down
Loading
Loading