Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add automatic variable discovery for OptimizationSystem #3243

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ function OptimizationSystem(op, unknowns, ps;
checks = checks)
end

function OptimizationSystem(objective; constraints = [], kwargs...)
allunknowns = OrderedSet()
ps = OrderedSet()
collect_vars!(allunknowns, ps, objective, nothing)
for cons in constraints
collect_vars!(allunknowns, ps, cons, nothing)
end
for ssys in get(kwargs, :systems, OptimizationSystem[])
collect_scoped_vars!(allunknowns, ps, ssys, nothing)
end
new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)
else
push!(new_ps, p)
end
else
push!(new_ps, p)
end
end
return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
end

function flatten(sys::OptimizationSystem)
systems = get_systems(sys)
isempty(systems) && return sys
Expand Down
13 changes: 12 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif
end
end
end
if has_constraints(sys)
for eq in get_constraints(sys)
eqtype_supports_collect_vars(eq) || continue
collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
end
if has_op(sys)
collect_vars!(unknowns, parameters, get_op(sys), iv; depth, op)
end
newdepth = depth == -1 ? depth : depth + 1
for ssys in get_systems(sys)
collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op)
Expand Down Expand Up @@ -544,9 +553,10 @@ Can be dispatched by higher-level libraries to indicate support.
"""
eqtype_supports_collect_vars(eq) = false
eqtype_supports_collect_vars(eq::Equation) = true
eqtype_supports_collect_vars(eq::Inequality) = true
eqtype_supports_collect_vars(eq::Pair) = true

function collect_vars!(unknowns, parameters, eq::Equation, iv;
function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv;
depth = 0, op = Differential)
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
Expand All @@ -559,6 +569,7 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
return nothing
end


function collect_var!(unknowns, parameters, var, iv; depth = 0)
isequal(var, iv) && return nothing
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
Expand Down
11 changes: 11 additions & 0 deletions test/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,14 @@ end
prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)])
@test abs(prob.f.cons(prob.u0, prob.p)[1]) ≈ 1.0
end

@testset "Variable discovery" begin
@variables x1 x2
@parameters p1 p2
@named sys1 = OptimizationSystem(x1^2; constraints = [p1 * x1 ≲ 2.0])
@named sys2 = OptimizationSystem(x2^2; constraints = [p2 * x2 ≲ 2.0], systems = [sys1])
@test isequal(only(unknowns(sys1)), x1)
@test isequal(only(parameters(sys1)), p1)
@test all(y -> any(x -> isequal(x, y), unknowns(sys2)), [x2, sys1.x1])
@test all(y -> any(x -> isequal(x, y), parameters(sys2)), [p2, sys1.p1])
end
Loading