Skip to content

Commit

Permalink
Merge pull request #3273 from SciML/disturbance_args
Browse files Browse the repository at this point in the history
add option to include disturbance args in `generate_control_function`
  • Loading branch information
ChrisRackauckas authored Dec 16, 2024
2 parents ba842c2 + a79f4c2 commit 31f7a54
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 30 deletions.
25 changes: 17 additions & 8 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
# Build control function

"""
(f_oop, f_ip), x_sym, p, io_sys = generate_control_function(
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
sys::AbstractODESystem,
inputs = unbound_inputs(sys),
disturbance_inputs = nothing;
Expand All @@ -177,8 +177,7 @@ f_ip : (xout,x,u,p,t) -> nothing
The return values also include the chosen state-realization (the remaining unknowns) `x_sym` and parameters, in the order they appear as arguments to `f`.
If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement.
See [`add_input_disturbance`](@ref) for a higher-level interface to this functionality.
If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will (by default) not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. To add an input argument corresponding to the disturbance inputs, either include the disturbance inputs among the control inputs, or set `disturbance_argument=true`, in which case an additional input argument `w` is added to the generated function `(x,u,p,t,w)->rhs`.
!!! note "Un-simplified system"
This function expects `sys` to be un-simplified, i.e., `structural_simplify` or `@mtkbuild` should not be called on the system before passing it into this function. `generate_control_function` calls a special version of `structural_simplify` internally.
Expand All @@ -196,6 +195,7 @@ f[1](x, inputs, p, t)
"""
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys),
disturbance_inputs = disturbances(sys);
disturbance_argument = false,
implicit_dae = false,
simplify = false,
eval_expression = false,
Expand All @@ -219,10 +219,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
# ps = [ps; disturbance_inputs]
end
inputs = map(x -> time_varying_as_func(value(x), sys), inputs)
disturbance_inputs = unwrap.(disturbance_inputs)

eqs = [eq for eq in full_equations(sys)]
eqs = map(subs_constants, eqs)
if disturbance_inputs !== nothing
if disturbance_inputs !== nothing && !disturbance_argument
# Set all disturbance *inputs* to zero (we just want to keep the disturbance state)
subs = Dict(disturbance_inputs .=> 0)
eqs = [eq.lhs ~ substitute(eq.rhs, subs) for eq in eqs]
Expand All @@ -239,16 +240,24 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
t = get_iv(sys)

# pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys)

args = (u, inputs, p..., t)
if disturbance_argument
args = (u, inputs, p..., t, disturbance_inputs)
else
args = (u, inputs, p..., t)
end
if implicit_dae
ddvs = map(Differential(get_iv(sys)), dvs)
args = (ddvs, args...)
end
process = get_postprocess_fbody(sys)
wrapped_arrays_vars = disturbance_argument ?
wrap_array_vars(
sys, rhss; dvs, ps, inputs, extra_args = (disturbance_inputs,)) :
wrap_array_vars(sys, rhss; dvs, ps, inputs)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘
wrap_array_vars(sys, rhss; dvs, ps, inputs) .∘
expression = Val{true}, wrap_code = wrap_mtkparameters(
sys, false, 3, Int(disturbance_argument) + 1) .∘
wrapped_arrays_vars .∘
wrap_parameter_dependencies(sys, false),
kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
Expand Down
40 changes: 35 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,33 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
end

"""
$(TYPEDSIGNATURES)
Add the necessary assignment statements to allow use of unscalarized array variables
in the generated code. `expr` is the expression returned by the function. `dvs` and
`ps` are the unknowns and parameters of the system `sys` to use in the generated code.
`inputs` can be specified as an array of symbolics if the generated function has inputs.
If `history == true`, the generated function accepts a history function. `cachesyms` are
extra variables (arrays of variables) stored in the cache array(s) of the parameter
object. `extra_args` are extra arguments appended to the end of the argument list.
The function is assumed to have the signature `f(du, u, h, x, p, cache_syms..., t, extra_args...)`
Where:
- `du` is the optional buffer to write to for in-place functions.
- `u` is the list of unknowns. This argument is not present if `dvs === nothing`.
- `h` is the optional history function, present if `history == true`.
- `x` is the array of inputs, present only if `inputs !== nothing`. Values are assumed
to be in the order of variables passed to `inputs`.
- `p` is the parameter object.
- `cache_syms` are the cache variables. These are part of the splatted parameter object.
- `t` is time, present only if the system is time dependent.
- `extra_args` are the extra arguments passed to the function, present only if
`extra_args` is non-empty.
"""
function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
inputs = nothing, history = false, cachesyms::Tuple = ())
inputs = nothing, history = false, cachesyms::Tuple = (), extra_args::Tuple = ())
isscalar = !(exprs isa AbstractArray)
var_to_arridxs = Dict()

Expand All @@ -252,6 +276,10 @@ function wrap_array_vars(
if inputs !== nothing
rps = (inputs, rps...)
end
if has_iv(sys)
rps = (rps..., get_iv(sys))
end
rps = (rps..., extra_args...)
for sym in reduce(vcat, rps; init = [])
iscall(sym) && operation(sym) == getindex || continue
arg = arguments(sym)[1]
Expand Down Expand Up @@ -332,7 +360,7 @@ end
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)

"""
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, offset = Int(is_time_dependent(sys)))
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
Expand All @@ -342,12 +370,14 @@ the first parameter vector in the out-of-place version of the function. For exam
if a history function (DDEs) was passed before `p`, then the function before wrapping
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
`offset` is the number of arguments at the end of the argument list to ignore. Defaults
to 1 if the system is time-dependent (to ignore `t`) and 0 otherwise.
The returned function is `identity` if the system does not have an `IndexCache`.
"""
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2,
offset = Int(is_time_dependent(sys)))
if has_index_cache(sys) && get_index_cache(sys) !== nothing
offset = Int(is_time_dependent(sys))

if isscalar
function (expr)
param_args = expr.args[p_start:(end - offset)]
Expand Down
76 changes: 59 additions & 17 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,66 @@ if VERSION >= v"1.8" # :opaque_closure not supported before
end

## Code generation with unbound inputs
@testset "generate_control_function with disturbance inputs" begin
for split in [true, false]
simplify = true

@variables x(t)=0 u(t)=0 [input = true]
eqs = [
D(x) ~ -x + u
]

@named sys = ODESystem(eqs, t)
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split)

@test isequal(dvs[], x)
@test isempty(ps)

p = nothing
x = [rand()]
u = [rand()]
@test f[1](x, u, p, 1) == -x + u

# With disturbance inputs
@variables x(t)=0 u(t)=0 [input = true] d(t)=0
eqs = [
D(x) ~ -x + u + d^2
]

@named sys = ODESystem(eqs, t)
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
sys, [u], [d]; simplify, split)

@test isequal(dvs[], x)
@test isempty(ps)

p = nothing
x = [rand()]
u = [rand()]
@test f[1](x, u, p, 1) == -x + u

## With added d argument
@variables x(t)=0 u(t)=0 [input = true] d(t)=0
eqs = [
D(x) ~ -x + u + d^2
]

@named sys = ODESystem(eqs, t)
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(
sys, [u], [d]; simplify, split, disturbance_argument = true)

@test isequal(dvs[], x)
@test isempty(ps)

p = nothing
x = [rand()]
u = [rand()]
d = [rand()]
@test f[1](x, u, p, t, d) == -x + u + [d[]^2]
end
end

@variables x(t)=0 u(t)=0 [input = true]
eqs = [
D(x) ~ -x + u
]

@named sys = ODESystem(eqs, t)
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)

@test isequal(dvs[], x)
@test isempty(ps)

p = nothing
x = [rand()]
u = [rand()]
@test f[1](x, u, p, 1) == -x + u

# more complicated system
## more complicated system

@variables u(t) [input = true]

Expand Down

0 comments on commit 31f7a54

Please sign in to comment.