Skip to content

Commit

Permalink
Merge pull request #2301 from SciML/precompile_workload
Browse files Browse the repository at this point in the history
Fix recursive structure and setup precompilation
  • Loading branch information
ChrisRackauckas authored Oct 6, 2023
2 parents fb7c3af + 6e795b8 commit 7949761
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 19 deletions.
2 changes: 1 addition & 1 deletion ext/MTKDeepDiffsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using DeepDiffs, ModelingToolkit
using ModelingToolkit.BipartiteGraphs: Label,
BipartiteAdjacencyList, unassigned,
HighlightInt
using ModelingToolkit.SystemStructures: SystemStructure,
using ModelingToolkit: SystemStructure,
MatchedSystemStructure,
SystemStructurePrintMatrix

Expand Down
16 changes: 12 additions & 4 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ using PrecompileTools, Reexport
using MLStyle

using Reexport
using Symbolics
using Symbolics: degree
@reexport using Symbolics
using Symbolics: _parse_vars, value, @derivatives, get_variables,
exprs_occur_in, solve_for, build_expr, unwrap, wrap,
VariableSource, getname, variable, Connection, connect,
Expand All @@ -70,9 +70,10 @@ using PrecompileTools, Reexport
import OrdinaryDiffEq

import Graphs: SimpleDiGraph, add_edge!, incidence_matrix

@reexport using UnPack
end

@reexport using Symbolics
@reexport using UnPack
RuntimeGeneratedFunctions.init(@__MODULE__)

export @derivatives
Expand Down Expand Up @@ -156,7 +157,6 @@ include("systems/dependency_graphs.jl")
include("clock.jl")
include("discretedomain.jl")
include("systems/systemstructure.jl")
using .SystemStructures
include("systems/clock_inference.jl")
include("systems/systems.jl")

Expand All @@ -172,6 +172,14 @@ for S in subtypes(ModelingToolkit.AbstractSystem)
@eval convert_system(::Type{<:$S}, sys::$S) = sys
end

PrecompileTools.@compile_workload begin
using ModelingToolkit
@variables t x(t)
D = Differential(t)
@named sys = ODESystem([D(x) ~ -x])
prob = ODEProblem(structural_simplify(sys), [x => 30.0], (0, 100), [], jac = true)
end

export AbstractTimeDependentSystem,
AbstractTimeIndependentSystem,
AbstractMultivariateSystem
Expand Down
10 changes: 7 additions & 3 deletions src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
fast_substitute, get_fullvars, has_equations
fast_substitute, get_fullvars, has_equations, observed

using ModelingToolkit.BipartiteGraphs
import .BipartiteGraphs: invview, complete
import ModelingToolkit: var_derivative!, var_derivative_graph!
using Graphs
using ModelingToolkit.SystemStructures
using ModelingToolkit.SystemStructures: algeqs, EquationsView
using ModelingToolkit: algeqs, EquationsView,
SystemStructure, TransformationState, TearingState, structural_simplify!,
isdiffvar, isdervar, isalgvar, isdiffeq, algeqs, is_only_discrete,
dervars_range, diffvars_range, algvars_range,
DiffGraph, complete!,
get_fullvars, system_subset

using ModelingToolkit.DiffEqBase
using ModelingToolkit.StaticArrays
Expand Down
11 changes: 3 additions & 8 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
module SystemStructures

using DataStructures
using Symbolics: linear_expansion, unwrap
using Symbolics: linear_expansion, unwrap, Connection
using SymbolicUtils: istree, operation, arguments, Symbolic
using SymbolicUtils: quick_cancel, similarterm
using ..ModelingToolkit
Expand All @@ -10,7 +8,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
isparameter, isconstant,
independent_variables, SparseMatrixCLIL, AbstractSystem,
equations, isirreducible, input_timedomain, TimeDomain,
VariableType, getvariabletype, has_equations
VariableType, getvariabletype, has_equations, ODESystem
using ..BipartiteGraphs
import ..BipartiteGraphs: invview, complete
using Graphs
Expand All @@ -27,8 +25,7 @@ function quick_cancel_expr(expr)
end

export SystemStructure, TransformationState, TearingState, structural_simplify!
export initialize_system_structure, find_linear_equations
export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete
export isdiffvar, isdervar, isalgvar, isdiffeq, algeqs, is_only_discrete
export dervars_range, diffvars_range, algvars_range
export DiffGraph, complete!
export get_fullvars, system_subset
Expand Down Expand Up @@ -620,5 +617,3 @@ function _structural_simplify!(state::TearingState, io; simplify = false,
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
ModelingToolkit.invalidate_cache!(sys), input_idxs
end

end # module
5 changes: 2 additions & 3 deletions test/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,12 @@ By inference:
=> Shift(x, 0, dt) := (Shift(x, -1, dt) + dt) / (1 - dt) # Discrete system
=#

using ModelingToolkit.SystemStructures
ci, varmap = infer_clocks(sys)
eqmap = ci.eq_domain
tss, inputs = ModelingToolkit.split_system(deepcopy(ci))
sss, = SystemStructures._structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
@test equations(sss) == [D(x) ~ u - x]
sss, = SystemStructures._structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[2]), (inputs[2], ()))
@test isempty(equations(sss))
@test observed(sss) == [yd ~ Sample(t, dt)(y); r ~ 1.0; ud ~ kp * (r - yd)]

Expand Down

0 comments on commit 7949761

Please sign in to comment.