Skip to content

Commit

Permalink
Add u0_constructor kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Sep 21, 2023
1 parent 8b58588 commit 40c828a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 2 deletions.
6 changes: 4 additions & 2 deletions docs/src/basics/AbstractSystem.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ The values which are common to all `AbstractSystem`s are:
Optionally, a system could have:

- `observed(sys)`: All observed equations of the system and its subsystems.
- `independent_variables(sys)`: The independent variables of a system.
- `defaults(sys)`: A `Dict` that maps variables/parameters into their default values for the system and its subsystems.
- `get_observed(sys)`: Observed equations of the current-level system.
- `get_continuous_events(sys)`: `SymbolicContinuousCallback`s of the current-level system.
- `get_defaults(sys)`: A `Dict` that maps variables into their default values.
- `independent_variables(sys)`: The independent variables of a system.
- `get_defaults(sys)`: A `Dict` that maps variables into their default values
for the current-level system.
- `get_noiseeqs(sys)`: Noise equations of the current-level system.
- `get_metadata(sys)`: Any metadata about the system or its origin to be used by downstream packages.

Expand Down
16 changes: 16 additions & 0 deletions docs/src/basics/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,19 @@ julia> ModelingToolkit.missing_variable_defaults(sys, [1,2,3])
x2ˍtt(t) => 2
x3ˍtt(t) => 3
```

## Change the state vector type

Use the `u0_constructor` keyword argument to map an array to the desired
container type. For example:

```
using ModelingToolkit, StaticArrays
@variables t
sts = @variables x1(t)=0.0
D = Differential(t)
eqs = [D(x1) ~ 1.1 * x1]
@named sys = ODESystem(eqs, t)
sys = structural_simplify(sys)
prob = ODEProblem{false}(sys, [], (0,1); u0_constructor = x->SVector(x...))
```
4 changes: 4 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
use_union = true,
tofloat = true,
symbolic_u0 = false,
u0_constructor = identity,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
Expand All @@ -809,6 +810,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
tofloat,
use_union,
symbolic_u0)
if u0 !== nothing
u0 = u0_constructor(u0)
end

p, split_idxs = split_parameters_by_type(p)
if p isa Tuple
Expand Down
2 changes: 2 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ prob = ODEProblem(sys)
sol = solve(prob, Tsit5())
@test sol.t[end] == tspan[end]
@test sum(abs, sol[end]) < 1
prob = ODEProblem{false}(sys; u0_constructor = x -> SVector(x...))
@test prob.u0 isa SVector

# check_eqs_u0 kwarg test
@parameters t
Expand Down

0 comments on commit 40c828a

Please sign in to comment.