Skip to content

Commit

Permalink
Merge pull request #3241 from aml5600/pass-checks-opt-prob
Browse files Browse the repository at this point in the history
OptimizationProblem updates
  • Loading branch information
ChrisRackauckas authored Nov 27, 2024
2 parents eda23d4 + d6add0e commit 1b46612
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
linenumbers = true, parallel = SerialForm(),
eval_expression = false, eval_module = @__MODULE__,
use_union = false,
checks = true,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed `OptimizationSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `OptimizationProblem`")
Expand Down Expand Up @@ -393,12 +394,17 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps)
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
cons_sys = complete(cons_sys)
cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds,
linenumbers = linenumbers,
expression = Val{true})
cons = eval_or_rgf.(cons; eval_expression, eval_module)
cons = let (cons_oop, cons_iip) = eval_or_rgf.(cons; eval_expression, eval_module)
_cons(u, p) = cons_oop(u, p)
_cons(resid, u, p) = cons_iip(resid, u, p)
_cons(u, p::MTKParameters) = cons_oop(u, p...)
_cons(resid, u, p::MTKParameters) = cons_iip(resid, u, p...)
end
if cons_j
_cons_j = let (cons_jac_oop, cons_jac_iip) = eval_or_rgf.(
generate_jacobian(cons_sys;
Expand Down Expand Up @@ -464,7 +470,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
grad = _grad,
hess = _hess,
hess_prototype = hess_prototype,
cons = cons[2],
cons = cons,
cons_j = _cons_j,
cons_h = _cons_h,
cons_jac_prototype = cons_jac_prototype,
Expand Down
9 changes: 9 additions & 0 deletions test/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,12 @@ end
@test is_variable(sys, x[2])
@test is_variable(sys, x[3])
end

@testset "Constraints work with nonnumeric parameters" begin
@variables x
@parameters p f(::Real)
@mtkbuild sys = OptimizationSystem(
x^2 + f(x) * p, [x], [f, p]; constraints = [2.0 f(x) + p])
prob = OptimizationProblem(sys, [x => 1.0], [p => 1.0, f => (x -> 2x)])
@test abs(prob.f.cons(prob.u0, prob.p)[1]) 1.0
end

0 comments on commit 1b46612

Please sign in to comment.