Skip to content

Commit

Permalink
Merge pull request #3202 from AayushSabharwal/as/tuple-observed
Browse files Browse the repository at this point in the history
feat: support directly generating observed functions for tuples
  • Loading branch information
ChrisRackauckas authored Nov 13, 2024
2 parents 9cf859f + 8edc6b0 commit 02cbd76
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.31"
SymbolicIndexingInterface = "0.3.35"
SymbolicUtils = "3.7"
Symbolics = "6.15.4"
URIs = "1"
Expand Down
5 changes: 4 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,8 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
Expand All @@ -827,7 +829,8 @@ function SymbolicIndexingInterface.observed(
throw(ArgumentError("Symbol $sym does not exist in the system"))
end
sym = _sym
elseif sym isa AbstractArray && symbolic_type(sym) isa NotSymbolic &&
elseif (sym isa Tuple ||
(sym isa AbstractArray && symbolic_type(sym) isa NotSymbolic)) &&
any(x -> x isa Symbol, sym)
sym = map(sym) do s
if s isa Symbol
Expand Down
13 changes: 12 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ function build_explicit_observed_function(sys, ts;
param_only = false,
op = Operator,
throw = true)
is_tuple = ts isa Tuple
if is_tuple
ts = collect(ts)
end
if (isscalar = symbolic_type(ts) !== NotSymbolic())
ts = [ts]
end
Expand Down Expand Up @@ -573,9 +577,16 @@ function build_explicit_observed_function(sys, ts;

# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
return_value = if isscalar
ts[1]
elseif is_tuple
MakeTuple(Tuple(ts))
else
MakeArray(ts, output_type)
end
oop_fn = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
return_value,
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

Expand Down
19 changes: 19 additions & 0 deletions test/symbolic_indexing_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using SciMLStructures: Tunable
eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
@named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y])
odesys = complete(odesys)
@test SymbolicIndexingInterface.supports_tuple_observed(odesys)
@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y]))
@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b]))
@test variable_index.((odesys,), [x, y, a, b, t, 1, 2, :x, :y, :a, :b]) ==
Expand All @@ -33,6 +34,14 @@ using SciMLStructures: Tunable
@test default_values(odesys)[y] == 2.0
@test isequal(default_values(odesys)[xy], x + y)

prob = ODEProblem(odesys, [], (0.0, 1.0), [a => 1.0, b => 2.0])
getter = getu(odesys, (x + 1, x + 2))
@test getter(prob) isa Tuple
@test_nowarn @inferred getter(prob)
getter = getp(odesys, (a + 1, a + 2))
@test getter(prob) isa Tuple
@test_nowarn @inferred getter(prob)

@named odesys = ODESystem(
eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y])
odesys = complete(odesys)
Expand Down Expand Up @@ -99,6 +108,7 @@ end
0 ~ x * y - β * z]
@named ns = NonlinearSystem(eqs, [x, y, z], [σ, ρ, β])
ns = complete(ns)
@test SymbolicIndexingInterface.supports_tuple_observed(ns)
@test !is_time_dependent(ns)
ps = ModelingToolkit.MTKParameters(ns, [σ => 1.0, ρ => 2.0, β => 3.0])
pobs = parameter_observed(ns, σ + ρ)
Expand All @@ -107,6 +117,15 @@ end
pobs = parameter_observed(ns, [σ + ρ, ρ + β])
@test isempty(get_all_timeseries_indexes(ns, [σ + ρ, ρ + β]))
@test pobs(ps) == [3.0, 5.0]

prob = NonlinearProblem(
ns, [x => 1.0, y => 2.0, z => 3.0], [σ => 1.0, ρ => 2.0, β => 3.0])
getter = getu(ns, (x + 1, x + 2))
@test getter(prob) isa Tuple
@test_nowarn @inferred getter(prob)
getter = getp(ns, (σ + 1, σ + 2))
@test getter(prob) isa Tuple
@test_nowarn @inferred getter(prob)
end

@testset "PDESystem" begin
Expand Down

0 comments on commit 02cbd76

Please sign in to comment.