diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 801e7b05f3..3425216b5f 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -144,6 +144,33 @@ function OptimizationSystem(op, unknowns, ps; checks = checks) end +function OptimizationSystem(objective; constraints = [], kwargs...) + allunknowns = OrderedSet() + ps = OrderedSet() + collect_vars!(allunknowns, ps, objective, nothing) + for cons in constraints + collect_vars!(allunknowns, ps, cons, nothing) + end + for ssys in get(kwargs, :systems, OptimizationSystem[]) + collect_scoped_vars!(allunknowns, ps, ssys, nothing) + end + new_ps = OrderedSet() + for p in ps + if iscall(p) && operation(p) === getindex + par = arguments(p)[begin] + if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && + all(par[i] in ps for i in eachindex(par)) + push!(new_ps, par) + else + push!(new_ps, p) + end + else + push!(new_ps, p) + end + end + return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...) +end + function flatten(sys::OptimizationSystem) systems = get_systems(sys) isempty(systems) && return sys diff --git a/src/utils.jl b/src/utils.jl index 830ec98e44..1555cd624e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -516,6 +516,15 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif end end end + if has_constraints(sys) + for eq in get_constraints(sys) + eqtype_supports_collect_vars(eq) || continue + collect_vars!(unknowns, parameters, eq, iv; depth, op) + end + end + if has_op(sys) + collect_vars!(unknowns, parameters, get_op(sys), iv; depth, op) + end newdepth = depth == -1 ? depth : depth + 1 for ssys in get_systems(sys) collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op) @@ -544,9 +553,10 @@ Can be dispatched by higher-level libraries to indicate support. """ eqtype_supports_collect_vars(eq) = false eqtype_supports_collect_vars(eq::Equation) = true +eqtype_supports_collect_vars(eq::Inequality) = true eqtype_supports_collect_vars(eq::Pair) = true -function collect_vars!(unknowns, parameters, eq::Equation, iv; +function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv; depth = 0, op = Differential) collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op) collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op) @@ -559,6 +569,7 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ return nothing end + function collect_var!(unknowns, parameters, var, iv; depth = 0) isequal(var, iv) && return nothing check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index bb59fb09d9..c20613441a 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -377,3 +377,14 @@ end prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)]) @test abs(prob.f.cons(prob.u0, prob.p)[1]) ≈ 1.0 end + +@testset "Variable discovery" begin + @variables x1 x2 + @parameters p1 p2 + @named sys1 = OptimizationSystem(x1^2; constraints = [p1 * x1 ≲ 2.0]) + @named sys2 = OptimizationSystem(x2^2; constraints = [p2 * x2 ≲ 2.0], systems = [sys1]) + @test isequal(only(unknowns(sys1)), x1) + @test isequal(only(parameters(sys1)), p1) + @test all(y -> any(x -> isequal(x, y), unknowns(sys2)), [x2, sys1.x1]) + @test all(y -> any(x -> isequal(x, y), parameters(sys2)), [p2, sys1.p1]) +end