Skip to content

Commit

Permalink
Respect parameter types
Browse files Browse the repository at this point in the history
Fix #2296
  • Loading branch information
YingboMa committed Oct 6, 2023
1 parent fb7c3af commit d89d460
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
12 changes: 12 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ function split_parameters_by_type(ps)
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
T = Float16
idx = 0
for (i, p) in enumerate(split_ps)
E = eltype(p)
if E == promote_type(T, E)
T = E
idx = i
end
end
if idx != 0 && idx != 1
split_ps[idx], split_ps[1] = split_ps[1], split_ps[idx]
end
return (split_ps...,), split_idxs
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ function get_u0_p(sys,
u0map,
parammap = nothing;
use_union = true,
tofloat = true,
tofloat = false,
symbolic_u0 = false)
dvs = states(sys)
ps = parameters(sys)
Expand Down Expand Up @@ -799,7 +799,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
linenumbers = true, parallel = SerialForm(),
eval_expression = true,
use_union = true,
tofloat = true,
tofloat = false,
symbolic_u0 = false,
u0_constructor = identity,
kwargs...)
Expand Down
8 changes: 7 additions & 1 deletion test/split_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ sys = structural_simplify(model)
tspan = (0.0, t_end)
prob = ODEProblem(sys, [], tspan, [])

@test prob.p isa Vector{Float64}
@test prob.p isa Tuple{Vector{Float64}, Vector{Int}}
sol = solve(prob, ImplicitEuler());
@test sol.retcode == ReturnCode.Success

Expand Down Expand Up @@ -184,3 +184,9 @@ connections = [[state_feedback.input.u[i] ~ model_outputs[i] for i in 1:4]
connect(add.output, :u, model.torque.tau)]
@named closed_loop = ODESystem(connections, t, systems = [model, state_feedback, add, d])
S = get_sensitivity(closed_loop, :u)

@variables t
@parameters a b c
@named s = ODESystem(Equation[], t, [], [a, b, c])
prob = ODEProblem(s, nothing, (0.0, 1.0), Pair[a => 1, b => 1, c => 1.0])
@test prob.p == ([1.0], [1, 1])

0 comments on commit d89d460

Please sign in to comment.