Skip to content

Commit

Permalink
Suppose heterogeneous parameters for linearize and remake
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Sep 26, 2023
1 parent ee4deb9 commit a9c56b6
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ PrecompileTools = "1"
RecursiveArrayTools = "2.3"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.0.1"
SciMLBase = "1, 2.0.1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
Expand Down
18 changes: 15 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ for prop in [:eqs
:metadata
:gui_metadata
:discrete_subsystems
:unknown_states]
:unknown_states
:split_idxs]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down Expand Up @@ -1274,14 +1275,25 @@ See also [`linearize`](@ref) which provides a higher-level interface.
function linearization_function(sys::AbstractSystem, inputs,
outputs; simplify = false,
initialize = true,
op = Dict(),
p = DiffEqBase.NullParameters(),
kwargs...)
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
kwargs...)
x0 = merge(defaults(sys), op)
u0, p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
p, split_idxs = split_parameters_by_type(p)
ps = parameters(sys)
if p isa Tuple
ps = Base.Fix1(getindex, ps).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple

Check warning on line 1289 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1283-L1289

Added lines #L1283 - L1289 were not covered by tests
end

lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = states(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys, states(sys), ps; p = p),
h = build_explicit_observed_function(sys, outputs),
chunk = ForwardDiff.Chunk(input_idxs)

Expand Down Expand Up @@ -1600,11 +1612,11 @@ function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
allow_input_derivatives = false,
zero_dummy_der = false,
kwargs...)
lin_fun, ssys = linearization_function(sys, inputs, outputs; kwargs...)
if zero_dummy_der
dummyder = setdiff(states(ssys), states(sys))
op = merge(op, Dict(x => 0.0 for x in dummyder))
end
lin_fun, ssys = linearization_function(sys, inputs, outputs; op, kwargs...)

Check warning on line 1619 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L1619

Added line #L1619 was not covered by tests
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
end

Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
checkbounds = false,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
kwargs...) where {iip, specialize}
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
expression_module = eval_module, checkbounds = checkbounds,
Expand Down Expand Up @@ -508,6 +509,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
nothing
end

@set! sys.split_idxs = split_idxs
ODEFunction{iip, specialize}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
Expand Down Expand Up @@ -765,15 +767,17 @@ Take dictionaries with initial conditions and parameters and convert them to num
"""
function get_u0_p(sys,
u0map,
parammap;
parammap = nothing;
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)

defs = defaults(sys)
defs = mergedefaults(defs, parammap, ps)
if parammap !== nothing
defs = mergedefaults(defs, parammap, ps)
end
defs = mergedefaults(defs, u0map, dvs)

if symbolic_u0
Expand Down Expand Up @@ -835,7 +839,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression,
sparse = sparse, eval_expression = eval_expression, split_idxs,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end
Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,19 @@ struct ODESystem <: AbstractODESystem
used for ODAEProblem.
"""
unknown_states::Union{Nothing, Vector{Any}}
"""
split_idxs: a vector of vectors of indices for the split parameters.
"""
split_idxs::Union{Nothing, Vector{Vector{Int}}}

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, gui_metadata = nothing,
tearing_state = nothing,
substitutions = nothing, complete = false,
discrete_subsystems = nothing, unknown_states = nothing;
checks::Union{Bool, Int} = true)
discrete_subsystems = nothing, unknown_states = nothing,
split_idxs = nothing; checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_variables(dvs, iv)
check_parameters(ps, iv)
Expand All @@ -161,7 +165,7 @@ struct ODESystem <: AbstractODESystem
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, gui_metadata,
tearing_state, substitutions, complete, discrete_subsystems,
unknown_states)
unknown_states, split_idxs)
end
end

Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
else
sym_vs = filter(x -> SymbolicUtils.issym(x) || SymbolicUtils.istree(x), vs)
isempty(sym_vs) || throw_missingvars_in_sys(sym_vs)

C = nothing

Check warning on line 665 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L665

Added line #L665 was not covered by tests
for v in vs
E = typeof(v)
Expand All @@ -676,7 +676,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
if use_union
C = Union{C, E}

Check warning on line 677 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L676-L677

Added lines #L676 - L677 were not covered by tests
else
@assert C == E "`promote_to_concrete` can't make type $E uniform with $C"
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
C = E

Check warning on line 680 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L679-L680

Added lines #L679 - L680 were not covered by tests
end
end
Expand All @@ -686,7 +686,7 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
if (vs[i] isa Number) & tofloat
y[i] = float(vs[i]) #needed because copyto! can't convert Int to Float automatically

Check warning on line 687 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L684-L687

Added lines #L684 - L687 were not covered by tests
else
y[i] = vs[i]
y[i] = vs[i]

Check warning on line 689 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L689

Added line #L689 was not covered by tests
end
end

Expand Down
26 changes: 17 additions & 9 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,23 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
end

# assemble defaults
defs = defaults(prob.f.sys)
defs = mergedefaults(defs, prob.p, parameters(prob.f.sys))
defs = mergedefaults(defs, p, parameters(prob.f.sys))
defs = mergedefaults(defs, prob.u0, states(prob.f.sys))
defs = mergedefaults(defs, u0, states(prob.f.sys))

u0 = varmap_to_vars(u0, states(prob.f.sys); defaults = defs, tofloat = true)
p = varmap_to_vars(p, parameters(prob.f.sys); defaults = defs)
sys = prob.f.sys
defs = defaults(sys)
ps = parameters(sys)
if has_split_idxs(sys) && (split_idxs = get_split_idxs(sys)) !== nothing
for (i, idxs) in enumerate(split_idxs)
defs = mergedefaults(defs, prob.p[i], ps[idxs])
end

Check warning on line 154 in src/variables.jl

View check run for this annotation

Codecov / codecov/patch

src/variables.jl#L148-L154

Added lines #L148 - L154 were not covered by tests
else
# assemble defaults
defs = defaults(sys)
defs = mergedefaults(defs, prob.p, ps)

Check warning on line 158 in src/variables.jl

View check run for this annotation

Codecov / codecov/patch

src/variables.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
end
defs = mergedefaults(defs, p, ps)
sts = states(sys)
defs = mergedefaults(defs, prob.u0, sts)
defs = mergedefaults(defs, u0, sts)
u0, p, defs = get_u0_p(sys, defs)

Check warning on line 164 in src/variables.jl

View check run for this annotation

Codecov / codecov/patch

src/variables.jl#L160-L164

Added lines #L160 - L164 were not covered by tests

return p, u0
end
Expand Down
90 changes: 35 additions & 55 deletions test/split_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,29 @@ using ModelingToolkit, Test
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEq


x = [1, 2.0, false, [1,2,3], Parameter(1.0)]
x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]

y = ModelingToolkit.promote_to_concrete(x)
@test eltype(y) == Union{Float64, Parameter{Float64}, Vector{Int64}}

y = ModelingToolkit.promote_to_concrete(x; tofloat=false)
y = ModelingToolkit.promote_to_concrete(x; tofloat = false)
@test eltype(y) == Union{Bool, Float64, Int64, Parameter{Float64}, Vector{Int64}}


x = [1, 2.0, false, [1,2,3]]
x = [1, 2.0, false, [1, 2, 3]]
y = ModelingToolkit.promote_to_concrete(x)
@test eltype(y) == Union{Float64, Vector{Int64}}

x = Any[1, 2.0, false]
y = ModelingToolkit.promote_to_concrete(x; tofloat=false)
y = ModelingToolkit.promote_to_concrete(x; tofloat = false)
@test eltype(y) == Union{Bool, Float64, Int64}

y = ModelingToolkit.promote_to_concrete(x; use_union=false)
y = ModelingToolkit.promote_to_concrete(x; use_union = false)
@test eltype(y) == Float64

x = Float16[1., 2., 3.]
x = Float16[1.0, 2.0, 3.0]
y = ModelingToolkit.promote_to_concrete(x)
@test eltype(y) == Float16




# ------------------------ Mixed Single Values and Vector

dt = 4e-4
Expand Down Expand Up @@ -74,7 +69,7 @@ eqs = [y ~ src.output.u
@named sys = ODESystem(eqs, t, vars, []; systems = [int, src])
s = complete(sys)
sys = structural_simplify(sys)
prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; tofloat=false)
prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; tofloat = false)
@test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}}
sol = solve(prob, ImplicitEuler());
@test sol.retcode == ReturnCode.Success
Expand All @@ -83,18 +78,15 @@ sol = solve(prob, ImplicitEuler());
#TODO: remake becomes more complicated now, how to improve?
defs = ModelingToolkit.defaults(sys)
defs[s.src.data] = 2x
p′ = ModelingToolkit.varmap_to_vars(defs, parameters(sys); tofloat=false)
p′ = ModelingToolkit.varmap_to_vars(defs, parameters(sys); tofloat = false)
p′, = ModelingToolkit.split_parameters_by_type(p′) #NOTE: we need to ensure this is called now before calling remake()
prob′ = remake(prob; p=p′)
prob′ = remake(prob; p = p′)
sol = solve(prob′, ImplicitEuler());
@test sol.retcode == ReturnCode.Success
@test sol[y][end] == 2x[end]

prob′′ = remake(prob; p=[s.src.data => x])
@test prob′′.p isa Tuple



prob′′ = remake(prob; p = [s.src.data => x])
@test_broken prob′′.p isa Tuple

# ------------------------ Mixed Type Converted to float (default behavior)

Expand Down Expand Up @@ -122,11 +114,6 @@ prob = ODEProblem(sys, [], tspan, []; tofloat = false)
sol = solve(prob, ImplicitEuler());
@test sol.retcode == ReturnCode.Success






# ------------------------- Bug
using ModelingToolkit, LinearAlgebra
using ModelingToolkitStandardLibrary.Mechanical.Rotational
Expand All @@ -136,51 +123,48 @@ using ModelingToolkit: connect

"A wrapper function to make symbolic indexing easier"
function wr(sys)
ODESystem(Equation[], ModelingToolkit.get_iv(sys), systems=[sys], name=:a_wrapper)
ODESystem(Equation[], ModelingToolkit.get_iv(sys), systems = [sys], name = :a_wrapper)
end
indexof(sym,syms) = findfirst(isequal(sym),syms)
indexof(sym, syms) = findfirst(isequal(sym), syms)

# Parameters
m1 = 1.
m2 = 1.
k = 10. # Spring stiffness
c = 3. # Damping coefficient
m1 = 1.0
m2 = 1.0
k = 10.0 # Spring stiffness
c = 3.0 # Damping coefficient

@named inertia1 = Inertia(; J = m1)
@named inertia2 = Inertia(; J = m2)
@named spring = Spring(; c = k)
@named damper = Damper(; d = c)
@named torque = Torque(use_support=false)
@named torque = Torque(use_support = false)

function SystemModel(u=nothing; name=:model)
eqs = [
connect(torque.flange, inertia1.flange_a)
function SystemModel(u = nothing; name = :model)
eqs = [connect(torque.flange, inertia1.flange_a)
connect(inertia1.flange_b, spring.flange_a, damper.flange_a)
connect(inertia2.flange_a, spring.flange_b, damper.flange_b)
]
connect(inertia2.flange_a, spring.flange_b, damper.flange_b)]
if u !== nothing
push!(eqs, connect(torque.tau, u.output))
return @named model = ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper, u])
return @named model = ODESystem(eqs,
t;
systems = [torque, inertia1, inertia2, spring, damper, u])
end
ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper], name)
end


model = SystemModel() # Model with load disturbance
@named d = Step(start_time=1., duration=10., offset=0., height=1.) # Disturbance
@named d = Step(start_time = 1.0, duration = 10.0, offset = 0.0, height = 1.0) # Disturbance
model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.inertia2.phi] # This is the state realization we want to control
inputs = [model.torque.tau.u]
matrices, ssys = ModelingToolkit.linearize(wr(model), inputs, model_outputs)

# Design state-feedback gain using LQR
# Define cost matrices
x_costs = [
model.inertia1.w => 1.
model.inertia2.w => 1.
model.inertia1.phi => 1.
model.inertia2.phi => 1.
]
L = randn(1,4) # Post-multiply by `C` to get the correct input to the controller
x_costs = [model.inertia1.w => 1.0
model.inertia2.w => 1.0
model.inertia1.phi => 1.0
model.inertia2.phi => 1.0]
L = randn(1, 4) # Post-multiply by `C` to get the correct input to the controller

# This old definition of MatrixGain will work because the parameter space does not include K (an Array term)
# @component function MatrixGainAlt(K::AbstractArray; name)
Expand All @@ -191,16 +175,12 @@ L = randn(1,4) # Post-multiply by `C` to get the correct input to the controller
# compose(ODESystem(eqs, t, [], []; name = name), [input, output])
# end

@named state_feedback = MatrixGain(K=-L) # Build negative feedback into the feedback matrix
@named add = Add(;k1=1., k2=1.) # To add the control signal and the disturbance
@named state_feedback = MatrixGain(K = -L) # Build negative feedback into the feedback matrix
@named add = Add(; k1 = 1.0, k2 = 1.0) # To add the control signal and the disturbance

connections = [
[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4]
connections = [[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4]
connect(d.output, :d, add.input1)
connect(add.input2, state_feedback.output)
connect(add.output, :u, model.torque.tau)
]
closed_loop = ODESystem(connections, t, systems=[model, state_feedback, add, d], name=:closed_loop)
connect(add.output, :u, model.torque.tau)]
@named closed_loop = ODESystem(connections, t, systems = [model, state_feedback, add, d])
S = get_sensitivity(closed_loop, :u)


0 comments on commit a9c56b6

Please sign in to comment.