From 40c828a9994fed2420483445833775153993b404 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 21 Sep 2023 11:17:28 -0400 Subject: [PATCH] Add `u0_constructor` kwargs --- docs/src/basics/AbstractSystem.md | 6 ++++-- docs/src/basics/FAQ.md | 16 ++++++++++++++++ src/systems/diffeqs/abstractodesystem.jl | 4 ++++ test/odesystem.jl | 2 ++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/src/basics/AbstractSystem.md b/docs/src/basics/AbstractSystem.md index 2f5627eff0..a45695b8ac 100644 --- a/docs/src/basics/AbstractSystem.md +++ b/docs/src/basics/AbstractSystem.md @@ -56,10 +56,12 @@ The values which are common to all `AbstractSystem`s are: Optionally, a system could have: - `observed(sys)`: All observed equations of the system and its subsystems. + - `independent_variables(sys)`: The independent variables of a system. + - `defaults(sys)`: A `Dict` that maps variables/parameters into their default values for the system and its subsystems. - `get_observed(sys)`: Observed equations of the current-level system. - `get_continuous_events(sys)`: `SymbolicContinuousCallback`s of the current-level system. - - `get_defaults(sys)`: A `Dict` that maps variables into their default values. - - `independent_variables(sys)`: The independent variables of a system. + - `get_defaults(sys)`: A `Dict` that maps variables into their default values + for the current-level system. - `get_noiseeqs(sys)`: Noise equations of the current-level system. - `get_metadata(sys)`: Any metadata about the system or its origin to be used by downstream packages. diff --git a/docs/src/basics/FAQ.md b/docs/src/basics/FAQ.md index 5aeedfb75e..7b6846354f 100644 --- a/docs/src/basics/FAQ.md +++ b/docs/src/basics/FAQ.md @@ -129,3 +129,19 @@ julia> ModelingToolkit.missing_variable_defaults(sys, [1,2,3]) x2ˍtt(t) => 2 x3ˍtt(t) => 3 ``` + +## Change the state vector type + +Use the `u0_constructor` keyword argument to map an array to the desired +container type. For example: + +``` +using ModelingToolkit, StaticArrays +@variables t +sts = @variables x1(t)=0.0 +D = Differential(t) +eqs = [D(x1) ~ 1.1 * x1] +@named sys = ODESystem(eqs, t) +sys = structural_simplify(sys) +prob = ODEProblem{false}(sys, [], (0,1); u0_constructor = x->SVector(x...)) +``` diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 596cbbef4e..8528dd3c12 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -797,6 +797,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; use_union = true, tofloat = true, symbolic_u0 = false, + u0_constructor = identity, kwargs...) eqs = equations(sys) dvs = states(sys) @@ -809,6 +810,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; tofloat, use_union, symbolic_u0) + if u0 !== nothing + u0 = u0_constructor(u0) + end p, split_idxs = split_parameters_by_type(p) if p isa Tuple diff --git a/test/odesystem.jl b/test/odesystem.jl index 6d82e97d46..56d436d76c 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -438,6 +438,8 @@ prob = ODEProblem(sys) sol = solve(prob, Tsit5()) @test sol.t[end] == tspan[end] @test sum(abs, sol[end]) < 1 +prob = ODEProblem{false}(sys; u0_constructor = x -> SVector(x...)) +@test prob.u0 isa SVector # check_eqs_u0 kwarg test @parameters t