Skip to content

Commit

Permalink
groundwork for external inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
hexaeder committed Nov 27, 2024
1 parent dd7f5e9 commit 38bf0f4
Show file tree
Hide file tree
Showing 13 changed files with 304 additions and 115 deletions.
10 changes: 6 additions & 4 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ function Adapt.adapt_structure(to, n::Network)
mm = adapt(to, n.mass_matrix)
gbp = adapt(to, n.gbufprovider)
caches = (;output = _adapt_diffcache(to, n.caches.output),
aggregation = _adapt_diffcache(to, n.caches.aggregation))
aggregation = _adapt_diffcache(to, n.caches.aggregation),
external = _adapt_diffcache(to, n.caches.external))
exT = typeof(executionstyle(n))
gT = typeof(n.im.g)
extmap = adapt(to, n.extmap)

Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp)}(
vb, layer, n.im, caches, mm, gbp)
Network{exT,gT,typeof(layer),typeof(vb),typeof(mm),eltype(caches),typeof(gbp),typeof(extmap)}(
vb, layer, n.im, caches, mm, gbp, extmap)
end

Adapt.@adapt_structure NetworkLayer
Expand Down Expand Up @@ -90,7 +92,7 @@ end
function Adapt.adapt_structure(to, b::ComponentBatch)
indices = adapt(to, b.indices)
ComponentBatch(dispatchT(b), indices, b.compf, b.compg, b.ff,
b.statestride, b.pstride, b.inbufstride, b.outbufstride)
b.statestride, b.pstride, b.inbufstride, b.outbufstride, b.extbufstride)
end

####
Expand Down
1 change: 1 addition & 0 deletions src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("aggregators.jl")
include("gbufs.jl")
include("construction.jl")
include("coreloop.jl")
include("external_inputs.jl")

# XXX: have both, s[:] and uflat(s) ?
export VIndex, EIndex, VPIndex, EPIndex, NWState, NWParameter, uflat, pflat
Expand Down
58 changes: 43 additions & 15 deletions src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,14 @@ struct VertexModel{F,G,FFT,OF,MM} <: ComponentModel
g::G
outsym::Vector{Symbol}
ff::FFT
# parameters and option input sym
# parameters, optional input sym and optional external inputs
psym::Vector{Symbol}
insym::Union{Nothing, Vector{Symbol}}
extsym::Vector{SymbolicIndex}
# observed
obsf::OF
obssym::Vector{Symbol}
# optional insyms
# metadata
symmetadata::Dict{Symbol,Dict{Symbol, Any}}
metadata::Dict{Symbol,Any}
# cached symbol collections
Expand Down Expand Up @@ -334,9 +335,10 @@ struct EdgeModel{F,G,FFT,OF,MM} <: ComponentModel
g::G
outsym::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}}
ff::FFT
# parameters and option input sym
# parameters, optional input sym and optional external inputs
psym::Vector{Symbol}
insym::Union{Nothing, @NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}}}
extsym::Vector{SymbolicIndex}
# observed
obsf::OF
obssym::Vector{Symbol}
Expand Down Expand Up @@ -515,6 +517,20 @@ Gives the input dimension(s).
indim(c::VertexModel)::Int = length(insym(c))
indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int} = (; src=length(insym(c).src), dst=length(insym(c).dst))

"""
extsym(c::ComponentModel)::Vector{Symbol}
Retrieve the external input symbols of the component.
"""
extsym(c::ComponentModel) = c.extsym

"""
extdim(c::ComponentModel)::Int
Retrieve the external input dimension of the component.
"""
extdim(c::ComponentModel) = length(extsym(c))

# return both "observed" outputs (those that do not shadow states) and true observed
outsym_flat(c::ComponentModel) = c._outsym_flat
obssym_all(c::ComponentModel) = c._obssym_all
Expand Down Expand Up @@ -570,12 +586,13 @@ dispatchT(T::Type{<:EdgeModel}) = EdgeModel{nothing,nothing,nothing,nothing,noth
# TODO: introduce batchequal hash for faster batching of component models
batchequal(a, b) = false
function batchequal(a::ComponentModel, b::ComponentModel)
compf(a) == compf(b) || return false
compg(a) == compg(b) || return false
fftype(a) == fftype(b) || return false
dim(a) == dim(b) || return false
outdim(a) == outdim(b) || return false
pdim(a) == pdim(b) || return false
compf(a) == compf(b) || return false
compg(a) == compg(b) || return false
fftype(a) == fftype(b) || return false
dim(a) == dim(b) || return false
outdim(a) == outdim(b) || return false
pdim(a) == pdim(b) || return false
extdim(a) == extdim(b) || return false
return true
end

Expand Down Expand Up @@ -957,6 +974,15 @@ function _fill_defaults(T, @nospecialize(kwargs))
dict[:_outsym_flat] = _outsym_flat
dict[:_obssym_all] = setdiff(_outsym_flat, sym) obssym

####
#### Extsym
####
if haskey(dict, :extsym)
@assert dict[:extsym] isa Vector{<:SymbolicIndex}
else
dict[:extsym] = SymbolicIndex[]
end


# check for name clashes (at the end because only now sym, psym, obssym are initialized)
_s = sym
Expand Down Expand Up @@ -1098,20 +1124,22 @@ function Base.:(==)(cf1::ComponentModel, cf2::ComponentModel)
end

function compfg(c::VertexModel)
(out, du, u, in, p, t) -> begin
(out, du, u, in, p, t, ext=nothing) -> begin
f = compf(c)
isnothing(f) || f(du, u, in, p, t)
fargs = isnothing(ext) ? (du, u, in, p, t) : (du, u, in, p, t, ext)
isnothing(f) || f(fargs...)
compg(c)(_gargs(fftype(c), (out,), du, u, (in,), p, t)...)
nothing
end
end
end
function compfg(c::EdgeModel)
(out1, out2, du, u, in1, in2, p, t) -> begin
(out1, out2, du, u, in1, in2, p, t, ext=nothing) -> begin
f = compf(c)
isnothing(f) || f(du, u, in1, in2, p, t)
fargs = isnothing(ext) ? (du, u, in1, in2, p, t) : (du, u, in1, in2, p, t, ext)
isnothing(f) || f(fargs...)
compg(c)(_gargs(fftype(c), (out1,out2), du, u, (in1,in2), p, t)...)
nothing
end
end
end
_gargs(::PureFeedForward, outs, du, u, ins, p, t) = (outs..., ins..., p, t)
_gargs(::FeedForward, outs, du, u, ins, p, t) = (outs..., u, ins..., p, t)
Expand Down
57 changes: 32 additions & 25 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,26 @@ function Network(g::AbstractGraph,
mass_matrix = construct_mass_matrix(im)
N = ForwardDiff.pickchunksize(max(im.lastidx_dynamic, im.lastidx_p))
caches = (; output = DiffCache(zeros(im.lastidx_out), N),
aggregation = DiffCache(zeros(im.lastidx_aggr), N))
aggregation = DiffCache(zeros(im.lastidx_aggr), N),
external = DiffCache(zeros(im.lastidx_extbuf), N))

gbufprovider = if usebuffer(execution)
EagerGBufProvider(im, edgebatches)
else
LazyGBufProvider(im, edgebatches)
end

# create map for extenral inputs
extmap = ExtMap(im)

nw = Network{typeof(execution),typeof(g),typeof(nl), typeof(vertexbatches),
typeof(mass_matrix),eltype(caches),typeof(gbufprovider)}(
typeof(mass_matrix),eltype(caches),typeof(gbufprovider),typeof(extmap)}(
vertexbatches,
nl, im,
caches,
mass_matrix,
gbufprovider
gbufprovider,
extmap,
)

end
Expand Down Expand Up @@ -276,37 +281,39 @@ function aliasgroups(cfs::Vector{T}) where {T<:ComponentModel}
end

function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose)
components = @view im.vertexm[idxs]
_compT = dispatchT(first(components))
_compf = compf(first(components))
_compg = compg(first(components))
_ff = fftype(first(components))
_dim = dim(first(components))
_pdim = pdim(first(components))
_outdim = outdim(first(components))

strides = register_vertices!(im, _dim, _outdim, _pdim, idxs)
first_comp = im.vertexm[first(idxs)]
_compT = dispatchT(first_comp)
_compf = compf(first_comp)
_compg = compg(first_comp)
_ff = fftype(first_comp)
_dim = dim(first_comp)
_pdim = pdim(first_comp)
_outdim = outdim(first_comp)
_extdim = extdim(first_comp)

strides = register_vertices!(im, idxs, _dim, _outdim, _pdim, _extdim)

verbose && println(" - VertexBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))")

ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out)
ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out, strides.ext)
end

function EdgeBatch(im::IndexManager, idxs::Vector{Int}; verbose)
components = @view im.edgem[idxs]
_compT = dispatchT(first(components))
_compf = compf(first(components))
_compg = compg(first(components))
_ff = fftype(first(components))
_dim = dim(first(components))
_pdim = pdim(first(components))
_outdim = outdim(first(components))

strides = register_edges!(im, _dim, _outdim, _pdim, idxs)
first_comp = im.edgem[first(idxs)]
_compT = dispatchT(first_comp)
_compf = compf(first_comp)
_compg = compg(first_comp)
_ff = fftype(first_comp)
_dim = dim(first_comp)
_pdim = pdim(first_comp)
_outdim = outdim(first_comp)
_extdim = extdim(first_comp)

strides = register_edges!(im, idxs, _dim, _outdim, _pdim, _extdim)

verbose && println(" - EdgeBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))")

ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out)
ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out, strides.ext)
end

batch_by_idxs(v, idxs::Vector{Vector{Int}}) = [v for batch in idxs]
Expand Down
Loading

0 comments on commit 38bf0f4

Please sign in to comment.