Skip to content

Commit

Permalink
Merge pull request #3243 from AayushSabharwal/as/optsys-discover-vari…
Browse files Browse the repository at this point in the history
…ables

feat: add automatic variable discovery for `OptimizationSystem`
  • Loading branch information
ChrisRackauckas authored Nov 28, 2024
2 parents bdb4c03 + 2885815 commit 1eda645
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
27 changes: 27 additions & 0 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ function OptimizationSystem(op, unknowns, ps;
checks = checks)
end

function OptimizationSystem(objective; constraints = [], kwargs...)
allunknowns = OrderedSet()
ps = OrderedSet()
collect_vars!(allunknowns, ps, objective, nothing)
for cons in constraints
collect_vars!(allunknowns, ps, cons, nothing)
end
for ssys in get(kwargs, :systems, OptimizationSystem[])
collect_scoped_vars!(allunknowns, ps, ssys, nothing)
end
new_ps = OrderedSet()
for p in ps
if iscall(p) && operation(p) === getindex
par = arguments(p)[begin]
if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() &&
all(par[i] in ps for i in eachindex(par))
push!(new_ps, par)
else
push!(new_ps, p)
end
else
push!(new_ps, p)
end
end
return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
end

function flatten(sys::OptimizationSystem)
systems = get_systems(sys)
isempty(systems) && return sys
Expand Down
13 changes: 12 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif
end
end
end
if has_constraints(sys)
for eq in get_constraints(sys)
eqtype_supports_collect_vars(eq) || continue
collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
end
if has_op(sys)
collect_vars!(unknowns, parameters, get_op(sys), iv; depth, op)
end
newdepth = depth == -1 ? depth : depth + 1
for ssys in get_systems(sys)
collect_scoped_vars!(unknowns, parameters, ssys, iv; depth = newdepth, op)
Expand Down Expand Up @@ -544,9 +553,10 @@ Can be dispatched by higher-level libraries to indicate support.
"""
eqtype_supports_collect_vars(eq) = false
eqtype_supports_collect_vars(eq::Equation) = true
eqtype_supports_collect_vars(eq::Inequality) = true
eqtype_supports_collect_vars(eq::Pair) = true

function collect_vars!(unknowns, parameters, eq::Equation, iv;
function collect_vars!(unknowns, parameters, eq::Union{Equation, Inequality}, iv;
depth = 0, op = Differential)
collect_vars!(unknowns, parameters, eq.lhs, iv; depth, op)
collect_vars!(unknowns, parameters, eq.rhs, iv; depth, op)
Expand All @@ -559,6 +569,7 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
return nothing
end


function collect_var!(unknowns, parameters, var, iv; depth = 0)
isequal(var, iv) && return nothing
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
Expand Down
11 changes: 11 additions & 0 deletions test/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,14 @@ end
prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)])
@test abs(prob.f.cons(prob.u0, prob.p)[1]) 1.0
end

@testset "Variable discovery" begin
@variables x1 x2
@parameters p1 p2
@named sys1 = OptimizationSystem(x1^2; constraints = [p1 * x1 2.0])
@named sys2 = OptimizationSystem(x2^2; constraints = [p2 * x2 2.0], systems = [sys1])
@test isequal(only(unknowns(sys1)), x1)
@test isequal(only(parameters(sys1)), p1)
@test all(y -> any(x -> isequal(x, y), unknowns(sys2)), [x2, sys1.x1])
@test all(y -> any(x -> isequal(x, y), parameters(sys2)), [p2, sys1.p1])
end

0 comments on commit 1eda645

Please sign in to comment.