Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve BifurcationKit extension #2337

Merged
merged 5 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,76 @@
using ModelingToolkit, Setfield
import BifurcationKit

### Observable Plotting Handling ###

# Functor used when the plotting variable is an observable. Keeps track of the required information for computing the observable's value at each point of the bifurcation diagram.
struct ObservableRecordFromSolution{S, T}
# The equations determining the observables values.
obs_eqs::S
# The index of the observable that we wish to plot.
target_obs_idx::Int64
# The final index in subs_vals that contains a state.
state_end_idxs::Int64
# The final index in subs_vals that contains a param.
param_end_idxs::Int64
# The index (in subs_vals) that contain the bifurcation parameter.
bif_par_idx::Int64
# A Vector of pairs (Symbolic => value) with teh default values of all system variables and parameters.
subs_vals::T

function ObservableRecordFromSolution(nsys::NonlinearSystem,

Check warning on line 26 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L26

Added line #L26 was not covered by tests
plot_var,
bif_idx,
u0_vals,
p_vals) where {S, T}
obs_eqs = observed(nsys)
target_obs_idx = findfirst(isequal(plot_var, eq.lhs) for eq in observed(nsys))
state_end_idxs = length(states(nsys))
param_end_idxs = state_end_idxs + length(parameters(nsys))

Check warning on line 34 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L31-L34

Added lines #L31 - L34 were not covered by tests

bif_par_idx = state_end_idxs + bif_idx

Check warning on line 36 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L36

Added line #L36 was not covered by tests
# Gets the (base) substitution values for states.
subs_vals_states = Pair.(states(nsys), u0_vals)

Check warning on line 38 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L38

Added line #L38 was not covered by tests
# Gets the (base) substitution values for parameters.
subs_vals_params = Pair.(parameters(nsys), p_vals)

Check warning on line 40 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L40

Added line #L40 was not covered by tests
# Gets the (base) substitution values for observables.
subs_vals_obs = [obs.lhs => substitute(obs.rhs,

Check warning on line 42 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L42

Added line #L42 was not covered by tests
[subs_vals_states; subs_vals_params]) for obs in observed(nsys)]
# Sometimes observables depend on other observables, hence we make a second upate to this vector.
subs_vals_obs = [obs.lhs => substitute(obs.rhs,

Check warning on line 45 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L45

Added line #L45 was not covered by tests
[subs_vals_states; subs_vals_params; subs_vals_obs]) for obs in observed(nsys)]
# During the bifurcation process, teh value of some states, parameters, and observables may vary (and are calculated in each step). Those that are not are stored in this vector
subs_vals = [subs_vals_states; subs_vals_params; subs_vals_obs]

Check warning on line 48 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L48

Added line #L48 was not covered by tests

param_end_idxs = state_end_idxs + length(parameters(nsys))
new{typeof(obs_eqs), typeof(subs_vals)}(obs_eqs,

Check warning on line 51 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
target_obs_idx,
state_end_idxs,
param_end_idxs,
bif_par_idx,
subs_vals)
end
end
# Functor function that computes the value.
function (orfs::ObservableRecordFromSolution)(x, p)

Check warning on line 60 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L60

Added line #L60 was not covered by tests
# Updates the state values (in subs_vals).
for state_idx in 1:(orfs.state_end_idxs)
orfs.subs_vals[state_idx] = orfs.subs_vals[state_idx][1] => x[state_idx]
end

Check warning on line 64 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L62-L64

Added lines #L62 - L64 were not covered by tests

# Updates the bifurcation parameters value (in subs_vals).
orfs.subs_vals[orfs.bif_par_idx] = orfs.subs_vals[orfs.bif_par_idx][1] => p

Check warning on line 67 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L67

Added line #L67 was not covered by tests

# Updates the observable values (in subs_vals).
for (obs_idx, obs_eq) in enumerate(orfs.obs_eqs)
orfs.subs_vals[orfs.param_end_idxs + obs_idx] = orfs.subs_vals[orfs.param_end_idxs + obs_idx][1] => substitute(obs_eq.rhs,

Check warning on line 71 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
orfs.subs_vals)
end

Check warning on line 73 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L73

Added line #L73 was not covered by tests

# Substitutes in the value for all states, parameters, and observables into the equation for the designated observable.
return substitute(orfs.obs_eqs[orfs.target_obs_idx].rhs, orfs.subs_vals)

Check warning on line 76 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L76

Added line #L76 was not covered by tests
end

### Creates BifurcationProblem Overloads ###

# When input is a NonlinearSystem.
Expand All @@ -23,20 +93,37 @@
F = ofun.f
J = jac ? ofun.jac : nothing

# Computes bifurcation parameter and plot var indexes.
# Converts the input state guess.
u0_bif_vals = ModelingToolkit.varmap_to_vars(u0_bif,
states(nsys);
defaults = nsys.defaults)
p_vals = ModelingToolkit.varmap_to_vars(ps, parameters(nsys); defaults = nsys.defaults)

# Computes bifurcation parameter and the plotting function.
bif_idx = findfirst(isequal(bif_par), parameters(nsys))
if !isnothing(plot_var)
plot_idx = findfirst(isequal(plot_var), states(nsys))
record_from_solution = (x, p) -> x[plot_idx]
end
# If the plot var is a normal state.
if any(isequal(plot_var, var) for var in states(nsys))
plot_idx = findfirst(isequal(plot_var), states(nsys))
record_from_solution = (x, p) -> x[plot_idx]

# Converts the input state guess.
u0_bif = ModelingToolkit.varmap_to_vars(u0_bif, states(nsys))
ps = ModelingToolkit.varmap_to_vars(ps, parameters(nsys))
# If the plot var is an observed state.
elseif any(isequal(plot_var, eq.lhs) for eq in observed(nsys))
record_from_solution = ObservableRecordFromSolution(nsys,

Check warning on line 112 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L111-L112

Added lines #L111 - L112 were not covered by tests
plot_var,
bif_idx,
u0_bif_vals,
p_vals)

# If neither an variable nor observable, throw an error.
else
error("The plot variable ($plot_var) was neither recognised as a system state nor observable.")

Check warning on line 120 in ext/MTKBifurcationKitExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MTKBifurcationKitExt.jl#L120

Added line #L120 was not covered by tests
end
end

return BifurcationKit.BifurcationProblem(F,
u0_bif,
ps,
u0_bif_vals,
p_vals,
(@lens _[bif_idx]),
args...;
record_from_solution = record_from_solution,
Expand Down
121 changes: 111 additions & 10 deletions test/extensions/bifurcationkit.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,136 @@
using BifurcationKit, ModelingToolkit, Test

# Checks pitchfork diagram and that there are the correct number of branches (a main one and two children)
# Simple pitchfork diagram, compares solution to native BifurcationKit, checks they are identical.
# Checks using `jac=false` option.
let
# Creates model.
@variables t x(t) y(t)
@parameters μ α
eqs = [0 ~ μ * x - x^3 + α * y,
0 ~ -y]
@named nsys = NonlinearSystem(eqs, [x, y], [μ, α])

# Creates BifurcationProblem
bif_par = μ
p_start = [μ => -1.0, α => 1.0]
u0_guess = [x => 1.0, y => 1.0]
plot_var = x

using BifurcationKit
bprob = BifurcationProblem(nsys,
u0_guess,
p_start,
bif_par;
plot_var = plot_var,
jac = false)

# Conputes bifurcation diagram.
p_span = (-4.0, 6.0)
opts_br = ContinuationPar(max_steps = 500, p_min = p_span[1], p_max = p_span[2])
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)

# Computes bifurcation diagram using BifurcationKit directly (without going through MTK).
function f_BK(u, p)
x, y = u
μ, α = p
return [μ * x - x^3 + α * y, -y]
end
bprob_BK = BifurcationProblem(f_BK,
[1.0, 1.0],
[-1.0, 1.0],
(@lens _[1]);
record_from_solution = (x, p) -> x[1])
bif_dia_BK = bifurcationdiagram(bprob_BK,
PALC(),
2,
(args...) -> opts_br;
bothside = true)

# Compares results.
@test getfield.(bif_dia.γ.branch, :x) ≈ getfield.(bif_dia_BK.γ.branch, :x)
@test getfield.(bif_dia.γ.branch, :param) ≈ getfield.(bif_dia_BK.γ.branch, :param)
@test bif_dia.γ.specialpoint[1].x == bif_dia_BK.γ.specialpoint[1].x
@test bif_dia.γ.specialpoint[1].param == bif_dia_BK.γ.specialpoint[1].param
@test bif_dia.γ.specialpoint[1].type == bif_dia_BK.γ.specialpoint[1].type
end

# Lotka–Volterra model, checks exact position of bifurcation variable and bifurcation points.
# Checks using ODESystem input.
let
# Creates a Lotka–Volterra model.
@parameters α a b
@variables t x(t) y(t) z(t)
D = Differential(t)
eqs = [D(x) ~ -x + a * y + x^2 * y,
D(y) ~ b - a * y - x^2 * y]
@named sys = ODESystem(eqs)

# Creates BifurcationProblem
bprob = BifurcationProblem(sys,
[x => 1.5, y => 1.0],
[a => 0.1, b => 0.5],
b;
plot_var = x)

# Computes bifurcation diagram.
p_span = (0.0, 2.0)
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 2000)
opts_br = ContinuationPar(dsmax = 0.05,
max_steps = 500,
newton_options = opt_newton,
p_min = p_span[1],
p_max = p_span[2],
n_inversion = 4)
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)

# Tests that the diagram has the correct values (x = b)
all([b.x ≈ b.param for b in bif_dia.γ.branch])

# Tests that we get two Hopf bifurcations at the correct positions.
hopf_points = sort(getfield.(filter(sp -> sp.type == :hopf, bif_dia.γ.specialpoint),
:x);
by = x -> x[1])
@test length(hopf_points) == 2
@test hopf_points[1] ≈ [0.41998733080424205, 1.5195495712453098]
@test hopf_points[2] ≈ [0.7899715592573977, 1.0910379583813192]
end

# Simple fold bifurcation model, checks exact position of bifurcation variable and bifurcation points.
# Checks that default parameter values are accounted for.
# Checks that observables (that depend on other observables, as in this case) are accounted for.
let
# Creates model, and uses `structural_simplify` to generate observables.
@parameters μ p=2
@variables t x(t) y(t) z(t)
D = Differential(t)
eqs = [0 ~ μ - x^3 + 2x^2,
0 ~ p * μ - y,
0 ~ y - z]
@named nsys = NonlinearSystem(eqs, [x, y, z], [μ, p])
nsys = structural_simplify(nsys)

# Creates BifurcationProblem.
bif_par = μ
p_start = [μ => 1.0]
u0_guess = [x => 1.0, y => 0.1, z => 0.1]
plot_var = x
bprob = BifurcationProblem(nsys, u0_guess, p_start, bif_par; plot_var = plot_var)

# Computes bifurcation diagram.
p_span = (-4.3, 12.0)
opt_newton = NewtonPar(tol = 1e-9, max_iterations = 20)
opts_br = ContinuationPar(dsmin = 0.001, dsmax = 0.05, ds = 0.01,
max_steps = 100, nev = 2, newton_options = opt_newton,
p_min = p_span[1], p_max = p_span[2],
detect_bifurcation = 3, n_inversion = 4, tol_bisection_eigenvalue = 1e-8,
dsmin_bisection = 1e-9)
opts_br = ContinuationPar(dsmax = 0.05,
max_steps = 500,
newton_options = opt_newton,
p_min = p_span[1],
p_max = p_span[2],
n_inversion = 4)
bif_dia = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)

bf = bifurcationdiagram(bprob, PALC(), 2, (args...) -> opts_br; bothside = true)
# Tests that the diagram has the correct values (x = b)
all([b.x ≈ 2 * b.param for b in bif_dia.γ.branch])

@test length(bf.child) == 2
# Tests that we get two fold bifurcations at the correct positions.
fold_points = sort(getfield.(filter(sp -> sp.type == :bp, bif_dia.γ.specialpoint),
:param))
@test length(fold_points) == 2
@test fold_points ≈ [-1.1851851706940317, -5.6734983580551894e-6] # test that they occur at the correct parameter values).
end
Loading