Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support heterogeneous parameters for linearize and remake #2283

Merged
merged 17 commits into from
Oct 4, 2023
Merged
1 change: 1 addition & 0 deletions docs/src/basics/MTKModel_Connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ end
`@connector`s accepts begin blocks of `@components`, `@equations`, `@extend`, `@parameters`, `@structural_parameters`, `@variables`. These keywords mean the same as described above for `@mtkmodel`.

!!! note

For more examples of usage, checkout [ModelingToolkitStandardLibrary.jl](https://github.com/SciML/ModelingToolkitStandardLibrary.jl/)

### What's a `structure` dictionary?
Expand Down
38 changes: 30 additions & 8 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@
:metadata
:gui_metadata
:discrete_subsystems
:unknown_states]
:unknown_states
:split_idxs]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down Expand Up @@ -1273,14 +1274,34 @@
function linearization_function(sys::AbstractSystem, inputs,
outputs; simplify = false,
initialize = true,
op = Dict(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very bad interface, now the user has to provide both p and op even if op already contains p. The new changes are also lacking documentation, how is the user supposed to know that they now have to provide also p for the results to be correct?

I would prefer to think a bit harder about this design, the current state feels like a quick hack that is going to be very confusing to use

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative is being wrong though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op here doesn't require the correct numerical value. It only cares about types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op here doesn't require the correct numerical value. It only cares about types.

What does this mean? op is the user-provided operating point around which to linearize.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we use heterogenous parameters by default now, we need to know the exact type of the parameters to create the linearized function correctly.

p = DiffEqBase.NullParameters(),
zero_dummy_der = false,
kwargs...)
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
ssys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs;
simplify,
kwargs...)
if zero_dummy_der
dummyder = setdiff(states(ssys), states(sys))
defs = Dict(x => 0.0 for x in dummyder)
@set! ssys.defaults = merge(defs, defaults(ssys))
op = merge(defs, op)
end
sys = ssys
x0 = merge(defaults(sys), op)
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple

Check warning on line 1297 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1296-L1297

Added lines #L1296 - L1297 were not covered by tests
end

lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = states(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, states(sys), ps; p = p),
h = build_explicit_observed_function(sys, outputs),
chunk = ForwardDiff.Chunk(input_idxs)

Expand Down Expand Up @@ -1599,11 +1620,12 @@
allow_input_derivatives = false,
zero_dummy_der = false,
kwargs...)
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
if zero_dummy_der
dummyder = setdiff(states(ssys), states(sys))
op = merge(op, Dict(x => 0.0 for x in dummyder))
end
lin_fun, ssys = linearization_function(sys,
inputs,
outputs;
zero_dummy_der,
op,
kwargs...)
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
end

Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
checkbounds = false,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
kwargs...) where {iip, specialize}
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
expression_module = eval_module, checkbounds = checkbounds,
Expand Down Expand Up @@ -508,6 +509,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
nothing
end

@set! sys.split_idxs = split_idxs
ODEFunction{iip, specialize}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
Expand Down Expand Up @@ -765,15 +767,17 @@ Take dictionaries with initial conditions and parameters and convert them to num
"""
function get_u0_p(sys,
u0map,
parammap;
parammap = nothing;
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
if parammap !== nothing
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)

if symbolic_u0
Expand Down Expand Up @@ -835,7 +839,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression,
sparse = sparse, eval_expression = eval_expression, split_idxs,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end
Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,19 @@ struct ODESystem <: AbstractODESystem
used for ODAEProblem.
"""
unknown_states::Union{Nothing, Vector{Any}}
"""
split_idxs: a vector of vectors of indices for the split parameters.
"""
split_idxs::Union{Nothing, Vector{Vector{Int}}}

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, gui_metadata = nothing,
tearing_state = nothing,
substitutions = nothing, complete = false,
discrete_subsystems = nothing, unknown_states = nothing;
checks::Union{Bool, Int} = true)
discrete_subsystems = nothing, unknown_states = nothing,
split_idxs = nothing; checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_variables(dvs, iv)
check_parameters(ps, iv)
Expand All @@ -161,7 +165,7 @@ struct ODESystem <: AbstractODESystem
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, gui_metadata,
tearing_state, substitutions, complete, discrete_subsystems,
unknown_states)
unknown_states, split_idxs)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ dprob = DiscreteProblem(js, u₀map, tspan, parammap)
function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothing},
parammap = DiffEqBase.NullParameters();
checkbounds = false,
use_union = false,
use_union = true,
kwargs...)
dvs = states(sys)
ps = parameters(sys)
Expand Down
54 changes: 24 additions & 30 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,46 +657,40 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
end
T = eltype(vs)
if Base.isconcretetype(T) && (!tofloat || T === float(T)) # nothing to do
vs
return vs
else
sym_vs = filter(x -> SymbolicUtils.issym(x) || SymbolicUtils.istree(x), vs)
isempty(sym_vs) || throw_missingvars_in_sys(sym_vs)
C = typeof(first(vs))
I = Int8
has_int = false
has_array = false
has_bool = false
array_T = nothing

C = nothing
for v in vs
if v isa AbstractArray
has_array = true
array_T = typeof(v)
E = typeof(v)
if E <: Number
if tofloat
E = float(E)
end
end
E = eltype(v)
C = promote_type(C, E)
if E <: Integer
has_int = true
I = promote_type(I, E)
if C === nothing
C = E
end
if E <: Bool
has_bool = true
if use_union
C = Union{C, E}
else
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
C = E
end
end
if tofloat && !has_array
C = float(C)
elseif has_array || (use_union && has_int && C !== I)
if has_array
C = Union{C, array_T}
end
if has_int
C = Union{C, I}
end
if has_bool
C = Union{C, Bool}

y = similar(vs, C)
for i in eachindex(vs)
if (vs[i] isa Number) & tofloat
y[i] = float(vs[i]) #needed because copyto! can't convert Int to Float automatically
else
y[i] = vs[i]
end
return copyto!(similar(vs, C), vs)
end
convert.(C, vs)

return y
end
end

Expand Down
26 changes: 17 additions & 9 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,23 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
end

# assemble defaults
defs = defaults(prob.f.sys)
defs = mergedefaults(defs, prob.p, parameters(prob.f.sys))
defs = mergedefaults(defs, p, parameters(prob.f.sys))
defs = mergedefaults(defs, prob.u0, states(prob.f.sys))
defs = mergedefaults(defs, u0, states(prob.f.sys))

u0 = varmap_to_vars(u0, states(prob.f.sys); defaults = defs, tofloat = true)
p = varmap_to_vars(p, parameters(prob.f.sys); defaults = defs)
sys = prob.f.sys
defs = defaults(sys)
ps = parameters(sys)
if has_split_idxs(sys) && (split_idxs = get_split_idxs(sys)) !== nothing
for (i, idxs) in enumerate(split_idxs)
defs = mergedefaults(defs, prob.p[i], ps[idxs])
end
else
# assemble defaults
defs = defaults(sys)
defs = mergedefaults(defs, prob.p, ps)
end
defs = mergedefaults(defs, p, ps)
sts = states(sys)
defs = mergedefaults(defs, prob.u0, sts)
defs = mergedefaults(defs, u0, sts)
u0, p, defs = get_u0_p(sys, defs)

return p, u0
end
Expand Down
4 changes: 2 additions & 2 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ using ModelingToolkitStandardLibrary.Mechanical.Rotational
t = ModelingToolkitStandardLibrary.Mechanical.Rotational.t
@named inertia1 = Inertia(; J = 1)
@named inertia2 = Inertia(; J = 1)
@named spring = Spring(; c = 10)
@named damper = Damper(; d = 3)
@named spring = Rotational.Spring(; c = 10)
@named damper = Rotational.Damper(; d = 3)
@named torque = Torque(; use_support = false)
@variables y(t) = 0
eqs = [connect(torque.flange, inertia1.flange_a)
Expand Down
6 changes: 3 additions & 3 deletions test/latexify/10.tex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
\begin{align}
\frac{\mathrm{d} x\left( t \right)}{\mathrm{d}t} =& \frac{\sigma \left( - x\left( t \right) + y\left( t \right) \right) \frac{\mathrm{d}}{\mathrm{d}t} \left( - y\left( t \right) + x\left( t \right) \right)}{\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t}} \\
0 =& - y\left( t \right) + \frac{1}{10} \sigma \left( \rho - z\left( t \right) \right) x\left( t \right) \\
\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t} =& \left( y\left( t \right) \right)^{\frac{2}{3}} x\left( t \right) - \beta z\left( t \right)
\frac{\mathrm{d} x\left( t \right)}{\mathrm{d}t} =& \frac{\left( - x\left( t \right) + y\left( t \right) \right) \frac{\mathrm{d}}{\mathrm{d}t} \left( x\left( t \right) - y\left( t \right) \right) \sigma}{\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t}} \\
0 =& - y\left( t \right) + \frac{1}{10} x\left( t \right) \left( - z\left( t \right) + \rho \right) \sigma \\
\frac{\mathrm{d} z\left( t \right)}{\mathrm{d}t} =& \left( y\left( t \right) \right)^{\frac{2}{3}} x\left( t \right) - z\left( t \right) \beta
\end{align}
4 changes: 2 additions & 2 deletions test/latexify/20.tex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
\begin{align}
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
0 =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& p_3 \left( - u(t)_1 + u(t)_2 \right) \\
0 =& - u(t)_2 + \frac{1}{10} \left( p_1 - u(t)_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_3 =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
\end{align}
4 changes: 2 additions & 2 deletions test/latexify/30.tex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
\begin{align}
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& \left( - u(t)_1 + u(t)_2 \right) p_3 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_2 =& - u(t)_2 + \frac{1}{10} \left( - u(t)_1 + p_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_1 =& p_3 \left( - u(t)_1 + u(t)_2 \right) \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_2 =& - u(t)_2 + \frac{1}{10} \left( p_1 - u(t)_1 \right) p_2 p_3 u(t)_1 \\
\frac{\mathrm{d}}{\mathrm{d}t} u(t)_3 =& u(t)_2^{\frac{2}{3}} u(t)_1 - p_3 u(t)_3
\end{align}
9 changes: 5 additions & 4 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ eqs = [D(x) ~ σ * (y - x),
ModelingToolkit.toexpr.(eqs)[1]
@named de = ODESystem(eqs; defaults = Dict(x => 1))
subed = substitute(de, [σ => k])
@test isequal(sort(parameters(subed), by = string), [k, β, ρ])
ssort(eqs) = sort(eqs, by = string)
@test isequal(ssort(parameters(subed)), [k, β, ρ])
@test isequal(equations(subed),
[D(x) ~ k * (y - x)
D(y) ~ (ρ - z) * x - y
Expand Down Expand Up @@ -241,7 +242,7 @@ p2 = (k₁ => 0.04,
k₃ => 1e4)
tspan = (0.0, 100000.0)
prob1 = ODEProblem(sys, u0, tspan, p)
@test prob1.f.sys === sys
@test prob1.f.sys == sys
prob12 = ODEProblem(sys, u0, tspan, [0.04, 3e7, 1e4])
prob13 = ODEProblem(sys, u0, tspan, (0.04, 3e7, 1e4))
prob14 = ODEProblem(sys, u0, tspan, p2)
Expand Down Expand Up @@ -348,14 +349,14 @@ eqs = [0 ~ x + z
D(accumulation_y) ~ y
D(accumulation_z) ~ z
D(x) ~ y]
@test sort(equations(asys), by = string) == eqs
@test ssort(equations(asys)) == ssort(eqs)
@variables ac(t)
asys = add_accumulations(sys, [ac => (x + y)^2])
eqs = [0 ~ x + z
0 ~ x - y
D(ac) ~ (x + y)^2
D(x) ~ y]
@test sort(equations(asys), by = string) == eqs
@test ssort(equations(asys)) == ssort(eqs)

sys2 = ode_order_lowering(sys)
M = ModelingToolkit.calculate_massmatrix(sys2)
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ end
@safetestset "FuncAffect Test" include("funcaffect.jl")
@safetestset "Constants Test" include("constants.jl")
# Reference tests go Last
@safetestset "Latexify recipes Test" include("latexify.jl")
if VERSION >= v"1.9"
@safetestset "Latexify recipes Test" include("latexify.jl")
end
Loading
Loading