From 41600b14e7ddea22cf17e529f005eb85aa98ce31 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 12 Nov 2024 23:53:53 -0800 Subject: [PATCH 1/3] Add a simple mechanism to add passes to structural simplify --- src/systems/systems.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 97d22cf4cf..29cc96d28f 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -26,7 +26,7 @@ topological sort of the observed equations in `sys`. + `fully_determined=true` controls whether or not an error will be thrown if the number of equations don't match the number of inputs, outputs, and equations. """ function structural_simplify( - sys::AbstractSystem, io = nothing; simplify = false, split = true, + sys::AbstractSystem, io = nothing; additional_passes = [], simplify = false, split = true, allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true, kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) @@ -49,6 +49,9 @@ function structural_simplify( if newsys isa ODESystem || has_parent(newsys) @set! newsys.parent = complete(sys; split, flatten = false) end + for pass in additional_passes + newsys = pass(newsys) + end newsys = complete(newsys; split) if has_defaults(newsys) && (defs = get_defaults(newsys)) !== nothing ks = collect(keys(defs)) # take copy to avoid mutating defs while iterating. From c0abc56decc7691b25254af716f20cc9979f73cf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 12 Dec 2024 14:59:55 +0530 Subject: [PATCH 2/3] fix: run additional passes before setting the parent of the system --- src/systems/systems.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 29cc96d28f..47acd81a82 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -46,12 +46,12 @@ function structural_simplify( not yet supported. """) end + for pass in additional_passes + newsys = pass(newsys) + end if newsys isa ODESystem || has_parent(newsys) @set! newsys.parent = complete(sys; split, flatten = false) end - for pass in additional_passes - newsys = pass(newsys) - end newsys = complete(newsys; split) if has_defaults(newsys) && (defs = get_defaults(newsys)) !== nothing ks = collect(keys(defs)) # take copy to avoid mutating defs while iterating. From daf93edfde0a60ec6fb460f83e977c248cfeb849 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 12 Dec 2024 15:00:08 +0530 Subject: [PATCH 3/3] test: test additional passes mechanism --- test/structural_transformation/utils.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 2704559f72..863e091aad 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -152,3 +152,12 @@ end end end end + +@testset "additional passes" begin + @variables x(t) y(t) + @named sys = ODESystem([D(x) ~ x, y ~ x + t], t) + value = Ref(0) + pass(sys; kwargs...) = (value[] += 1; return sys) + structural_simplify(sys; additional_passes = [pass]) + @test value[] == 1 +end