Skip to content

Commit

Permalink
Merge pull request #3220 from SciML/infiniteopt
Browse files Browse the repository at this point in the history
add capability to trace MTK dynamics with InfiniteOpt
  • Loading branch information
ChrisRackauckas authored Nov 30, 2024
2 parents 8e48f65 + 7200999 commit e8a9f04
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"

[extensions]
Expand All @@ -72,6 +73,7 @@ MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"

[compat]
AbstractTrees = "0.3, 0.4"
Expand Down Expand Up @@ -104,6 +106,7 @@ FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
HomotopyContinuation = "2.11"
InfiniteOpt = "0.5"
InteractiveUtils = "1"
JuliaFormatter = "1.0.47"
JumpProcesses = "9.13.1"
Expand Down
26 changes: 26 additions & 0 deletions ext/MTKInfiniteOptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module MTKInfiniteOptExt
import ModelingToolkit
import SymbolicUtils
import NaNMath
import InfiniteOpt
import InfiniteOpt: JuMP, GeneralVariableRef

# This file contains method definitions to make it possible to trace through functions generated by MTK using JuMP variables

for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
f = nameof(ff)
# These need to be defined so that JuMP can trace through functions built by Symbolics
@eval NaNMath.$f(x::GeneralVariableRef) = Base.$f(x)
end

# JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both.
function Base.isequal(::SymbolicUtils.Symbolic,
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr})
false
end
function Base.isequal(
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr},
::SymbolicUtils.Symbolic)
false
end
end
3 changes: 3 additions & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
Expand Down
102 changes: 102 additions & 0 deletions test/extensions/test_infiniteopt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using ModelingToolkit, InfiniteOpt, JuMP, Ipopt
using ModelingToolkit: D_nounits as D, t_nounits as t, varmap_to_vars

@mtkmodel Pendulum begin
@parameters begin
g = 9.8
L = 0.4
K = 1.2
m = 0.3
end
@variables begin
θ(t) # state
ω(t) # state
τ(t) = 0 # input
y(t) # output
end
@equations begin
D(θ) ~ ω
D(ω) ~ -g / L * sin(θ) - K / m * ω + τ / m / L^2
y ~ θ * 180 / π
end
end
@named model = Pendulum()
model = complete(model)

inputs = [model.τ]
(f_oop, f_ip), dvs, psym, io_sys = ModelingToolkit.generate_control_function(
model, inputs, split = false)

outputs = [model.y]
f_obs = ModelingToolkit.build_explicit_observed_function(io_sys, outputs; inputs = inputs)

expected_state_order = [model.θ, model.ω]
permutation = [findfirst(isequal(x), expected_state_order) for x in dvs] # This maps our expected state order to the actual state order

##

ub = varmap_to_vars([model.θ => 2pi, model.ω => 10], dvs)
lb = varmap_to_vars([model.θ => -2pi, model.ω => -10], dvs)
xf = varmap_to_vars([model.θ => pi, model.ω => 0], dvs)
nx = length(dvs)
nu = length(inputs)
ny = length(outputs)

##
m = InfiniteModel(optimizer_with_attributes(Ipopt.Optimizer,
"print_level" => 0, "acceptable_tol" => 1e-3, "constr_viol_tol" => 1e-5, "max_iter" => 1000,
"tol" => 1e-5, "mu_strategy" => "monotone", "nlp_scaling_method" => "gradient-based",
"alpha_for_y" => "safer-min-dual-infeas", "bound_mult_init_method" => "mu-based", "print_user_options" => "yes"));

@infinite_parameter(m, τ in [0, 1], num_supports=51,
derivative_method=OrthogonalCollocation(4)) # Time variable
guess_xs = [t -> pi, t -> 0.1][permutation]
guess_us = [t -> 0.1]
InfiniteOpt.@variables(m,
begin
# state variables
(lb[i] <= x[i = 1:nx] <= ub[i], Infinite(τ), start = guess_xs[i]) # state variables
-10 <= u[i = 1:nu] <= 10, Infinite(τ), (start = guess_us[i]) # control variables
0 <= tf <= 10, (start = 5) # Final time
0.2 <= L <= 0.6, (start = 0.4) # Length parameter
end)

# Trace the dynamics
x0, p = ModelingToolkit.get_u0_p(io_sys, [model.θ => 0, model.ω => 0], [model.L => L])

xp = f_oop(x, u, p, τ)
cp = f_obs(x, u, p, τ) # Test that it's possible to trace through an observed function

@objective(m, Min, tf)
@constraint(m, [i = 1:nx], x[i](0)==x0[i]) # Initial condition
@constraint(m, [i = 1:nx], x[i](1)==xf[i]) # Terminal state

x_scale = varmap_to_vars([model.θ => 1
model.ω => 1], dvs)

# Add dynamics constraints
@constraint(m, [i = 1:nx], ((x[i], τ) - tf * xp[i]) / x_scale[i]==0)

optimize!(m)

# Extract the optimal solution
opt_tf = value(tf)
opt_time = opt_tf * value(τ)
opt_x = [value(x[i]) for i in permutation]
opt_u = [value(u[i]) for i in 1:nu]
opt_L = value(L)

# Plot the results
# using Plots
# plot(opt_time, opt_x[1], label = "θ", xlabel = "Time [s]", layout=3)
# plot!(opt_time, opt_x[2], label = "ω", sp=2)
# plot!(opt_time, opt_u[1], label = "τ", sp=3)

using Test
@test opt_x[1][end]pi atol=1e-3
@test opt_x[2][end]0 atol=1e-3

@test opt_x[1][1]0 atol=1e-3
@test opt_x[2][1]0 atol=1e-3

@test opt_L0.2 atol=1e-3 # Smallest permissible length is optimal
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,6 @@ end
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
@safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl")
end
end

0 comments on commit e8a9f04

Please sign in to comment.