Skip to content

Commit

Permalink
Merge pull request #3200 from BenChung/explicit-observed-func
Browse files Browse the repository at this point in the history
Add docstring and oop array construction method for build_explicit_observed_function
  • Loading branch information
ChrisRackauckas authored Nov 15, 2024
2 parents 24a7c2e + 14996d7 commit 2c19234
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,48 @@ end
ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...)

"""
$(SIGNATURES)
Build the observed function assuming the observed equations are all explicit,
i.e. there are no cycles.
build_explicit_observed_function(sys, ts; kwargs...) -> Function(s)
Generates a function that computes the observed value(s) `ts` in the system `sys`, while making the assumption that there are no cycles in the equations.
## Arguments
- `sys`: The system for which to generate the function
- `ts`: The symbolic observed values whose value should be computed
## Keywords
- `return_inplace = false`: If true and the observed value is a vector, then return both the in place and out of place methods.
- `expression = false`: Generates a Julia `Expr`` computing the observed value if `expression` is true
- `eval_expression = false`: If true and `expression = false`, evaluates the returned function in the module `eval_module`
- `output_type = Array` the type of the array generated by a out-of-place vector-valued function
- `param_only = false` if true, only allow the generated function to access system parameters
- `inputs = nothing` additinoal symbolic variables that should be provided to the generated function
- `checkbounds = true` checks bounds if true when destructuring parameters
- `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
- `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist.
- `mkarray`; only used if the output is an array (that is, `!isscalar(ts)` and `ts` is not a tuple, in which case the result will always be a tuple). Called as `mkarray(ts, output_type)` where `ts` are the expressions to put in
the array and `output_type` is the argument of the same name passed to build_explicit_observed_function.
## Returns
The return value will be either:
* a single function `f_oop` if the input is a scalar or if the input is a Vector but `return_inplace` is false
* the out of place and in-place functions `(f_ip, f_oop)` if `return_inplace` is true and the input is a `Vector`
The function(s) `f_oop` (and potentially `f_ip`) will be:
* `RuntimeGeneratedFunction`s by default,
* A Julia `Expr` if `expression` is true,
* A directly evaluated Julia function in the module `eval_module` if `eval_expression` is true and `expression` is false.
The signatures will be of the form `g(...)` with arguments:
- `output` for in-place functions
- `unknowns` if `param_only` is `false`
- `inputs` if `inputs` is an array of symbolic inputs that should be available in `ts`
- `p...` unconditionally; note that in the case of `MTKParameters` more than one parameters argument may be present, so it must be splatted
- `t` if the system is time-dependent; for example `NonlinearSystem` will not have `t`
For example, a function `g(op, unknowns, p..., inputs, t)` will be the in-place function generated if `return_inplace` is true, `ts` is a vector,
an array of inputs `inputs` is given, and `param_only` is false for a time-dependent system.
"""
function build_explicit_observed_function(sys, ts;
inputs = nothing,
Expand All @@ -421,12 +459,12 @@ function build_explicit_observed_function(sys, ts;
eval_module = @__MODULE__,
output_type = Array,
checkbounds = true,
drop_expr = drop_expr,
ps = parameters(sys),
return_inplace = false,
param_only = false,
op = Operator,
throw = true)
throw = true,
mkarray = MakeArray)
is_tuple = ts isa Tuple
if is_tuple
ts = collect(ts)
Expand Down Expand Up @@ -582,7 +620,7 @@ function build_explicit_observed_function(sys, ts;
elseif is_tuple
MakeTuple(Tuple(ts))
else
MakeArray(ts, output_type)
mkarray(ts, output_type)
end
oop_fn = Func(args, [],
pre(Let(obsexprs,
Expand Down

0 comments on commit 2c19234

Please sign in to comment.