diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 4a99ec11f5..6bdcac6dd4 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -160,7 +160,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex)) # Build control function """ - (f_oop, f_ip), x_sym, p, io_sys = generate_control_function( + (f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function( sys::AbstractODESystem, inputs = unbound_inputs(sys), disturbance_inputs = nothing; @@ -177,8 +177,7 @@ f_ip : (xout,x,u,p,t) -> nothing The return values also include the chosen state-realization (the remaining unknowns) `x_sym` and parameters, in the order they appear as arguments to `f`. -If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. -See [`add_input_disturbance`](@ref) for a higher-level interface to this functionality. +If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will (by default) not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. To add an input argument corresponding to the disturbance inputs, either include the disturbance inputs among the control inputs, or set `disturbance_argument=true`, in which case an additional input argument `w` is added to the generated function `(x,u,p,t,w)->rhs`. !!! note "Un-simplified system" This function expects `sys` to be un-simplified, i.e., `structural_simplify` or `@mtkbuild` should not be called on the system before passing it into this function. `generate_control_function` calls a special version of `structural_simplify` internally. @@ -196,6 +195,7 @@ f[1](x, inputs, p, t) """ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys), disturbance_inputs = disturbances(sys); + disturbance_argument = false, implicit_dae = false, simplify = false, eval_expression = false, @@ -219,10 +219,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu # ps = [ps; disturbance_inputs] end inputs = map(x -> time_varying_as_func(value(x), sys), inputs) + disturbance_inputs = unwrap.(disturbance_inputs) eqs = [eq for eq in full_equations(sys)] eqs = map(subs_constants, eqs) - if disturbance_inputs !== nothing + if disturbance_inputs !== nothing && !disturbance_argument # Set all disturbance *inputs* to zero (we just want to keep the disturbance state) subs = Dict(disturbance_inputs .=> 0) eqs = [eq.lhs ~ substitute(eq.rhs, subs) for eq in eqs] @@ -239,16 +240,24 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu t = get_iv(sys) # pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys) - - args = (u, inputs, p..., t) + if disturbance_argument + args = (u, inputs, p..., t, disturbance_inputs) + else + args = (u, inputs, p..., t) + end if implicit_dae ddvs = map(Differential(get_iv(sys)), dvs) args = (ddvs, args...) end process = get_postprocess_fbody(sys) + wrapped_arrays_vars = disturbance_argument ? + wrap_array_vars( + sys, rhss; dvs, ps, inputs, extra_args = (disturbance_inputs,)) : + wrap_array_vars(sys, rhss; dvs, ps, inputs) f = build_function(rhss, args...; postprocess_fbody = process, - expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘ - wrap_array_vars(sys, rhss; dvs, ps, inputs) .∘ + expression = Val{true}, wrap_code = wrap_mtkparameters( + sys, false, 3, Int(disturbance_argument) + 1) .∘ + wrapped_arrays_vars .∘ wrap_parameter_dependencies(sys, false), kwargs...) f = eval_or_rgf.(f; eval_expression, eval_module) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index e44f250a7f..168260ae69 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -230,9 +230,33 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar) wrap_assignments(isscalar, [eq.lhs ← eq.rhs for eq in parameter_dependencies(sys)]) end +""" + $(TYPEDSIGNATURES) + +Add the necessary assignment statements to allow use of unscalarized array variables +in the generated code. `expr` is the expression returned by the function. `dvs` and +`ps` are the unknowns and parameters of the system `sys` to use in the generated code. +`inputs` can be specified as an array of symbolics if the generated function has inputs. +If `history == true`, the generated function accepts a history function. `cachesyms` are +extra variables (arrays of variables) stored in the cache array(s) of the parameter +object. `extra_args` are extra arguments appended to the end of the argument list. + +The function is assumed to have the signature `f(du, u, h, x, p, cache_syms..., t, extra_args...)` +Where: +- `du` is the optional buffer to write to for in-place functions. +- `u` is the list of unknowns. This argument is not present if `dvs === nothing`. +- `h` is the optional history function, present if `history == true`. +- `x` is the array of inputs, present only if `inputs !== nothing`. Values are assumed + to be in the order of variables passed to `inputs`. +- `p` is the parameter object. +- `cache_syms` are the cache variables. These are part of the splatted parameter object. +- `t` is time, present only if the system is time dependent. +- `extra_args` are the extra arguments passed to the function, present only if + `extra_args` is non-empty. +""" function wrap_array_vars( sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), - inputs = nothing, history = false, cachesyms::Tuple = ()) + inputs = nothing, history = false, cachesyms::Tuple = (), extra_args::Tuple = ()) isscalar = !(exprs isa AbstractArray) var_to_arridxs = Dict() @@ -252,6 +276,10 @@ function wrap_array_vars( if inputs !== nothing rps = (inputs, rps...) end + if has_iv(sys) + rps = (rps..., get_iv(sys)) + end + rps = (rps..., extra_args...) for sym in reduce(vcat, rps; init = []) iscall(sym) && operation(sym) == getindex || continue arg = arguments(sym)[1] @@ -332,7 +360,7 @@ end const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___) """ - wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2) + wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, offset = Int(is_time_dependent(sys))) Return function(s) to be passed to the `wrap_code` keyword of `build_function` which allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters` @@ -342,12 +370,14 @@ the first parameter vector in the out-of-place version of the function. For exam if a history function (DDEs) was passed before `p`, then the function before wrapping would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`. +`offset` is the number of arguments at the end of the argument list to ignore. Defaults +to 1 if the system is time-dependent (to ignore `t`) and 0 otherwise. + The returned function is `identity` if the system does not have an `IndexCache`. """ -function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2) +function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, + offset = Int(is_time_dependent(sys))) if has_index_cache(sys) && get_index_cache(sys) !== nothing - offset = Int(is_time_dependent(sys)) - if isscalar function (expr) param_args = expr.args[p_start:(end - offset)] diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 9550a87f31..de6fc92b5c 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -153,24 +153,66 @@ if VERSION >= v"1.8" # :opaque_closure not supported before end ## Code generation with unbound inputs +@testset "generate_control_function with disturbance inputs" begin + for split in [true, false] + simplify = true + + @variables x(t)=0 u(t)=0 [input = true] + eqs = [ + D(x) ~ -x + u + ] + + @named sys = ODESystem(eqs, t) + f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split) + + @test isequal(dvs[], x) + @test isempty(ps) + + p = nothing + x = [rand()] + u = [rand()] + @test f[1](x, u, p, 1) == -x + u + + # With disturbance inputs + @variables x(t)=0 u(t)=0 [input = true] d(t)=0 + eqs = [ + D(x) ~ -x + u + d^2 + ] + + @named sys = ODESystem(eqs, t) + f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( + sys, [u], [d]; simplify, split) + + @test isequal(dvs[], x) + @test isempty(ps) + + p = nothing + x = [rand()] + u = [rand()] + @test f[1](x, u, p, 1) == -x + u + + ## With added d argument + @variables x(t)=0 u(t)=0 [input = true] d(t)=0 + eqs = [ + D(x) ~ -x + u + d^2 + ] + + @named sys = ODESystem(eqs, t) + f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( + sys, [u], [d]; simplify, split, disturbance_argument = true) + + @test isequal(dvs[], x) + @test isempty(ps) + + p = nothing + x = [rand()] + u = [rand()] + d = [rand()] + @test f[1](x, u, p, t, d) == -x + u + [d[]^2] + end +end -@variables x(t)=0 u(t)=0 [input = true] -eqs = [ - D(x) ~ -x + u -] - -@named sys = ODESystem(eqs, t) -f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true) - -@test isequal(dvs[], x) -@test isempty(ps) - -p = nothing -x = [rand()] -u = [rand()] -@test f[1](x, u, p, 1) == -x + u - -# more complicated system +## more complicated system @variables u(t) [input = true]