From 38bf0f4b4ac7f4d875385cf7a7ae55c3251db854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20W=C3=BCrfel?= Date: Wed, 27 Nov 2024 12:57:55 +0100 Subject: [PATCH] groundwork for external inputs --- ext/CUDAExt.jl | 10 ++-- src/NetworkDynamics.jl | 1 + src/component_functions.jl | 58 +++++++++++++++------ src/construction.jl | 57 ++++++++++++--------- src/coreloop.jl | 98 ++++++++++++++++++++---------------- src/doctor.jl | 5 +- src/external_inputs.jl | 60 ++++++++++++++++++++++ src/network_structure.jl | 56 +++++++++++++++------ src/show.jl | 6 +++ src/symbolicindexing.jl | 28 ++++++----- src/utils.jl | 6 +++ test/external_inputs_test.jl | 32 ++++++++++++ test/runtests.jl | 2 + 13 files changed, 304 insertions(+), 115 deletions(-) create mode 100644 src/external_inputs.jl create mode 100644 test/external_inputs_test.jl diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 21257356..f931eec2 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -32,12 +32,14 @@ function Adapt.adapt_structure(to, n::Network) mm = adapt(to, n.mass_matrix) gbp = adapt(to, n.gbufprovider) caches = (;output = _adapt_diffcache(to, n.caches.output), - aggregation = _adapt_diffcache(to, n.caches.aggregation)) + aggregation = _adapt_diffcache(to, n.caches.aggregation), + external = _adapt_diffcache(to, n.caches.external)) exT = typeof(executionstyle(n)) gT = typeof(n.im.g) + extmap = adapt(to, n.extmap) - Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp)}( - vb, layer, n.im, caches, mm, gbp) + Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp),typeof(extmap)}( + vb, layer, n.im, caches, mm, gbp, extmap) end Adapt.@adapt_structure NetworkLayer @@ -90,7 +92,7 @@ end function Adapt.adapt_structure(to, b::ComponentBatch) indices = adapt(to, b.indices) ComponentBatch(dispatchT(b), indices, b.compf, b.compg, b.ff, - b.statestride, b.pstride, b.inbufstride, b.outbufstride) + b.statestride, b.pstride, b.inbufstride, b.outbufstride, b.extbufstride) end #### diff --git a/src/NetworkDynamics.jl b/src/NetworkDynamics.jl index ce2e5e48..5552f7ed 100644 --- a/src/NetworkDynamics.jl +++ b/src/NetworkDynamics.jl @@ -55,6 +55,7 @@ include("aggregators.jl") include("gbufs.jl") include("construction.jl") include("coreloop.jl") +include("external_inputs.jl") # XXX: have both, s[:] and uflat(s) ? export VIndex, EIndex, VPIndex, EPIndex, NWState, NWParameter, uflat, pflat diff --git a/src/component_functions.jl b/src/component_functions.jl index f7e387f2..292935d5 100644 --- a/src/component_functions.jl +++ b/src/component_functions.jl @@ -283,13 +283,14 @@ struct VertexModel{F,G,FFT,OF,MM} <: ComponentModel g::G outsym::Vector{Symbol} ff::FFT - # parameters and option input sym + # parameters, optional input sym and optional external inputs psym::Vector{Symbol} insym::Union{Nothing, Vector{Symbol}} + extsym::Vector{SymbolicIndex} # observed obsf::OF obssym::Vector{Symbol} - # optional insyms + # metadata symmetadata::Dict{Symbol,Dict{Symbol, Any}} metadata::Dict{Symbol,Any} # cached symbol collections @@ -334,9 +335,10 @@ struct EdgeModel{F,G,FFT,OF,MM} <: ComponentModel g::G outsym::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}} ff::FFT - # parameters and option input sym + # parameters, optional input sym and optional external inputs psym::Vector{Symbol} insym::Union{Nothing, @NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}}} + extsym::Vector{SymbolicIndex} # observed obsf::OF obssym::Vector{Symbol} @@ -515,6 +517,20 @@ Gives the input dimension(s). indim(c::VertexModel)::Int = length(insym(c)) indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int} = (; src=length(insym(c).src), dst=length(insym(c).dst)) +""" + extsym(c::ComponentModel)::Vector{Symbol} + +Retrieve the external input symbols of the component. +""" +extsym(c::ComponentModel) = c.extsym + +""" + extdim(c::ComponentModel)::Int + +Retrieve the external input dimension of the component. +""" +extdim(c::ComponentModel) = length(extsym(c)) + # return both "observed" outputs (those that do not shadow states) and true observed outsym_flat(c::ComponentModel) = c._outsym_flat obssym_all(c::ComponentModel) = c._obssym_all @@ -570,12 +586,13 @@ dispatchT(T::Type{<:EdgeModel}) = EdgeModel{nothing,nothing,nothing,nothing,noth # TODO: introduce batchequal hash for faster batching of component models batchequal(a, b) = false function batchequal(a::ComponentModel, b::ComponentModel) - compf(a) == compf(b) || return false - compg(a) == compg(b) || return false - fftype(a) == fftype(b) || return false - dim(a) == dim(b) || return false - outdim(a) == outdim(b) || return false - pdim(a) == pdim(b) || return false + compf(a) == compf(b) || return false + compg(a) == compg(b) || return false + fftype(a) == fftype(b) || return false + dim(a) == dim(b) || return false + outdim(a) == outdim(b) || return false + pdim(a) == pdim(b) || return false + extdim(a) == extdim(b) || return false return true end @@ -957,6 +974,15 @@ function _fill_defaults(T, @nospecialize(kwargs)) dict[:_outsym_flat] = _outsym_flat dict[:_obssym_all] = setdiff(_outsym_flat, sym) ∪ obssym + #### + #### Extsym + #### + if haskey(dict, :extsym) + @assert dict[:extsym] isa Vector{<:SymbolicIndex} + else + dict[:extsym] = SymbolicIndex[] + end + # check for name clashes (at the end because only now sym, psym, obssym are initialized) _s = sym @@ -1098,20 +1124,22 @@ function Base.:(==)(cf1::ComponentModel, cf2::ComponentModel) end function compfg(c::VertexModel) - (out, du, u, in, p, t) -> begin + (out, du, u, in, p, t, ext=nothing) -> begin f = compf(c) - isnothing(f) || f(du, u, in, p, t) + fargs = isnothing(ext) ? (du, u, in, p, t) : (du, u, in, p, t, ext) + isnothing(f) || f(fargs...) compg(c)(_gargs(fftype(c), (out,), du, u, (in,), p, t)...) nothing - end + end end function compfg(c::EdgeModel) - (out1, out2, du, u, in1, in2, p, t) -> begin + (out1, out2, du, u, in1, in2, p, t, ext=nothing) -> begin f = compf(c) - isnothing(f) || f(du, u, in1, in2, p, t) + fargs = isnothing(ext) ? (du, u, in1, in2, p, t) : (du, u, in1, in2, p, t, ext) + isnothing(f) || f(fargs...) compg(c)(_gargs(fftype(c), (out1,out2), du, u, (in1,in2), p, t)...) nothing - end + end end _gargs(::PureFeedForward, outs, du, u, ins, p, t) = (outs..., ins..., p, t) _gargs(::FeedForward, outs, du, u, ins, p, t) = (outs..., u, ins..., p, t) diff --git a/src/construction.jl b/src/construction.jl index a12525e1..ad875fb9 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -171,7 +171,8 @@ function Network(g::AbstractGraph, mass_matrix = construct_mass_matrix(im) N = ForwardDiff.pickchunksize(max(im.lastidx_dynamic, im.lastidx_p)) caches = (; output = DiffCache(zeros(im.lastidx_out), N), - aggregation = DiffCache(zeros(im.lastidx_aggr), N)) + aggregation = DiffCache(zeros(im.lastidx_aggr), N), + external = DiffCache(zeros(im.lastidx_extbuf), N)) gbufprovider = if usebuffer(execution) EagerGBufProvider(im, edgebatches) @@ -179,13 +180,17 @@ function Network(g::AbstractGraph, LazyGBufProvider(im, edgebatches) end + # create map for extenral inputs + extmap = ExtMap(im) + nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches), - typeof(mass_matrix),eltype(caches),typeof(gbufprovider)}( + typeof(mass_matrix),eltype(caches),typeof(gbufprovider),typeof(extmap)}( vertexbatches, nl, im, caches, mass_matrix, - gbufprovider + gbufprovider, + extmap, ) end @@ -276,37 +281,39 @@ function aliasgroups(cfs::Vector{T}) where {T<:ComponentModel} end function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose) - components = @view im.vertexm[idxs] - _compT = dispatchT(first(components)) - _compf = compf(first(components)) - _compg = compg(first(components)) - _ff = fftype(first(components)) - _dim = dim(first(components)) - _pdim = pdim(first(components)) - _outdim = outdim(first(components)) - - strides = register_vertices!(im, _dim, _outdim, _pdim, idxs) + first_comp = im.vertexm[first(idxs)] + _compT = dispatchT(first_comp) + _compf = compf(first_comp) + _compg = compg(first_comp) + _ff = fftype(first_comp) + _dim = dim(first_comp) + _pdim = pdim(first_comp) + _outdim = outdim(first_comp) + _extdim = extdim(first_comp) + + strides = register_vertices!(im, idxs, _dim, _outdim, _pdim, _extdim) verbose && println(" - VertexBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))") - ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out) + ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out, strides.ext) end function EdgeBatch(im::IndexManager, idxs::Vector{Int}; verbose) - components = @view im.edgem[idxs] - _compT = dispatchT(first(components)) - _compf = compf(first(components)) - _compg = compg(first(components)) - _ff = fftype(first(components)) - _dim = dim(first(components)) - _pdim = pdim(first(components)) - _outdim = outdim(first(components)) - - strides = register_edges!(im, _dim, _outdim, _pdim, idxs) + first_comp = im.edgem[first(idxs)] + _compT = dispatchT(first_comp) + _compf = compf(first_comp) + _compg = compg(first_comp) + _ff = fftype(first_comp) + _dim = dim(first_comp) + _pdim = pdim(first_comp) + _outdim = outdim(first_comp) + _extdim = extdim(first_comp) + + strides = register_edges!(im, idxs, _dim, _outdim, _pdim, _extdim) verbose && println(" - EdgeBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))") - ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out) + ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out, strides.ext) end batch_by_idxs(v, idxs::Vector{Vector{Int}}) = [v for batch in idxs] diff --git a/src/coreloop.jl b/src/coreloop.jl index 861a221e..014669fa 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -2,7 +2,7 @@ function (nw::Network{A,B,C,D,E})(du::dT, u::T, p, t) where {A,B,C,D,E,dT,T} if !(eachindex(du) == eachindex(u) == 1:nw.im.lastidx_dynamic) throw(ArgumentError("du or u does not have expected size $(nw.im.lastidx_dynamic)")) end - if nw.im.lastidx_p > 0 && _indexable(p) && !(eachindex(p) == 1:nw.im.lastidx_p) + if pdim(nw) > 0 && !(eachindex(p) == 1:nw.im.lastidx_p) throw(ArgumentError("p does not has expecte size $(nw.im.lastidx_p)")) end @@ -10,26 +10,31 @@ function (nw::Network{A,B,C,D,E})(du::dT, u::T, p, t) where {A,B,C,D,E,dT,T} fill!(du, zero(eltype(du))) o = get_output_cache(nw, du) fill!(o, convert(eltype(o), NaN)) + extbuf = get_extinput_cache(nw, du) + fill!(o, convert(eltype(extbuf), NaN)) duopt = (du, u, o, p, t) aggbuf = get_aggregation_cache(nw, du) gbuf = get_gbuf(nw.gbufprovider, o) # vg without ff - process_batches!(ex, Val{:g}(), !hasff, nw.vertexbatches, nothing, duopt) + process_batches!(ex, Val{:g}(), !hasff, nw.vertexbatches, (nothing, nothing), duopt) # eg without ff - process_batches!(ex, Val{:g}(), !hasff, nw.layer.edgebatches, nothing, duopt) + process_batches!(ex, Val{:g}(), !hasff, nw.layer.edgebatches, (nothing, nothing), duopt) # process batches might be async so sync before next step ex isa KAExecution && KernelAbstractions.synchronize(get_backend(du)) + # gather the external inputs + collect_externals!(nw.extmap, extbuf, u, o) + # gather the vertex results for edges with ff - gather!(nw.gbufprovider, gbuf, o) # 2.6ms + gather!(nw.gbufprovider, gbuf, o) # execute f for the edges without ff - process_batches!(ex, Val{:f}(), !hasff, nw.layer.edgebatches, gbuf, duopt) + process_batches!(ex, Val{:f}(), !hasff, nw.layer.edgebatches, (gbuf, extbuf), duopt) # execute f&g for edges with ff - process_batches!(ex, Val{:fg}(), hasff, nw.layer.edgebatches, gbuf, duopt) + process_batches!(ex, Val{:fg}(), hasff, nw.layer.edgebatches, (gbuf, extbuf), duopt) # process batches might be async so sync before next step ex isa KAExecution && KernelAbstractions.synchronize(get_backend(du)) @@ -38,7 +43,7 @@ function (nw::Network{A,B,C,D,E})(du::dT, u::T, p, t) where {A,B,C,D,E,dT,T} aggregate!(nw.layer.aggregator, aggbuf, o) # vf for vertices without ff - process_batches!(ex, Val{:f}(), !hasff, nw.vertexbatches, aggbuf, duopt) + process_batches!(ex, Val{:f}(), !hasff, nw.vertexbatches, (aggbuf, extbuf), duopt) # process batches might be async so sync before next step ex isa KAExecution && KernelAbstractions.synchronize(get_backend(du)) @@ -55,15 +60,15 @@ function get_buffers(nw, u, p, t; initbufs) gbuf = get_gbuf(nw.gbufprovider, o) # vg without ff - process_batches!(ex, Val{:g}(), !hasff, nw.vertexbatches, nothing, duopt) + process_batches!(ex, Val{:g}(), !hasff, nw.vertexbatches, (nothing, nothing), duopt) # eg without ff - process_batches!(ex, Val{:g}(), !hasff, nw.layer.edgebatches, nothing, duopt) + process_batches!(ex, Val{:g}(), !hasff, nw.layer.edgebatches, (nothing, nothing), duopt) # process batches might be async so sync before next step ex isa KAExecution && KernelAbstractions.synchronize(get_backend(u)) # gather the vertex results for edges with ff gather!(nw.gbufprovider, gbuf, o) # 2.6ms # execute g for edges with ff - process_batches!(ex, Val{:g}(), hasff, nw.layer.edgebatches, gbuf, duopt) + process_batches!(ex, Val{:g}(), hasff, nw.layer.edgebatches, (gbuf, nothing), duopt) # process batches might be async so sync before next step ex isa KAExecution && KernelAbstractions.synchronize(get_backend(u)) # aggegrate the results @@ -72,37 +77,37 @@ function get_buffers(nw, u, p, t; initbufs) o, aggbuf end -@inline function process_batches!(::SequentialExecution, fg, filt::F, batches, inbuf, duopt) where {F} +@inline function process_batches!(::SequentialExecution, fg, filt::F, batches, inbufs, duopt) where {F} unrolled_foreach(filt, batches) do batch (du, u, o, p, t) = duopt for i in 1:length(batch) _type = dispatchT(batch) - apply_comp!(_type, fg, batch, i, du, u, o, inbuf, p, t) + apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) end end end -@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbuf, duopt) where {F} +@inline function process_batches!(::ThreadedExecution, fg, filt::F, batches, inbufs, duopt) where {F} unrolled_foreach(filt, batches) do batch (du, u, o, p, t) = duopt Threads.@threads for i in 1:length(batch) _type = dispatchT(batch) - apply_comp!(_type, fg, batch, i, du, u, o, inbuf, p, t) + apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) end end end -@inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbuf, duopt) where {F} +@inline function process_batches!(::PolyesterExecution, fg, filt::F, batches, inbufs, duopt) where {F} unrolled_foreach(filt, batches) do batch (du, u, o, p, t) = duopt Polyester.@batch for i in 1:length(batch) _type = dispatchT(batch) - apply_comp!(_type, fg, batch, i, du, u, o, inbuf, p, t) + apply_comp!(_type, fg, batch, i, du, u, o, inbufs, p, t) end end end -@inline function process_batches!(::KAExecution, fg, filt::F, batches, inbuf, duopt) where {F} +@inline function process_batches!(::KAExecution, fg, filt::F, batches, inbufs, duopt) where {F} _backend = get_backend(duopt[1]) unrolled_foreach(filt, batches) do batch (du, u, o, p, t) = duopt @@ -114,61 +119,70 @@ end elseif evalg(fg, batch) compkernel_g!(_backend) end - isnothing(kernel) || kernel(_type, fg, batch, du, u, o, inbuf, p, t; ndrange=length(batch)) + isnothing(kernel) || kernel(_type, fg, batch, du, u, o, inbufs, p, t; ndrange=length(batch)) end end @kernel function compkernel_f!(::Type{T}, @Const(fg), @Const(batch), - du, @Const(u), @Const(o), @Const(inbuf), @Const(p), @Const(t)) where {T} + du, @Const(u), @Const(o), @Const(inbufs), @Const(p), @Const(t)) where {T} I = @index(Global) - apply_comp!(T, fg, batch, I, du, u, o, inbuf, p, t) + apply_comp!(T, fg, batch, I, du, u, o, inbufs, p, t) nothing end @kernel function compkernel_g!(::Type{T}, @Const(fg), @Const(batch), - @Const(du), @Const(u), o, @Const(inbuf), @Const(p), @Const(t)) where {T} + @Const(du), @Const(u), o, @Const(inbufs), @Const(p), @Const(t)) where {T} I = @index(Global) - apply_comp!(T, fg, batch, I, du, u, o, inbuf, p, t) + apply_comp!(T, fg, batch, I, du, u, o, inbufs, p, t) nothing end @kernel function compkernel_fg!(::Type{T}, @Const(fg), @Const(batch), du, @Const(u), o, @Const(inbuf), @Const(p), @Const(t)) where {T} I = @index(Global) - apply_comp!(T, fg, batch, I, du, u, o, inbuf, p, t) + apply_comp!(T, fg, batch, I, du, u, o, inbufs, p, t) nothing end -@inline function apply_comp!(::Type{<:VertexModel}, fg, batch, i, du, u, o, aggbuf, p, t) +@inline function apply_comp!(::Type{<:VertexModel}, fg, batch, i, du, u, o, inbufs, p, t) @inbounds begin + aggbuf, extbuf = inbufs _o = _needs_out(fg, batch) ? view(o, out_range(batch, i)) : nothing _du = _needs_du(fg, batch) ? view(du, state_range(batch, i)) : nothing _u = _needs_u(fg, batch) ? view(u, state_range(batch, i)) : nothing _agg = _needs_in(fg, batch) ? view(aggbuf, in_range(batch, i)) : nothing - _p = _needs_p(fg, batch) && _indexable(p) ? view(p, parameter_range(batch, i)) : p - evalf(fg, batch) && apply_compf(compf(batch), _du, _u, (_agg,), _p, t) + _p = _needs_p(fg, batch) ? view(p, parameter_range(batch, i)) : nothing + _ext = _needs_ext(fg, batch) ? view(extbuf, extbuf_range(batch, i)) : nothing + evalf(fg, batch) && apply_compf(compf(batch), _du, _u, (_agg,), _p, t, _ext) evalg(fg, batch) && apply_compg(fftype(batch), compg(batch), (_o,), _u, (_agg,), _p, t) end nothing end -@inline function apply_comp!(::Type{<:EdgeModel}, fg, batch, i, du, u, o, gbuf, p, t) +@inline function apply_comp!(::Type{<:EdgeModel}, fg, batch, i, du, u, o, inbufs, p, t) @inbounds begin - _osrc = _needs_out(fg, batch) ? view(o, out_range(batch, i, 1)) : nothing - _odst = _needs_out(fg, batch) ? view(o, out_range(batch, i, 2)) : nothing - _du = _needs_du(fg, batch) ? view(du, state_range(batch, i)) : nothing - _u = _needs_u(fg, batch) ? view(u, state_range(batch, i)) : nothing - _ins = _needs_in(fg, batch) ? get_src_dst(gbuf, batch, i) : nothing - _p = _needs_p(fg, batch) && _indexable(p) ? view(p, parameter_range(batch, i)) : p - evalf(fg, batch) && apply_compf(compf(batch), _du, _u, _ins, _p, t) + gbuf, extbuf = inbufs + _osrc = _needs_out(fg, batch) ? view(o, out_range(batch, i, :src)) : nothing + _odst = _needs_out(fg, batch) ? view(o, out_range(batch, i, :dst)) : nothing + _du = _needs_du(fg, batch) ? view(du, state_range(batch, i)) : nothing + _u = _needs_u(fg, batch) ? view(u, state_range(batch, i)) : nothing + _ins = _needs_in(fg, batch) ? get_src_dst(gbuf, batch, i) : nothing + _p = _needs_p(fg, batch) ? view(p, parameter_range(batch, i)) : nothing + _ext = _needs_ext(fg, batch) ? view(extbuf, extbuf_range(batch, i)) : nothing + evalf(fg, batch) && apply_compf(compf(batch), _du, _u, _ins, _p, t, _ext) evalg(fg, batch) && apply_compg(fftype(batch), compg(batch), (_osrc, _odst), _u, _ins, _p, t) end nothing end -@propagate_inbounds function apply_compf(f::F, du, u, ins, p, t) where {F} +@propagate_inbounds function apply_compf(f::F, du, u, ins, p, t, ::Nothing) where {F} f(du, u, ins..., p, t) nothing end +@propagate_inbounds function apply_compf(f::F, du, u, ins, p, t, ext) where {F} + f(du, u, ins..., p, t, ext) + nothing +end + @propagate_inbounds function apply_compg(::PureFeedForward, g::G, outs, u, ins, p, t) where {G} g(outs..., ins..., p, t) nothing @@ -187,11 +201,12 @@ end end # check if the function arguments are actually used -_needs_du(fg, batch) = evalf(fg, batch) -_needs_u(fg, batch) = evalf(fg, batch) || fftype(batch) != PureFeedForward() +_needs_du(fg, batch) = evalf(fg, batch) +_needs_u(fg, batch) = evalf(fg, batch) || fftype(batch) != PureFeedForward() _needs_out(fg, batch) = evalg(fg, batch) -_needs_in(fg, batch) = evalf(fg, batch) || hasff(batch) -_needs_p(fg, batch) = evalf(fg, batch) || fftype(batch) != PureStateMap() +_needs_in(fg, batch) = evalf(fg, batch) || hasff(batch) +_needs_p(fg, batch) = !iszero(pdim(batch)) && (evalf(fg, batch) || fftype(batch) != PureStateMap()) +_needs_ext(fg, batch) = !iszero(extdim(batch)) && evalf(fg, batch) # check if eval of f or g is necessary evalf(::Val{:f}, batch) = !isnothing(compf(batch)) @@ -200,8 +215,3 @@ evalf(::Val{:fg}, batch) = !isnothing(compf(batch)) evalg(::Val{:f}, _) = false evalg(::Val{:g}, _) = true evalg(::Val{:fg}, _) = true - -# check if indexing into p is necessary -_indexable(::Nothing) = false -_indexable(::SciMLBase.NullParameters) = false -_indexable(::AbstractVector) = true diff --git a/src/doctor.jl b/src/doctor.jl index 8c73de60..41bc4d38 100644 --- a/src/doctor.jl +++ b/src/doctor.jl @@ -99,10 +99,11 @@ function chk_component(c::ComponentModel) Tuple(AccessTracker(rand(indim_guess)) for _ in outdim_normalized(c)) end outs = Tuple(AccessTracker(rand(l)) for l in values(outdim(c))) + ext = AccessTracker(rand(extdim(c))) t = NaN try - compfg(c)(outs..., du, u, ins..., p, t) + compfg(c)(outs..., du, u, ins..., p, t, ext) catch e if e isa MethodError @warn "Encountered MethodError. All arguments are AbstractArrays, make sure to allways index into them: $e" @@ -123,6 +124,7 @@ function chk_component(c::ComponentModel) has_oob(du) && @warn "There is out of bound acces to du: reads $(oob_reads(du)) and writes $(oob_writes(du))! Check dim/sym!" has_oob(u) && @warn "There is out of bound acces to u: reads $(oob_reads(u)) and writes $(oob_writes(u))! Check dim/sym!" has_oob(p) && @warn "There is out of bound acces to p: reads $(oob_reads(p)) and writes $(oob_writes(p))! Check pdim/psim!" + has_oob(ext) && @warn "There is out of bound acces to external input: reads $(oob_reads(ext)) and writes $(oob_writes(ext))!" for (j, o) in enumerate(outs) has_oob(o) && @warn "There is out of bound acces to output#$j: reads $(oob_reads(o)) and writes $(oob_writes(o))!" end @@ -148,6 +150,7 @@ function chk_component(c::ComponentModel) end has_writes(p) && @warn "There is write access to p: $(writes(p))!" + has_writes(ext) && @warn "There is write access to external inputs: $(writes(ext))!" similars = String[] has_similars(du) && push!(similars, "du") diff --git a/src/external_inputs.jl b/src/external_inputs.jl new file mode 100644 index 00000000..f0b2a135 --- /dev/null +++ b/src/external_inputs.jl @@ -0,0 +1,60 @@ +struct StateBufIdx + idx::Int +end +struct OutBufIdx + idx::Int +end +struct ExtMap{M<:AbstractVector{<:Union{StateBufIdx, OutBufIdx}}} + map::M +end + +function ExtMap(im::IndexManager) + map = Vector{Union{StateBufIdx, OutBufIdx}}(undef, im.lastidx_extbuf) + isempty(map) && return ExtMap(map) + + for vm in im.vertexm + for (i, si) in pairs(vm.extsym) + map[im.v_ext[i]] = _symidx_to_extidx(im, si) + end + end + for em in im.edgem + for (i, si) in pairs(em.extsym) + map[im.e_ext[i]] = _symidx_to_extidx(im, si) + end + end + + # narrow down type if all are output or state + if all(e -> typeof(e) == typeof(first(map)), map) + map = Vector{typeof(first(map))}(map) + end + ExtMap(map) +end + +function _symidex_to_extidx(im, sni) + cm = getcomp(im, sni) + if subsym_has_idx(sni.subidx, sym(cm)) + range = getcomprange(im, sni) + return StateBufIdx(range[subsym_to_idx(sni.subidx, sym(cm))]) + elseif subsym_has_idx(sni.subidx, outsym_flat(cm)) + range = getcompoutrange(im, sni) + if hasff(cm) + throw(ArgumentError("Cannot resolve external input $sni! Outputs of feed-forward components are not allowed as external inputs.")) + end + return OutBufIdx(range[subsym_to_idx(sni.subidx, sym(cm))]) + else + throw(ArgumentError("Cannot resolve external input $sni! External inputs musst be states or outputs of non-feed-forward components.")) + end +end + +function collect_externals!(map::ExtMap, extbuf, u, o) + @inbounds for (dst, src) in pairs(map.map) + if src isa StateBufIdx + extbuf[dst] = u[src.idx] + else + extbuf[dst] = o[src.idx] + end + end +end + +has_external_inputs(c::ComponentModel) = !iszero(extdim(c)) +has_external_inputs(cb::ComponentBatch) = !iszero(stridesT(cb.extbufstride)) diff --git a/src/network_structure.jl b/src/network_structure.jl index d998b3e7..9f7037f9 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -6,11 +6,13 @@ mutable struct IndexManager{G} v_out::Vector{UnitRange{Int}} # v range in output buf v_para::Vector{UnitRange{Int}} # v para in flat para v_aggr::Vector{UnitRange{Int}} # v input in aggbuf + v_ext::Vector{UnitRange{Int}} # v external inputs in ext buf # positions of edge data e_data::Vector{UnitRange{Int}} # e data in flat states e_out::Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}} # e range in output buf e_para::Vector{UnitRange{Int}} # e para in flat para e_gbufr::Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}} # e range in input buf + e_ext::Vector{UnitRange{Int}} # v external inputs in ext buf # metadata edepth::Int vdepth::Int @@ -19,6 +21,7 @@ mutable struct IndexManager{G} lastidx_p::Int lastidx_aggr::Int lastidx_gbuf::Int + lastidx_extbuf::Int vertexm::Vector{VertexModel} edgem::Vector{EdgeModel} aliased_vertexms::IdDict{VertexModel, @NamedTuple{idxs::Vector{Int}, hash::UInt}} @@ -31,13 +34,14 @@ mutable struct IndexManager{G} unique_vnames = unique_mappings(getproperty.(vertexm, :name), 1:nv(g)) unique_enames = unique_mappings(getproperty.(edgem, :name), 1:ne(g)) new{typeof(g)}(g, collect(edges(g)), - (Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:4)..., + (Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:5)..., Vector{UnitRange{Int}}(undef, ne(g)), Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}}(undef, ne(g)), Vector{UnitRange{Int}}(undef, ne(g)), Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}}(undef, ne(g)), + Vector{UnitRange{Int}}(undef, ne(g)), edepth, vdepth, - 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, vertexm, edgem, aliased_vertexm_hashes, aliased_edgem_hashes, @@ -62,10 +66,9 @@ end dim(im::IndexManager) = im.lastidx_dynamic pdim(im::IndexManager) = im.lastidx_p -sdim(im::IndexManager) = im.lastidx_static - im.lastidx_dynamic -struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT} +struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT,EM} "vertex batches of same function" vertexbatches::VTup "network layer" @@ -73,11 +76,13 @@ struct Network{EX<:ExecutionStyle,G,NL,VTup,MM,CT,GBT} "index manager" im::IndexManager{G} "lazy cache pool" - caches::@NamedTuple{output::CT,aggregation::CT} + caches::@NamedTuple{output::CT,aggregation::CT,external::CT} "mass matrix" mass_matrix::MM "Gather buffer provider (lazy or eager)" gbufprovider::GBT + "map to gather external inputs" + extmap::EM end executionstyle(::Network{ex}) where {ex} = ex() nvbatches(::Network) = length(vertexbatches) @@ -109,6 +114,7 @@ function get_output_cache(nw::Network, T) get_tmp(nw.caches.output, T) end get_aggregation_cache(nw::Network, T) = get_tmp(nw.caches.aggregation, T) +get_extinput_cache(nw::Network, T) = get_tmp(nw.caches.external, T) iscudacompatible(nw::Network) = iscudacompatible(executionstyle(nw)) && iscudacompatible(nw.layer.aggregator) @@ -125,7 +131,7 @@ struct NetworkLayer{GT,ETup,AF} vdepth::Int # potential becomes range for multilayer end -struct ComponentBatch{T,F,G,FFT,DIM,PDIM,INDIMS,OUTDIMS,IV} +struct ComponentBatch{T,F,G,FFT,DIM,PDIM,INDIMS,OUTDIMS,EXTDIM,IV} "indices contained in batch" indices::IV "internal function" @@ -141,9 +147,11 @@ struct ComponentBatch{T,F,G,FFT,DIM,PDIM,INDIMS,OUTDIMS,IV} inbufstride::BatchStride{INDIMS} "outputbuf: dimension(s) and first index" outbufstride::BatchStride{OUTDIMS} - function ComponentBatch(dT, i, f, g, ff, ss, ps, is, os) - new{dT,typeof.((f,g,ff))...,stridesT.((ss, ps, is, os))...,typeof(i)}( - i, f, g, ff, ss, ps, is, os) + "external inputs: dimension and first index" + extbufstride::BatchStride{EXTDIM} + function ComponentBatch(dT, i, f, g, ff, ss, ps, is, os, es) + new{dT,typeof.((f,g,ff))...,stridesT.((ss, ps, is, os, es))...,typeof(i)}( + i, f, g, ff, ss, ps, is, os, es) end end @@ -152,6 +160,8 @@ end @inline compf(b::ComponentBatch) = b.compf @inline compg(b::ComponentBatch) = b.compg @inline fftype(b::ComponentBatch) = b.ff +@inline pdim(b::ComponentBatch) = b.pstride.strides +@inline extdim(b::ComponentBatch) = b.extbufstride.strides @inline state_range(batch) = _fullrange(batch.statestride, length(batch)) @@ -161,22 +171,25 @@ end @inline out_range(batch, i, j) = _range(batch.outbufstride, i, j) @inline in_range(batch, i) = _range(batch.inbufstride, i) @inline in_range(batch, i, j) = _range(batch.inbufstride, i, j) +@inline extbuf_range(batch, i) = _range(batch.extbufstride, i) -function register_vertices!(im::IndexManager, dim, outdim, pdim, idxs) +function register_vertices!(im::IndexManager, idxs, dim, outdim, pdim, extdim) for i in idxs im.v_data[i] = _nexturange!(im, dim) im.v_out[i] = _nextoutrange!(im, outdim) im.v_para[i] = _nextprange!(im, pdim) im.v_aggr[i] = _nextaggrrange!(im, im.edepth) + im.v_ext[i] = _nextextrange!(im, extdim) end (; state = BatchStride(first(im.v_data[first(idxs)]), dim), p = BatchStride(first(im.v_para[first(idxs)]), pdim), in = BatchStride(first(im.v_aggr[first(idxs)]), im.edepth), out = BatchStride(first(im.v_out[first(idxs)]), outdim), + ext = BatchStride(first(im.v_ext[first(idxs)]), extdim), ) end -function register_edges!(im::IndexManager, dim, outdim, pdim, idxs) +function register_edges!(im::IndexManager, idxs, dim, outdim, pdim, extdim) for i in idxs e = im.edgevec[i] im.e_data[i] = _nexturange!(im, dim) @@ -185,12 +198,14 @@ function register_edges!(im::IndexManager, dim, outdim, pdim, idxs) im.e_para[i] = _nextprange!(im, pdim) im.e_gbufr[i] = (src = _nextgbufrange!(im, im.vdepth), dst = _nextgbufrange!(im, im.vdepth)) + im.e_ext[i] = _nextextrange!(im, extdim) end (; state = BatchStride(first(im.e_data[first(idxs)]), dim), p = BatchStride(first(im.e_para[first(idxs)]), pdim), in = BatchStride(first(flatrange(im.e_gbufr[first(idxs)])), (;src=im.vdepth, dst=im.vdepth)), out = BatchStride(first(flatrange(im.e_out[first(idxs)])), (;src=outdim.src, dst=outdim.dst)), + ext = BatchStride(first(im.e_ext[first(idxs)]), extdim), ) end function _nexturange!(im::IndexManager, N) @@ -218,12 +233,18 @@ function _nextgbufrange!(im::IndexManager, N) im.lastidx_gbuf = newlast return range end +function _nextextrange!(im::IndexManager, N) + newlast, range = _nextrange(im.lastidx_extbuf, N) + im.lastidx_extbuf = newlast + return range +end _nextrange(last, N) = last + N, last+1:last+N function isdense(im::IndexManager) - pidxs = Int[] - stateidxs = Int[] - outidxs = Int[] + stateidxs = sizehint!(Int[], im.lastidx_dynamic) + pidxs = sizehint!(Int[], im.lastidx_p) + outidxs = sizehint!(Int[], im.lastidx_out) + extidxs = sizehint!(Int[], im.lastidx_extbuf) for dataranges in (im.v_data, im.e_data) for range in dataranges append!(stateidxs, range) @@ -239,11 +260,18 @@ function isdense(im::IndexManager) append!(outidxs, flatrange(range)) end end + for extranges in (im.v_ext, im.e_ext) + for range in extranges + append!(extidxs, flatrange(range)) + end + end sort!(pidxs) sort!(stateidxs) sort!(outidxs) + sort!(outidxs) @assert pidxs == 1:im.lastidx_p @assert stateidxs == 1:im.lastidx_dynamic @assert outidxs == 1:im.lastidx_out + @assert extidxs == 1:im.lastidx_extbuf return true end diff --git a/src/show.jl b/src/show.jl index 48141210..b4440792 100644 --- a/src/show.jl +++ b/src/show.jl @@ -67,6 +67,12 @@ function print_states_params(io, @nospecialize(c::ComponentModel), styling) num, word = maybe_plural(pdim(c), "param") pdim(c) > 0 && push!(info, styled"$num &$word: &&$(stylesymbolarray(c.psym, pdef(c), pguess(c)))") + if has_external_inputs(c) + num = extdim(c) + arr = match(r"(\[.*\])", repr(extsym(c)))[1] + push!(info, styled"$num &ext in: &&$arr") + end + print_treelike(io, align_strings(info)) end function _inout_string(@nospecialize(c::VertexModel), f, name) diff --git a/src/symbolicindexing.jl b/src/symbolicindexing.jl index fda4463e..3dd7220f 100644 --- a/src/symbolicindexing.jl +++ b/src/symbolicindexing.jl @@ -1,6 +1,3 @@ -abstract type SymbolicIndex{C,S} end -abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end -abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end """ VIndex{C,S} <: SymbolicStateIndex{C,S} idx = VIndex(comp, sub) @@ -105,21 +102,28 @@ function SII.getname(x::SymbolicEdgeIndex) Symbol(prefix, Symbol(x.compidx), :₊, Symbol(x.subidx)) end -resolvecompidx(nw::Network, sni::SymbolicIndex{Int}) = sni.compidx -function resolvecompidx(nw::Network, sni::SymbolicIndex{Symbol}) - dict = sni isa SymbolicVertexIndex ? nw.im.unique_vnames : nw.im.unique_enames +resolvecompidx(nw::Network, sni) = resolvecompidx(nw.im, sni) +resolvecompidx(::IndexManager, sni::SymbolicIndex{Int}) = sni.compidx +function resolvecompidx(im::IndexManager, sni::SymbolicIndex{Symbol}) + dict = sni isa SymbolicVertexIndex ? im.unique_vnames : im.unique_enames if haskey(dict, sni.compidx) return dict[sni.compidx] else throw(ArgumentError("Could not resolve component index for $sni, the name might not be unique?")) end end -getcomp(nw::Network, sni::SymbolicEdgeIndex) = nw.im.edgem[resolvecompidx(nw, sni)] -getcomp(nw::Network, sni::SymbolicVertexIndex) = nw.im.vertexm[resolvecompidx(nw, sni)] -getcomprange(nw::Network, sni::VIndex{<:Union{Symbol,Int}}) = nw.im.v_data[resolvecompidx(nw, sni)] -getcomprange(nw::Network, sni::EIndex{<:Union{Symbol,Int}}) = nw.im.e_data[resolvecompidx(nw, sni)] -getcompoutrange(nw::Network, sni::VIndex{<:Union{Symbol,Int}}) = nw.im.v_out[resolvecompidx(nw, sni)] -getcompoutrange(nw::Network, sni::EIndex{<:Union{Symbol,Int}}) = flatrange(nw.im.e_out[resolvecompidx(nw, sni)]) +getcomp(nw::Network, sni) = getcomp(nw.im, sni) +getcomp(im::IndexManager, sni::SymbolicEdgeIndex) = im.edgem[resolvecompidx(im, sni)] +getcomp(im::IndexManager, sni::SymbolicVertexIndex) = im.vertexm[resolvecompidx(im, sni)] + +getcomprange(nw::Network, sni) = getcomprange(nw.im, sni) +getcomprange(im::IndexManager, sni::VIndex{<:Union{Symbol,Int}}) = im.v_data[resolvecompidx(im, sni)] +getcomprange(im::IndexManager, sni::EIndex{<:Union{Symbol,Int}}) = im.e_data[resolvecompidx(im, sni)] + +getcompoutrange(nw::Network, sni) = getcompoutrange(nw.im, sni) +getcompoutrange(im::IndexManager, sni::VIndex{<:Union{Symbol,Int}}) = im.v_out[resolvecompidx(im, sni)] +getcompoutrange(im::IndexManager, sni::EIndex{<:Union{Symbol,Int}}) = flatrange(im.e_out[resolvecompidx(im, sni)]) + getcompprange(nw::Network, sni::VPIndex{<:Union{Symbol,Int}}) = nw.im.v_para[resolvecompidx(nw, sni)] getcompprange(nw::Network, sni::EPIndex{<:Union{Symbol,Int}}) = nw.im.e_para[resolvecompidx(nw, sni)] diff --git a/src/utils.jl b/src/utils.jl index 860785f8..4442cadb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -124,3 +124,9 @@ function rand_inputs_fg(rng, cf) (outs..., du, u, ins..., p, t) end rand_inputs_fg(cf) = rand_inputs_fg(Random.default_rng(), cf) + + +# abstract symbolic index types +abstract type SymbolicIndex{C,S} end +abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end +abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end diff --git a/test/external_inputs_test.jl b/test/external_inputs_test.jl new file mode 100644 index 00000000..49fc3ed5 --- /dev/null +++ b/test/external_inputs_test.jl @@ -0,0 +1,32 @@ +using NetworkDynamics, Graphs +using Chairmarks: @b + +using NetworkDynamics: StateBufIdx, OutBufIdx, ExtMap, collect_externals! +using Random + +@testset "test extmap performance" begin + N = 10_000 + u = rand(N) + o = rand(N) + _map = Union{StateBufIdx,OutBufIdx}[rand((StateBufIdx(i), OutBufIdx(i))) for i in 1:N] + map = ExtMap(Random.shuffle(_map)) + + exbuf = zeros(N) + fill!(exbuf, NaN) + b = @b collect_externals!(map, exbuf, u, o) # 6.583μs + @test iszero(b.allocs) + + _map = [StateBufIdx(i) for i in 1:N] + map = ExtMap(Random.shuffle(_map)) + exbuf = zeros(N) + fill!(exbuf, NaN) + b = @b collect_externals!(map, exbuf, u, o) # 5.6μs + @test iszero(b.allocs) +end + +@testset "test construction of elements with external inputs" begin + function fext(dv, v, ein, p, t, ext) + dv[1] = ext[1] + end + c = VertexModel(f=fext, g=1, dim=1, extsym=[VIndex(18,:a)]) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3022c7da..6e09cd62 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,8 @@ using ExplicitImports NetworkDynamics.CHECK_COMPONENT[] = true @safetestset "Symbolic Indexing Tests" begin include("symbolicindexing_test.jl") end + @safetestset "external input test" begin include("external_inputs_test.jl") end + @safetestset "doctor test" begin include("doctor_test.jl") end @safetestset "Diffusion test" begin include("diffusion_test.jl") end