Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implementation of new SymbolicIndexingInterface #2345

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
OrdinaryDiffEq = "6"
PrecompileTools = "1"
RecursiveArrayTools = "2.3"
RecursiveArrayTools = "2.3, 3"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.0.1"
Expand All @@ -98,7 +98,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.1, 0.2"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "1.0"
Symbolics = "5.7"
URIs = "1"
Expand Down
3 changes: 1 addition & 2 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ using PrecompileTools, Reexport

using RecursiveArrayTools

import SymbolicIndexingInterface
import SymbolicIndexingInterface: independent_variables, states, parameters
using SymbolicIndexingInterface
export independent_variables, states, parameters
import SymbolicUtils
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
Expand Down
135 changes: 109 additions & 26 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,18 @@ function independent_variable(sys::AbstractSystem)
isdefined(sys, :iv) ? getfield(sys, :iv) : nothing
end

#Treat the result as a vector of symbols always
function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
systype = typeof(sys)
@warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
function independent_variables(sys::AbstractTimeDependentSystem)
return [getfield(sys, :iv)]
end

independent_variables(::AbstractTimeIndependentSystem) = []

function independent_variables(sys::AbstractMultivariateSystem)
return getfield(sys, :ivs)
end

function independent_variables(sys::AbstractSystem)
@warn "Please declare ($(typeof(sys))) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
if isdefined(sys, :iv)
return [getfield(sys, :iv)]
elseif isdefined(sys, :ivs)
Expand All @@ -174,14 +182,102 @@ function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
end
end

function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem)
[getfield(sys, :iv)]
#Treat the result as a vector of symbols always
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems really ugly.

Copy link
Member Author

@AayushSabharwal AayushSabharwal Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is. It's the only way I could find to cover all the cases

  • we need the isa Int check because [x, 1] wraps 1 in a Num (e.g. sol(1.0, idxs=[1, x])
  • The next one is necessary since sol[x + y] is a possibility, and getname errors on x+y
  • The getname case is necessary for namespaced variables

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does unwrap do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It gets the value inside a Num, which is usually either a SymbolicUtils.BasicSymbolic for symbolic variables or a number in the case of something like Num(1)

return unwrap(sym) in 1:length(unknown_states(sys))
end
return any(isequal(sym), unknown_states(sys)) || hasname(sym) && is_variable(sys, getname(sym))
end

function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(unknown_states(sys))) || count('₊', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
end
SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = []
function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem)
getfield(sys, :ivs)

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
idx = findfirst(isequal(sym), unknown_states(sys))
if idx === nothing && hasname(sym)
idx = variable_index(sys, getname(sym))
end
return idx
end

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(unknown_states(sys)))
if idx !== nothing
return idx
elseif count('₊', string(sym)) == 1
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys))))
end
return nothing
end

function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
return unknown_states(sys)
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usym

return unwrap(sym) in 1:length(parameters(sys))
end

return any(isequal(sym), parameters(sys)) || hasname(sym) && is_parameter(sys, getname(sym))
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(parameters(sys))) ||
count('₊', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
idx = findfirst(isequal(sym), parameters(sys))
if idx === nothing && hasname(sym)
idx = parameter_index(sys, getname(sym))
end
return idx
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(parameters(sys)))
if idx !== nothing
return idx
elseif count('₊', string(sym)) == 1
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys))))
end
return nothing
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
return parameters(sys)
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
return any(isequal(sym), independent_variables(sys))
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(independent_variables(sys)))
end

function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSystem)
return independent_variables(sys)
end

function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
return !is_variable(sys, sym) && !is_parameter(sys, sym) && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false

SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true

iscomplete(sys::AbstractSystem) = isdefined(sys, :complete) && getfield(sys, :complete)

"""
Expand Down Expand Up @@ -534,12 +630,15 @@ function states(sys::AbstractSystem)
[sts; reduce(vcat, namespace_variables.(systems))])
end

function SymbolicIndexingInterface.parameters(sys::AbstractSystem)
function parameters(sys::AbstractSystem)
ps = get_ps(sys)
systems = get_systems(sys)
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
end

# required in `src/connectors.jl:437`
parameters(_) = []

function controls(sys::AbstractSystem)
ctrls = get_ctrls(sys)
systems = get_systems(sys)
Expand Down Expand Up @@ -638,8 +737,6 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
return x
end

SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))

"""
$(SIGNATURES)

Expand All @@ -653,20 +750,6 @@ function unknown_states(sys::AbstractSystem)
return sts
end

function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), unknown_states(sys))
end
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))
end

function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys))
end
function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym))
end

###
### System utils
###
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
tgrad = _tgrad === nothing ? nothing : _tgrad,
mass_matrix = _M,
jac_prototype = jac_prototype,
syms = Symbol.(states(sys)),
syms = collect(Symbol.(states(sys))),
indepsym = Symbol(get_iv(sys)),
paramsyms = Symbol.(ps),
paramsyms = collect(Symbol.(ps)),
observed = observedfun,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic)
Expand Down
Loading