Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add capability to trace MTK dynamics with InfiniteOpt #3220

Merged
merged 3 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"

[extensions]
Expand All @@ -72,6 +73,7 @@ MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"

[compat]
AbstractTrees = "0.3, 0.4"
Expand Down Expand Up @@ -104,6 +106,7 @@ FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
HomotopyContinuation = "2.11"
InfiniteOpt = "0.5"
InteractiveUtils = "1"
JuliaFormatter = "1.0.47"
JumpProcesses = "9.13.1"
Expand Down
26 changes: 26 additions & 0 deletions ext/MTKInfiniteOptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module MTKInfiniteOptExt
import ModelingToolkit
import SymbolicUtils
import NaNMath
import InfiniteOpt
import InfiniteOpt: JuMP, GeneralVariableRef

# This file contains method definitions to make it possible to trace through functions generated by MTK using JuMP variables

for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
f = nameof(ff)
# These need to be defined so that JuMP can trace through functions built by Symbolics
@eval NaNMath.$f(x::GeneralVariableRef) = Base.$f(x)
end

# JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both.
function Base.isequal(::SymbolicUtils.Symbolic,
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr})
false
end
function Base.isequal(
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr},
::SymbolicUtils.Symbolic)
false
end
end
3 changes: 3 additions & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Nemo = "2edaba10-b0f1-5616-af89-8c11ac63239a"
Expand Down
102 changes: 102 additions & 0 deletions test/extensions/test_infiniteopt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using ModelingToolkit, InfiniteOpt, JuMP, Ipopt
using ModelingToolkit: D_nounits as D, t_nounits as t, varmap_to_vars

@mtkmodel Pendulum begin
@parameters begin
g = 9.8
L = 0.4
K = 1.2
m = 0.3
end
@variables begin
θ(t) # state
ω(t) # state
τ(t) = 0 # input
y(t) # output
end
@equations begin
D(θ) ~ ω
D(ω) ~ -g / L * sin(θ) - K / m * ω + τ / m / L^2
y ~ θ * 180 / π
end
end
@named model = Pendulum()
model = complete(model)

inputs = [model.τ]
(f_oop, f_ip), dvs, psym, io_sys = ModelingToolkit.generate_control_function(
model, inputs, split = false)

outputs = [model.y]
f_obs = ModelingToolkit.build_explicit_observed_function(io_sys, outputs; inputs = inputs)

expected_state_order = [model.θ, model.ω]
permutation = [findfirst(isequal(x), expected_state_order) for x in dvs] # This maps our expected state order to the actual state order

##

ub = varmap_to_vars([model.θ => 2pi, model.ω => 10], dvs)
lb = varmap_to_vars([model.θ => -2pi, model.ω => -10], dvs)
xf = varmap_to_vars([model.θ => pi, model.ω => 0], dvs)
nx = length(dvs)
nu = length(inputs)
ny = length(outputs)

##
m = InfiniteModel(optimizer_with_attributes(Ipopt.Optimizer,
"print_level" => 0, "acceptable_tol" => 1e-3, "constr_viol_tol" => 1e-5, "max_iter" => 1000,
"tol" => 1e-5, "mu_strategy" => "monotone", "nlp_scaling_method" => "gradient-based",
"alpha_for_y" => "safer-min-dual-infeas", "bound_mult_init_method" => "mu-based", "print_user_options" => "yes"));

@infinite_parameter(m, τ in [0, 1], num_supports=51,
derivative_method=OrthogonalCollocation(4)) # Time variable
guess_xs = [t -> pi, t -> 0.1][permutation]
guess_us = [t -> 0.1]
InfiniteOpt.@variables(m,
begin
# state variables
(lb[i] <= x[i = 1:nx] <= ub[i], Infinite(τ), start = guess_xs[i]) # state variables
-10 <= u[i = 1:nu] <= 10, Infinite(τ), (start = guess_us[i]) # control variables
0 <= tf <= 10, (start = 5) # Final time
0.2 <= L <= 0.6, (start = 0.4) # Length parameter
end)

# Trace the dynamics
x0, p = ModelingToolkit.get_u0_p(io_sys, [model.θ => 0, model.ω => 0], [model.L => L])

xp = f_oop(x, u, p, τ)
cp = f_obs(x, u, p, τ) # Test that it's possible to trace through an observed function

@objective(m, Min, tf)
@constraint(m, [i = 1:nx], x[i](0)==x0[i]) # Initial condition
@constraint(m, [i = 1:nx], x[i](1)==xf[i]) # Terminal state

x_scale = varmap_to_vars([model.θ => 1
model.ω => 1], dvs)

# Add dynamics constraints
@constraint(m, [i = 1:nx], (∂(x[i], τ) - tf * xp[i]) / x_scale[i]==0)

optimize!(m)

# Extract the optimal solution
opt_tf = value(tf)
opt_time = opt_tf * value(τ)
opt_x = [value(x[i]) for i in permutation]
opt_u = [value(u[i]) for i in 1:nu]
opt_L = value(L)

# Plot the results
# using Plots
# plot(opt_time, opt_x[1], label = "θ", xlabel = "Time [s]", layout=3)
# plot!(opt_time, opt_x[2], label = "ω", sp=2)
# plot!(opt_time, opt_u[1], label = "τ", sp=3)

using Test
@test opt_x[1][end]≈pi atol=1e-3
@test opt_x[2][end]≈0 atol=1e-3

@test opt_x[1][1]≈0 atol=1e-3
@test opt_x[2][1]≈0 atol=1e-3

@test opt_L≈0.2 atol=1e-3 # Smallest permissible length is optimal
2 changes: 1 addition & 1 deletion test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,6 @@ end
@parameters p[1:2] = [1.0, 2.0]
@mtkbuild sys = ODESystem([D(x) ~ x, y^2 ~ x + sum(p)], t)
prob = DAEProblem(sys, [D(x) => x, D(y) => D(x) / 2y], [], (0.0, 1.0))
sol = solve(prob, DFBDF(), abstol=1e-8, reltol=1e-8)
sol = solve(prob, DFBDF(), abstol = 1e-8, reltol = 1e-8)
@test sol[x]sol[y^2 - sum(p)] atol=1e-5
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ end
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
@safetestset "LabelledArrays Test" include("labelledarrays.jl")
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
@safetestset "InfiniteOpt Extension Test" include("extensions/test_infiniteopt.jl")
end
end
Loading