From 6e795b8574c4de02681b7a771c1ab3e3f1185926 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 6 Oct 2023 14:41:01 +0200 Subject: [PATCH] Fix recursive structure and setup precompilation The recursive structure of SystemStructure was giving precompilation an issue so I removed that module. Then I setup precompilation on the basic interface. Test case: ```julia @time using ModelingToolkit @time using DifferentialEquations @time using Plots @time begin @variables t x(t) D = Differential(t) @named sys = ODESystem([D(x) ~ -x]) end; @time prob = ODEProblem(structural_simplify(sys), [x => 30.0], (0, 100), [], jac = true); @time sol = solve(prob); @time plot(sol, idxs=[x]); ``` Before: ``` 11.082586 seconds (19.32 M allocations: 1.098 GiB, 4.09% gc time, 0.71% compilation time: 87% of which was recompilation) 0.639738 seconds (661.39 k allocations: 101.321 MiB, 4.33% gc time, 6.46% compilation time) 3.703724 seconds (5.71 M allocations: 322.840 MiB, 5.22% gc time, 9.92% compilation time: 86% of which was recompilation) 7.795297 seconds (8.25 M allocations: 483.041 MiB, 2.50% gc time, 99.88% compilation time) 21.719376 seconds (44.11 M allocations: 2.485 GiB, 5.68% gc time, 99.48% compilation time) 2.602250 seconds (4.04 M allocations: 253.058 MiB, 4.60% gc time, 99.90% compilation time) 2.450509 seconds (5.17 M allocations: 332.101 MiB, 5.89% gc time, 99.41% compilation time: 30% of which was recompilation) ``` After: ``` 9.129141 seconds (22.77 M allocations: 1.291 GiB, 4.65% gc time, 0.62% compilation time: 87% of which was recompilation) 0.784464 seconds (667.59 k allocations: 101.524 MiB, 3.95% gc time, 4.16% compilation time) 3.111142 seconds (5.42 M allocations: 305.594 MiB, 3.82% gc time, 6.39% compilation time: 82% of which was recompilation) 0.105567 seconds (157.39 k allocations: 10.522 MiB, 8.81% gc time, 95.49% compilation time: 74% of which was recompilation) 1.993642 seconds (4.03 M allocations: 218.310 MiB, 2.69% gc time, 96.95% compilation time: 82% of which was recompilation) 1.806758 seconds (4.06 M allocations: 254.371 MiB, 4.44% gc time, 99.91% compilation time) 1.694666 seconds (5.27 M allocations: 339.088 MiB, 6.18% gc time, 99.39% compilation time: 31% of which was recompilation) ``` And that's on v1.9, so v1.10 should be even better. 20 seconds off hehe. --- ext/MTKDeepDiffsExt.jl | 2 +- src/ModelingToolkit.jl | 16 ++++++++++++---- .../StructuralTransformations.jl | 10 +++++++--- src/systems/systemstructure.jl | 11 +++-------- test/clock.jl | 5 ++--- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/ext/MTKDeepDiffsExt.jl b/ext/MTKDeepDiffsExt.jl index 92bc9ba2b9..1d361f96a3 100644 --- a/ext/MTKDeepDiffsExt.jl +++ b/ext/MTKDeepDiffsExt.jl @@ -4,7 +4,7 @@ using DeepDiffs, ModelingToolkit using ModelingToolkit.BipartiteGraphs: Label, BipartiteAdjacencyList, unassigned, HighlightInt -using ModelingToolkit.SystemStructures: SystemStructure, +using ModelingToolkit: SystemStructure, MatchedSystemStructure, SystemStructurePrintMatrix diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index d5c4172340..1a6e0356a5 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -51,8 +51,8 @@ using PrecompileTools, Reexport using MLStyle using Reexport + using Symbolics using Symbolics: degree - @reexport using Symbolics using Symbolics: _parse_vars, value, @derivatives, get_variables, exprs_occur_in, solve_for, build_expr, unwrap, wrap, VariableSource, getname, variable, Connection, connect, @@ -70,9 +70,10 @@ using PrecompileTools, Reexport import OrdinaryDiffEq import Graphs: SimpleDiGraph, add_edge!, incidence_matrix - - @reexport using UnPack end + +@reexport using Symbolics +@reexport using UnPack RuntimeGeneratedFunctions.init(@__MODULE__) export @derivatives @@ -156,7 +157,6 @@ include("systems/dependency_graphs.jl") include("clock.jl") include("discretedomain.jl") include("systems/systemstructure.jl") -using .SystemStructures include("systems/clock_inference.jl") include("systems/systems.jl") @@ -172,6 +172,14 @@ for S in subtypes(ModelingToolkit.AbstractSystem) @eval convert_system(::Type{<:$S}, sys::$S) = sys end +PrecompileTools.@compile_workload begin + using ModelingToolkit + @variables t x(t) + D = Differential(t) + @named sys = ODESystem([D(x) ~ -x]) + prob = ODEProblem(structural_simplify(sys), [x => 30.0], (0, 100), [], jac = true) +end + export AbstractTimeDependentSystem, AbstractTimeIndependentSystem, AbstractMultivariateSystem diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 47b140b7d1..7289df4232 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -23,14 +23,18 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di IncrementalCycleTracker, add_edge_checked!, topological_sort, invalidate_cache!, Substitutions, get_or_construct_tearing_state, filter_kwargs, lower_varname, setio, SparseMatrixCLIL, - fast_substitute, get_fullvars, has_equations + fast_substitute, get_fullvars, has_equations, observed using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete import ModelingToolkit: var_derivative!, var_derivative_graph! using Graphs -using ModelingToolkit.SystemStructures -using ModelingToolkit.SystemStructures: algeqs, EquationsView +using ModelingToolkit: algeqs, EquationsView, + SystemStructure, TransformationState, TearingState, structural_simplify!, + isdiffvar, isdervar, isalgvar, isdiffeq, algeqs, is_only_discrete, + dervars_range, diffvars_range, algvars_range, + DiffGraph, complete!, + get_fullvars, system_subset using ModelingToolkit.DiffEqBase using ModelingToolkit.StaticArrays diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 4e330619d5..18e50afc9b 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1,7 +1,5 @@ -module SystemStructures - using DataStructures -using Symbolics: linear_expansion, unwrap +using Symbolics: linear_expansion, unwrap, Connection using SymbolicUtils: istree, operation, arguments, Symbolic using SymbolicUtils: quick_cancel, similarterm using ..ModelingToolkit @@ -10,7 +8,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, isparameter, isconstant, independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, - VariableType, getvariabletype, has_equations + VariableType, getvariabletype, has_equations, ODESystem using ..BipartiteGraphs import ..BipartiteGraphs: invview, complete using Graphs @@ -27,8 +25,7 @@ function quick_cancel_expr(expr) end export SystemStructure, TransformationState, TearingState, structural_simplify! -export initialize_system_structure, find_linear_equations -export isdiffvar, isdervar, isalgvar, isdiffeq, isalgeq, algeqs, is_only_discrete +export isdiffvar, isdervar, isalgvar, isdiffeq, algeqs, is_only_discrete export dervars_range, diffvars_range, algvars_range export DiffGraph, complete! export get_fullvars, system_subset @@ -620,5 +617,3 @@ function _structural_simplify!(state::TearingState, io; simplify = false, @set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates) ModelingToolkit.invalidate_cache!(sys), input_idxs end - -end # module diff --git a/test/clock.jl b/test/clock.jl index 42074b65ff..74946d21d7 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -62,13 +62,12 @@ By inference: => Shift(x, 0, dt) := (Shift(x, -1, dt) + dt) / (1 - dt) # Discrete system =# -using ModelingToolkit.SystemStructures ci, varmap = infer_clocks(sys) eqmap = ci.eq_domain tss, inputs = ModelingToolkit.split_system(deepcopy(ci)) -sss, = SystemStructures._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) +sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ())) @test equations(sss) == [D(x) ~ u - x] -sss, = SystemStructures._structural_simplify!(deepcopy(tss[2]), (inputs[2], ())) +sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[2]), (inputs[2], ())) @test isempty(equations(sss)) @test observed(sss) == [yd ~ Sample(t, dt)(y); r ~ 1.0; ud ~ kp * (r - yd)]