Skip to content

Commit

Permalink
fixup! feat: support inplace parameter observed
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 27, 2024
1 parent c79851f commit 7c3817e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
end
end

function wrap_assignments(isscalar, assignments; let_block = false)
function wrapper(expr)
Func(expr.args, [], Let(assignments, expr.body, let_block))
end
if isscalar
wrapper
else
wrapper, wrapper
end
end

function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
Expand Down Expand Up @@ -505,7 +516,7 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
ts_idx = nothing
end
rawobs = build_explicit_observed_function(
sys, sym; param_only = true, return_inplace = true)
sys, sym; param_only = true, return_inplace = true)
if rawobs isa Tuple
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1(p::MTKParameters, t) = oop(p..., t)
Expand Down
12 changes: 9 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,17 +487,23 @@ function build_explicit_observed_function(sys, ts;
if inputs === nothing
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]
else
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
end
pre = get_postprocess_fbody(sys)
res = build_function(isscalar ? ts[1] : ts, args...; get_postprocess_fbody = pre, wrap_code = wrap_array_vars(sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)), expression = Val{expression})
res = build_function(isscalar ? ts[1] : ts,
args...;
postprocess_fbody = pre,
wrap_code = wrap_array_vars(
sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)) .∘
wrap_assignments(isscalar, obsexprs),
expression = Val{expression})
if isscalar || return_inplace
return res
else
return res[1]
end

ex = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
Expand Down

0 comments on commit 7c3817e

Please sign in to comment.