Skip to content

Commit

Permalink
refactor: use Accessors instead of Setfield
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 19, 2024
1 parent 4792360 commit 2f8cf36
Show file tree
Hide file tree
Showing 30 changed files with 173 additions and 167 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "9.59.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down Expand Up @@ -73,11 +74,12 @@ MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"

[compat]
AbstractTrees = "0.3, 0.4"
Accessors = "0.1.36"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Expand All @@ -23,6 +24,7 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Accessors = "0.1.36"
BenchmarkTools = "1.3"
BifurcationKit = "0.4"
DataInterpolations = "6.5"
Expand Down
34 changes: 17 additions & 17 deletions docs/src/basics/Events.md
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,16 @@ is below `furnace_on_threshold` and off when above `furnace_off_threshold`, whil
in between. To do this, we create two continuous callbacks:

```@example events
using Setfield
using Accessors
furnace_disable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i
@set! x.furnace_on = false
@reset x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i
@set! x.furnace_on = true
@reset x.furnace_on = true
end)
```

Expand All @@ -432,7 +432,7 @@ You can also write
```julia
[temp ~ furnace_off_threshold] => ModelingToolkit.ImperativeAffect(modified = (;
furnace_on)) do x, o, i, c
@set! x.furnace_on = false
@reset x.furnace_on = false
end
```

Expand Down Expand Up @@ -462,7 +462,7 @@ f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`.
The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`.
The examples does this with Setfield, recreating the result tuple before returning it; the returned tuple may optionally be missing values as
The examples does this with Accessors, recreating the result tuple before returning it; the returned tuple may optionally be missing values as
well, in which case those values will not be written back to the problem.

Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we
Expand Down Expand Up @@ -542,18 +542,18 @@ In our encoder, we interpret this as occlusion or nonocclusion of the sensor, up
```@example events
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0],
ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 1
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@reset x.hA = x.qA
@reset x.hB = o.qB
@reset x.qA = 1
@reset x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end,
affect_neg = ModelingToolkit.ImperativeAffect(
(; qA, hA, hB, cnt), (; qB)) do x, o, c, i
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 0
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@reset x.hA = x.qA
@reset x.hB = o.qB
@reset x.qA = 0
@reset x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end)
```
Expand All @@ -566,10 +566,10 @@ Instead, we can use right root finding:
```@example events
qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0],
ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, c, i
@set! x.hA = o.qA
@set! x.hB = x.qB
@set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0)
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
@reset x.hA = o.qA
@reset x.hB = x.qB
@reset x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0)
@reset x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
x
end; rootfind = SciMLBase.RightRootFind)
```
Expand Down
4 changes: 2 additions & 2 deletions ext/MTKBifurcationKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MTKBifurcationKitExt
### Preparations ###

# Imports
using ModelingToolkit, Setfield
using ModelingToolkit, Accessors
import BifurcationKit

### Observable Plotting Handling ###
Expand Down Expand Up @@ -94,7 +94,7 @@ function BifurcationKit.BifurcationProblem(nsys::NonlinearSystem,
if !ModelingToolkit.iscomplete(nsys)
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `BifurcationProblem`")
end
@set! nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
@reset nsys.index_cache = nothing # force usage of a parameter vector instead of `MTKParameters`
# Creates F and J functions.
ofun = NonlinearFunction(nsys; jac = jac)
F = ofun.f
Expand Down
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ using JumpProcesses
using DataStructures
using Base.Threads
using Latexify, Unitful, ArrayInterface
using Setfield, ConstructionBase
import Setfield
using Accessors, ConstructionBase
import Libdl
using DocStringExtensions
using Base: RefValue
Expand Down
2 changes: 1 addition & 1 deletion src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using DocStringExtensions
using UnPack
using SparseArrays
using Graphs
using Setfield
using Accessors

### Matching
struct Unassigned
Expand Down
16 changes: 8 additions & 8 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,12 @@ function inputs_to_parameters!(state::TransformationState, io)
@assert new_v > 0
new_var_to_diff[new_i] = new_v
end
@set! structure.var_to_diff = complete(new_var_to_diff)
@set! structure.graph = complete(new_graph)
@reset structure.var_to_diff = complete(new_var_to_diff)
@reset structure.graph = complete(new_graph)

@set! sys.eqs = isempty(input_to_parameters) ? equations(sys) :
@reset sys.eqs = isempty(input_to_parameters) ? equations(sys) :
fast_substitute(equations(sys), input_to_parameters)
@set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
@reset sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters))
ps = parameters(sys)

if io !== nothing
Expand All @@ -334,11 +334,11 @@ function inputs_to_parameters!(state::TransformationState, io)
new_parameters = new_parameters[permutation]
end

@set! sys.ps = [ps; new_parameters]
@reset sys.ps = [ps; new_parameters]

@set! state.sys = sys
@set! state.fullvars = new_fullvars
@set! state.structure = structure
@reset state.sys = sys
@reset state.fullvars = new_fullvars
@reset state.structure = structure
base_params = length(ps)
return state, (base_params + 1):(base_params + length(new_parameters)) # (1:length(new_parameters)) .+ base_params
end
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module StructuralTransformations

using Setfield: @set!, @set
using Accessors: @set, @reset
using UnPack: @unpack

using Symbolics: unwrap, linear_expansion, fast_substitute
Expand Down
2 changes: 1 addition & 1 deletion src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ function build_torn_function(sys;
rhss)

unknown_vars = Any[fullvars[i] for i in unknowns_idxs]
@set! sys.solved_unknowns = unknown_vars
@reset sys.solved_unknowns = unknown_vars

pre = get_postprocess_fbody(sys)
cpre = get_preprocess_constants(rhss)
Expand Down
4 changes: 2 additions & 2 deletions src/structural_transformation/pantelides.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ function pantelides_reassemble(state::TearingState, var_eq_matching)
filter(x -> value(x.lhs) !== nothing,
out_eqs[sort(filter(x -> x !== unassigned, var_eq_matching))]))

@set! sys.eqs = final_eqs
@set! sys.unknowns = final_vars
@reset sys.eqs = final_eqs
@reset sys.unknowns = final_vars
return sys
end

Expand Down
28 changes: 14 additions & 14 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ end

function tearing_substitution(sys::AbstractSystem; kwargs...)
neweqs = full_equations(sys::AbstractSystem; kwargs...)
@set! sys.eqs = neweqs
@set! sys.substitutions = nothing
@set! sys.schedule = nothing
@reset sys.eqs = neweqs
@reset sys.substitutions = nothing
@reset sys.schedule = nothing
end

function tearing_assignments(sys::AbstractSystem)
Expand Down Expand Up @@ -563,10 +563,10 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
diff_to_var = invview(var_to_diff)

old_fullvars = fullvars
@set! state.structure.graph = complete(graph)
@set! state.structure.var_to_diff = var_to_diff
@set! state.structure.eq_to_diff = eq_to_diff
@set! state.fullvars = fullvars = fullvars[invvarsperm]
@reset state.structure.graph = complete(graph)
@reset state.structure.var_to_diff = var_to_diff
@reset state.structure.eq_to_diff = eq_to_diff
@reset state.fullvars = fullvars = fullvars[invvarsperm]
ispresent = let var_to_diff = var_to_diff, graph = graph
i -> (!isempty(𝑑neighbors(graph, i)) ||
(var_to_diff[i] !== nothing && !isempty(𝑑neighbors(graph, var_to_diff[i]))))
Expand All @@ -590,24 +590,24 @@ function tearing_reassemble(state::TearingState, var_eq_matching,
push!(unknowns, old_fullvars[v])
end
end
@set! sys.unknowns = unknowns
@reset sys.unknowns = unknowns

obs, subeqs, deps = cse_and_array_hacks(
obs, subeqs, unknowns, neweqs; cse = cse_hack, array = array_hack)

@set! sys.eqs = neweqs
@set! sys.observed = obs
@reset sys.eqs = neweqs
@reset sys.observed = obs

@set! sys.substitutions = Substitutions(subeqs, deps)
@reset sys.substitutions = Substitutions(subeqs, deps)

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
@reset sys.schedule = Schedule(var_eq_matching, dummy_sub)
end
sys = schedule(sys)
@set! state.sys = sys
@set! sys.tearing_state = state
@reset state.sys = sys
@reset sys.tearing_state = state
return invalidate_cache!(sys)
end

Expand Down
Loading

0 comments on commit 2f8cf36

Please sign in to comment.