Skip to content

Commit

Permalink
Merge pull request #2283 from SciML/bgc/split_params_bug
Browse files Browse the repository at this point in the history
Support heterogeneous parameters for linearize and remake
  • Loading branch information
YingboMa authored Oct 4, 2023
2 parents 8886a91 + ae05041 commit 2309f9f
Show file tree
Hide file tree
Showing 15 changed files with 236 additions and 92 deletions.
1 change: 1 addition & 0 deletions docs/src/basics/MTKModel_Connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ end
`@connector`s accepts begin blocks of `@components`, `@equations`, `@extend`, `@parameters`, `@structural_parameters`, `@variables`. These keywords mean the same as described above for `@mtkmodel`.

!!! note

For more examples of usage, checkout [ModelingToolkitStandardLibrary.jl](https://github.com/SciML/ModelingToolkitStandardLibrary.jl/)

### What's a `structure` dictionary?
Expand Down
38 changes: 30 additions & 8 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ for prop in [:eqs
:metadata
:gui_metadata
:discrete_subsystems
:unknown_states]
:unknown_states
:split_idxs]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down Expand Up @@ -1273,14 +1274,34 @@ See also [`linearize`](@ref) which provides a higher-level interface.
function linearization_function(sys::AbstractSystem, inputs,
outputs; simplify = false,
initialize = true,
op = Dict(),
p = DiffEqBase.NullParameters(),
zero_dummy_der = false,
kwargs...)
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
ssys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs;
simplify,
kwargs...)
if zero_dummy_der
dummyder = setdiff(states(ssys), states(sys))
defs = Dict(x => 0.0 for x in dummyder)
@set! ssys.defaults = merge(defs, defaults(ssys))
op = merge(defs, op)
end
sys = ssys
x0 = merge(defaults(sys), op)
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
end

lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = states(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, states(sys), ps; p = p),
h = build_explicit_observed_function(sys, outputs),
chunk = ForwardDiff.Chunk(input_idxs)

Expand Down Expand Up @@ -1599,11 +1620,12 @@ function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
allow_input_derivatives = false,
zero_dummy_der = false,
kwargs...)
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
if zero_dummy_der
dummyder = setdiff(states(ssys), states(sys))
op = merge(op, Dict(x => 0.0 for x in dummyder))
end
lin_fun, ssys = linearization_function(sys,
inputs,
outputs;
zero_dummy_der,
op,
kwargs...)
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
end

Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
checkbounds = false,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
kwargs...) where {iip, specialize}
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
expression_module = eval_module, checkbounds = checkbounds,
Expand Down Expand Up @@ -508,6 +509,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
nothing
end

@set! sys.split_idxs = split_idxs
ODEFunction{iip, specialize}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
Expand Down Expand Up @@ -765,15 +767,17 @@ Take dictionaries with initial conditions and parameters and convert them to num
"""
function get_u0_p(sys,
u0map,
parammap;
parammap = nothing;
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
if parammap !== nothing
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)

if symbolic_u0
Expand Down Expand Up @@ -835,7 +839,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression,
sparse = sparse, eval_expression = eval_expression, split_idxs,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end
Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,19 @@ struct ODESystem <: AbstractODESystem
used for ODAEProblem.
"""
unknown_states::Union{Nothing, Vector{Any}}
"""
split_idxs: a vector of vectors of indices for the split parameters.
"""
split_idxs::Union{Nothing, Vector{Vector{Int}}}

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, gui_metadata = nothing,
tearing_state = nothing,
substitutions = nothing, complete = false,
discrete_subsystems = nothing, unknown_states = nothing;
checks::Union{Bool, Int} = true)
discrete_subsystems = nothing, unknown_states = nothing,
split_idxs = nothing; checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_variables(dvs, iv)
check_parameters(ps, iv)
Expand All @@ -161,7 +165,7 @@ struct ODESystem <: AbstractODESystem
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, gui_metadata,
tearing_state, substitutions, complete, discrete_subsystems,
unknown_states)
unknown_states, split_idxs)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
parammap = DiffEqBase.NullParameters();
checkbounds = false,
use_union = false,
use_union = true,
kwargs...)
dvs = states(sys)
ps = parameters(sys)
Expand Down
54 changes: 24 additions & 30 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,46 +657,40 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
end
T = eltype(vs)
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
vs
return vs
else
sym_vs = filter(x -> SymbolicUtils.issym(x) || SymbolicUtils.istree(x), vs)
isempty(sym_vs) || throw_missingvars_in_sys(sym_vs)
C = typeof(first(vs))
I = Int8
has_int = false
has_array = false
has_bool = false
array_T = nothing

C = nothing
for v in vs
if v isa AbstractArray
has_array = true
array_T = typeof(v)
E = typeof(v)
if E <: Number
if tofloat
E = float(E)
end
end
E = eltype(v)
C = promote_type(C, E)
if E <: Integer
has_int = true
I = promote_type(I, E)
if C === nothing
C = E
end
if E <: Bool
has_bool = true
if use_union
C = Union{C, E}
else
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
C = E
end
end
if tofloat && !has_array
C = float(C)
elseif has_array || (use_union && has_int && C !== I)
if has_array
C = Union{C, array_T}
end
if has_int
C = Union{C, I}
end
if has_bool
C = Union{C, Bool}

y = similar(vs, C)
for i in eachindex(vs)
if (vs[i] isa Number) & tofloat
y[i] = float(vs[i]) #needed because copyto! can't convert Int to Float automatically
else
y[i] = vs[i]
end
return copyto!(similar(vs, C), vs)
end
convert.(C, vs)

return y
end
end

Expand Down
26 changes: 17 additions & 9 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,23 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
end

# assemble defaults
defs = defaults(prob.f.sys)
defs = mergedefaults(defs, prob.p, parameters(prob.f.sys))
defs = mergedefaults(defs, p, parameters(prob.f.sys))
defs = mergedefaults(defs, prob.u0, states(prob.f.sys))
defs = mergedefaults(defs, u0, states(prob.f.sys))

u0 = varmap_to_vars(u0, states(prob.f.sys); defaults = defs, tofloat = true)
p = varmap_to_vars(p, parameters(prob.f.sys); defaults = defs)
sys = prob.f.sys
defs = defaults(sys)
ps = parameters(sys)
if has_split_idxs(sys) && (split_idxs = get_split_idxs(sys)) !== nothing
for (i, idxs) in enumerate(split_idxs)
defs = mergedefaults(defs, prob.p[i], ps[idxs])
end
else
# assemble defaults
defs = defaults(sys)
defs = mergedefaults(defs, prob.p, ps)
end
defs = mergedefaults(defs, p, ps)
sts = states(sys)
defs = mergedefaults(defs, prob.u0, sts)
defs = mergedefaults(defs, u0, sts)
u0, p, defs = get_u0_p(sys, defs)

return p, u0
end
Expand Down
4 changes: 2 additions & 2 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ using ModelingToolkitStandardLibrary.Mechanical.Rotational
t = ModelingToolkitStandardLibrary.Mechanical.Rotational.t
@named inertia1 = Inertia(; J = 1)
@named inertia2 = Inertia(; J = 1)
@named spring = Spring(; c = 10)
@named damper = Damper(; d = 3)
@named spring = Rotational.Spring(; c = 10)
@named damper = Rotational.Damper(; d = 3)
@named torque = Torque(; use_support = false)
@variables y(t) = 0
eqs = [connect(torque.flange, inertia1.flange_a)
Expand Down
6 changes: 3 additions & 3 deletions test/latexify/10.tex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
\begin{align}
\frac{\mathrm{d} x\left( t \right)}{\mathrm{d}t} =& \frac{\sigma \left( - x\left( t \right) + y\left( t \right) \right) \frac{\mathrm{d}}{\mathrm{d}t} \left( - y\left( t \right) + x\left( t \right) \right)}{\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t}} \\
0 =& - y\left( t \right) + \frac{1}{10} \sigma \left( \rho - z\left( t \right) \right) x\left( t \right) \\
\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t} =& \left( y\left( t \right) \right)^{\frac{2}{3}} x\left( t \right) - \beta z\left( t \right)
\frac{\mathrm{d} x\left( t \right)}{\mathrm{d}t} =& \frac{\left( - x\left( t \right) + y\left( t \right) \right) \frac{\mathrm{d}}{\mathrm{d}t} \left( x\left( t \right) - y\left( t \right) \right) \sigma}{\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t}} \\
0 =& - y\left( t \right) + \frac{1}{10} x\left( t \right) \left( - z\left( t \right) + \rho \right) \sigma \\
\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t} =& \left( y\left( t \right) \right)^{\frac{2}{3}} x\left( t \right) - z\left( t \right) \beta
\end{align}
4 changes: 2 additions & 2 deletions test/latexify/20.tex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
\begin{align}
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
0 =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& p_3 \left( - u(t)_1 + u(t)_2 \right) \\
0 =& - u(t)_2 + \frac{1}{10} \left( p_1 - u(t)_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_3 =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
\end{align}
4 changes: 2 additions & 2 deletions test/latexify/30.tex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
\begin{align}
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_2 =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& p_3 \left( - u(t)_1 + u(t)_2 \right) \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_2 =& - u(t)_2 + \frac{1}{10} \left( p_1 - u(t)_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_3 =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
\end{align}
9 changes: 5 additions & 4 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ eqs = [D(x) ~ σ * (y - x),
ModelingToolkit.toexpr.(eqs)[1]
@named de = ODESystem(eqs; defaults = Dict(x => 1))
subed = substitute(de, [σ => k])
@test isequal(sort(parameters(subed), by = string), [k, β, ρ])
ssort(eqs) = sort(eqs, by = string)
@test isequal(ssort(parameters(subed)), [k, β, ρ])
@test isequal(equations(subed),
[D(x) ~ k * (y - x)
D(y) ~- z) * x - y
Expand Down Expand Up @@ -241,7 +242,7 @@ p2 = (k₁ => 0.04,
k₃ => 1e4)
tspan = (0.0, 100000.0)
prob1 = ODEProblem(sys, u0, tspan, p)
@test prob1.f.sys === sys
@test prob1.f.sys == sys
prob12 = ODEProblem(sys, u0, tspan, [0.04, 3e7, 1e4])
prob13 = ODEProblem(sys, u0, tspan, (0.04, 3e7, 1e4))
prob14 = ODEProblem(sys, u0, tspan, p2)
Expand Down Expand Up @@ -348,14 +349,14 @@ eqs = [0 ~ x + z
D(accumulation_y) ~ y
D(accumulation_z) ~ z
D(x) ~ y]
@test sort(equations(asys), by = string) == eqs
@test ssort(equations(asys)) == ssort(eqs)
@variables ac(t)
asys = add_accumulations(sys, [ac => (x + y)^2])
eqs = [0 ~ x + z
0 ~ x - y
D(ac) ~ (x + y)^2
D(x) ~ y]
@test sort(equations(asys), by = string) == eqs
@test ssort(equations(asys)) == ssort(eqs)

sys2 = ode_order_lowering(sys)
M = ModelingToolkit.calculate_massmatrix(sys2)
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ end
@safetestset "FuncAffect Test" include("funcaffect.jl")
@safetestset "Constants Test" include("constants.jl")
# Reference tests go Last
@safetestset "Latexify recipes Test" include("latexify.jl")
if VERSION >= v"1.9"
@safetestset "Latexify recipes Test" include("latexify.jl")
end
Loading

0 comments on commit 2309f9f

Please sign in to comment.