Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use Accessors instead of Setfield #3279

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "9.59.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down Expand Up @@ -73,11 +74,12 @@ MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"

[compat]
AbstractTrees = "0.3, 0.4"
Accessors = "0.1.36"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand All @@ -23,6 +24,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Accessors = "0.1.36"
BenchmarkTools = "1.3"
BifurcationKit = "0.4"
DataInterpolations = "6.5"
Expand Down
34 changes: 17 additions & 17 deletions docs/src/basics/Events.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,16 @@ is below `furnace_on_threshold` and off when above `furnace_off_threshold`, whil
in between. To do this, we create two continuous callbacks:

```@example events
using Setfield
using Accessors
furnace_disable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i
@set! x.furnace_on = false
@reset x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i
@set! x.furnace_on = true
@reset x.furnace_on = true
end)
```

Expand All @@ -432,7 +432,7 @@ You can also write
```julia
[temp ~ furnace_off_threshold] => ModelingToolkit.ImperativeAffect(modified = (;
furnace_on)) do x, o, i, c
@set! x.furnace_on = false
@reset x.furnace_on = false
end
```

Expand Down Expand Up @@ -462,7 +462,7 @@ f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`.
The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`.
The examples does this with Setfield, recreating the result tuple before returning it; the returned tuple may optionally be missing values as
The examples does this with Accessors, recreating the result tuple before returning it; the returned tuple may optionally be missing values as
well, in which case those values will not be written back to the problem.

Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we
Expand Down Expand Up @@ -542,18 +542,18 @@ In our encoder, we interpret this as occlusion or nonocclusion of the sensor, up
```@example events
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0],
ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 1
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@reset x.hA = x.qA
@reset x.hB = o.qB
@reset x.qA = 1
@reset x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end,
affect_neg = ModelingToolkit.ImperativeAffect(
(; qA, hA, hB, cnt), (; qB)) do x, o, c, i
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 0
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@reset x.hA = x.qA
@reset x.hB = o.qB
@reset x.qA = 0
@reset x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end)
```
Expand All @@ -566,10 +566,10 @@ Instead, we can use right root finding:
```@example events
qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0],
ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, c, i
@set! x.hA = o.qA
@set! x.hB = x.qB
@set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0)
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
@reset x.hA = o.qA
@reset x.hB = x.qB
@reset x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0)
@reset x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
x
end; rootfind = SciMLBase.RightRootFind)
```
Expand Down
4 changes: 2 additions & 2 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MTKBifurcationKitExt
### Preparations ###

# Imports
using ModelingToolkit, Setfield
using ModelingToolkit, Accessors
import BifurcationKit

### Observable Plotting Handling ###
Expand Down Expand Up @@ -94,7 +94,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
if !ModelingToolkit.iscomplete(nsys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
@reset nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
# Creates F and J functions.
ofun = NonlinearFunction(nsys; jac = jac)
F = ofun.f
Expand Down
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ using JumpProcesses
using DataStructures
using Base.Threads
using Latexify, Unitful, ArrayInterface
using Setfield, ConstructionBase
import Setfield
using Accessors, ConstructionBase
import Libdl
using DocStringExtensions
using Base: RefValue
Expand Down
2 changes: 1 addition & 1 deletion src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using DocStringExtensions
using UnPack
using SparseArrays
using Graphs
using Setfield
using Accessors

### Matching
struct Unassigned
Expand Down
16 changes: 8 additions & 8 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,12 @@ function inputs_to_parameters!(state::TransformationState, io)
@assert new_v > 0
new_var_to_diff[new_i] = new_v
end
@set! structure.var_to_diff = complete(new_var_to_diff)
@set! structure.graph = complete(new_graph)
@reset structure.var_to_diff = complete(new_var_to_diff)
@reset structure.graph = complete(new_graph)

@set! sys.eqs = isempty(input_to_parameters) ? equations(sys) :
@reset sys.eqs = isempty(input_to_parameters) ? equations(sys) :
fast_substitute(equations(sys), input_to_parameters)
@set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
@reset sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
ps = parameters(sys)

if io !== nothing
Expand All @@ -334,11 +334,11 @@ function inputs_to_parameters!(state::TransformationState, io)
new_parameters = new_parameters[permutation]
end

@set! sys.ps = [ps; new_parameters]
@reset sys.ps = [ps; new_parameters]

@set! state.sys = sys
@set! state.fullvars = new_fullvars
@set! state.structure = structure
@reset state.sys = sys
@reset state.fullvars = new_fullvars
@reset state.structure = structure
base_params = length(ps)
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
end
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module StructuralTransformations

using Setfield: @set!, @set
using Accessors: @set, @reset
using UnPack: @unpack

using Symbolics: unwrap, linear_expansion, fast_substitute
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ function build_torn_function(sys;
rhss)

unknown_vars = Any[fullvars[i] for i in unknowns_idxs]
@set! sys.solved_unknowns = unknown_vars
@reset sys.solved_unknowns = unknown_vars

pre = get_postprocess_fbody(sys)
cpre = get_preprocess_constants(rhss)
Expand Down
4 changes: 2 additions & 2 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
filter(x -> value(x.lhs) !== nothing,
out_eqs[sort(filter(x -> x !== unassigned, var_eq_matching))]))

@set! sys.eqs = final_eqs
@set! sys.unknowns = final_vars
@reset sys.eqs = final_eqs
@reset sys.unknowns = final_vars
return sys
end

Expand Down
28 changes: 14 additions & 14 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ end

function tearing_substitution(sys::AbstractSystem; kwargs...)
neweqs = full_equations(sys::AbstractSystem; kwargs...)
@set! sys.eqs = neweqs
@set! sys.substitutions = nothing
@set! sys.schedule = nothing
@reset sys.eqs = neweqs
@reset sys.substitutions = nothing
@reset sys.schedule = nothing
end

function tearing_assignments(sys::AbstractSystem)
Expand Down Expand Up @@ -563,10 +563,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
diff_to_var = invview(var_to_diff)

old_fullvars = fullvars
@set! state.structure.graph = complete(graph)
@set! state.structure.var_to_diff = var_to_diff
@set! state.structure.eq_to_diff = eq_to_diff
@set! state.fullvars = fullvars = fullvars[invvarsperm]
@reset state.structure.graph = complete(graph)
@reset state.structure.var_to_diff = var_to_diff
@reset state.structure.eq_to_diff = eq_to_diff
@reset state.fullvars = fullvars = fullvars[invvarsperm]
ispresent = let var_to_diff = var_to_diff, graph = graph
i -> (!isempty(𝑑neighbors(graph, i)) ||
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
Expand All @@ -590,24 +590,24 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
push!(unknowns, old_fullvars[v])
end
end
@set! sys.unknowns = unknowns
@reset sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs
@reset sys.eqs = neweqs
@reset sys.observed = obs

@set! sys.substitutions = Substitutions(subeqs, deps)
@reset sys.substitutions = Substitutions(subeqs, deps)

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
@reset sys.schedule = Schedule(var_eq_matching, dummy_sub)
end
sys = schedule(sys)
@set! state.sys = sys
@set! sys.tearing_state = state
@reset state.sys = sys
@reset sys.tearing_state = state
return invalidate_cache!(sys)
end

Expand Down
Loading
Loading