Skip to content

Commit

Permalink
Document build_explicit_observed_function and allow user-defined arra…
Browse files Browse the repository at this point in the history
…y construction
  • Loading branch information
BenChung committed Nov 13, 2024
1 parent b52bce7 commit e91738b
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,32 @@ ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...
"""
$(SIGNATURES)
Build the observed function assuming the observed equations are all explicit,
i.e. there are no cycles.
Generates a function that computes the observed value(s) `ts` in the system `sys` assuming that there are no cycles in the equations.
The return value will be either:
* a single function if the input is a scalar or if the input is a Vector but `return_inplace` is false
* the out of place and in-place functions `(ip, oop)` if `return_inplace` is true and the input is a `Vector`
The function(s) will be:
* `RuntimeGeneratedFunction`s by default,
* A Julia `Expr` if `expression` is true,
* A directly evaluated Julia function in the module `eval_module` if `eval_expression` is true
The signatures will be of the form `g(...)` with arguments:
* `output` for in-place functions
* `unknowns` if `params_only` is `false`
* `inputs` if `inputs` is an array of symbolic inputs that should be available in `ts`
* `p...` unconditionally; note that in the case of `MTKParameters` more than one parameters argument may be present, so it must be splatted
* `t` if the system is time-dependent; for example `NonlinearSystem` will not have `t`
For example, a function `g(op, unknowns, p, inputs, t)` will be the in-place function generated if `return_inplace` is true, `ts` is a vector, an array of inputs `inputs` is given, and `params_only` is false for a time-dependent system.
Options not otherwise specified are:
* `output_type = Array` the type of the array generated by the out-of-place vector-valued function
* `checkbounds = true` checks bounds if true when destructuring parameters
* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist
* `drop_expr` is deprecated.
* `mkarray`; only used if the output is an array (that is, `!isscalar(ts)`). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
"""
function build_explicit_observed_function(sys, ts;
inputs = nothing,
Expand All @@ -426,7 +450,8 @@ function build_explicit_observed_function(sys, ts;
return_inplace = false,
param_only = false,
op = Operator,
throw = true)
throw = true,
mkarray = MakeArray)
if (isscalar = symbolic_type(ts) !== NotSymbolic())
ts = [ts]
end
Expand Down Expand Up @@ -571,12 +596,11 @@ function build_explicit_observed_function(sys, ts;
oop_mtkp_wrapper = mtkparams_wrapper
end

output_expr = isscalar ? ts[1] : mkarray(ts, output_type)
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
oop_fn = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |>
oop_mtkp_wrapper |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
Expand Down

0 comments on commit e91738b

Please sign in to comment.