Skip to content

Commit

Permalink
Merge pull request #186 from JuliaDynamics/hw/extinputs
Browse files Browse the repository at this point in the history
add external input feature
  • Loading branch information
hexaeder authored Nov 27, 2024
2 parents 953624a + d6c336d commit dd7f5e9
Show file tree
Hide file tree
Showing 11 changed files with 124 additions and 143 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
Expand Down Expand Up @@ -64,6 +65,7 @@ Random = "1"
RecursiveArrayTools = "3.27.0"
SciMLBase = "2"
SparseArrays = "1"
Static = "1.1.1"
StaticArrays = "1.9.4"
SteadyStateDiffEq = "2.2.0"
StyledStrings = "1.0.3"
Expand Down
29 changes: 7 additions & 22 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module CUDAExt
using NetworkDynamics: Network, NetworkLayer, VertexBatch, EdgeBatch,
using NetworkDynamics: Network, NetworkLayer, ComponentBatch,
KAAggregator, AggregationMap, SparseAggregator,
LazyGBufProvider, EagerGBufProvider, LazyGBuf,
dispatchT, compf, compg, fftype, iscudacompatible, executionstyle
dispatchT, iscudacompatible, executionstyle
using NetworkDynamics.PreallocationTools: DiffCache
using NetworkDynamics: KernelAbstractions as KA

Expand Down Expand Up @@ -84,29 +84,14 @@ end
####
#### Adapt VertexBatch/EdgeBatch
####
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::VertexBatch)
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::ComponentBatch)
Adapt.adapt_structure(CuArray, b)
end
function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::EdgeBatch)
Adapt.adapt_structure(CuArray, b)
end
function Adapt.adapt_structure(to, b::VertexBatch)
idxs = adapt(to, b.indices)
f = compf(b)
g = compg(b)
ff = fftype(b)
VertexBatch{dispatchT(b), typeof.((f,g,ff,idxs))...}(
idxs, f, g, ff, b.statestride, b.outstride, b.pstride, b.aggbufstride)
function Adapt.adapt_structure(to, b::ComponentBatch)
indices = adapt(to, b.indices)
ComponentBatch(dispatchT(b), indices, b.compf, b.compg, b.ff,
b.statestride, b.pstride, b.inbufstride, b.outbufstride)
end
function Adapt.adapt_structure(to, b::EdgeBatch)
idxs = adapt(to, b.indices)
f = compf(b)
g = compg(b)
ff = fftype(b)
EdgeBatch{dispatchT(b), typeof.((f,g,ff,idxs))...}(
idxs, f, g, ff, b.statestride, b.outstride, b.pstride, b.gbufstride)
end


####
#### utils
Expand Down
1 change: 1 addition & 0 deletions src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using FastClosures: @closure
using ForwardDiff: ForwardDiff
using Printf: @sprintf
using Random: Random
using Static: Static, StaticInt

@static if VERSION v"1.11.0-0"
using Base: AnnotatedIOBuffer, AnnotatedString
Expand Down
80 changes: 20 additions & 60 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,76 +277,36 @@ end

function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose)
components = @view im.vertexm[idxs]
_compT = dispatchT(first(components))
_compf = compf(first(components))
_compg = compg(first(components))
_ff = fftype(first(components))
_dim = dim(first(components))
_pdim = pdim(first(components))
_outdim = outdim(first(components))

try
# TODO: those checks seems expensive and redundant
# _compT = dispatchT(only(unique(dispatchT, components)))
# _compf = compf(only(unique(compf, components)))
# _compg = compg(only(unique(compg, components)))
# _ff = fftype(only(unique(fftype, components)))
# _dim = dim(only(unique(dim, components)))
# _outdim = outdim(only(unique(outdim, components)))
# _pdim = pdim(only(unique(pdim, components)))
_compT = dispatchT(first(components))
_compf = compf(first(components))
_compg = compg(first(components))
_ff = fftype(first(components))
_dim = dim(first(components))
_outdim = outdim(first(components))
_pdim = pdim(first(components))

(statestride, outstride, pstride, aggbufstride) =
register_vertices!(im, _dim, _outdim, _pdim, idxs)
strides = register_vertices!(im, _dim, _outdim, _pdim, idxs)

verbose &&
println(" - VertexBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))")
verbose && println(" - VertexBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))")

VertexBatch{_compT, typeof(_compf), typeof(_compg), typeof(_ff), typeof(idxs)}(
idxs, _compf, _compg, _ff, statestride, outstride, pstride, aggbufstride)
catch e
if e isa ArgumentError && startswith(e.msg, "Collection has multiple elements")
throw(ArgumentError("Provided vertex models $idxs use the same function but have different metadata (dim, pdim,type,...)"))
else
rethrow(e)
end
end
ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out)
end

function EdgeBatch(im::IndexManager, idxs::Vector{Int}; verbose)
components = @view im.edgem[idxs]
_compT = dispatchT(first(components))
_compf = compf(first(components))
_compg = compg(first(components))
_ff = fftype(first(components))
_dim = dim(first(components))
_pdim = pdim(first(components))
_outdim = outdim(first(components))

try
# TODO: those checks seems expensive and redundant
# _compT = dispatchT(only(unique(dispatchT, components)))
# _compf = compf(only(unique(compf, components)))
# _compg = compg(only(unique(compg, components)))
# _ff = fftype(only(unique(fftype, components)))
# _dim = dim(only(unique(dim, components)))
# _outdim = outdim(only(unique(outdim, components)))
# _pdim = pdim(only(unique(pdim, components)))
_compT = dispatchT(first(components))
_compf = compf(first(components))
_compg = compg(first(components))
_ff = fftype(first(components))
_dim = dim(first(components))
_outdim = outdim(first(components))
_pdim = pdim(first(components))

(statestride, outstride, pstride, gbufstride) =
register_edges!(im, _dim, _outdim, _pdim, idxs)
strides = register_edges!(im, _dim, _outdim, _pdim, idxs)

verbose &&
println(" - EdgeBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))")
verbose && println(" - EdgeBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))")

EdgeBatch{_compT, typeof(_compf), typeof(_compg), typeof(_ff), typeof(idxs)}(
idxs, _compf, _compg, _ff, statestride, outstride, pstride, gbufstride)
catch e
if e isa ArgumentError && startswith(e.msg, "Collection has multiple elements")
throw(ArgumentError("Provided edge models $idxs use the same function but have different metadata (dim, pdim,type,...)"))
else
rethrow(e)
end
end
ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out)
end

batch_by_idxs(v, idxs::Vector{Vector{Int}}) = [v for batch in idxs]
Expand Down
2 changes: 1 addition & 1 deletion src/coreloop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ end
_o = _needs_out(fg, batch) ? view(o, out_range(batch, i)) : nothing
_du = _needs_du(fg, batch) ? view(du, state_range(batch, i)) : nothing
_u = _needs_u(fg, batch) ? view(u, state_range(batch, i)) : nothing
_agg = _needs_in(fg, batch) ? view(aggbuf, aggbuf_range(batch, i)) : nothing
_agg = _needs_in(fg, batch) ? view(aggbuf, in_range(batch, i)) : nothing
_p = _needs_p(fg, batch) && _indexable(p) ? view(p, parameter_range(batch, i)) : p
evalf(fg, batch) && apply_compf(compf(batch), _du, _u, (_agg,), _p, t)
evalg(fg, batch) && apply_compg(fftype(batch), compg(batch), (_o,), _u, (_agg,), _p, t)
Expand Down
11 changes: 5 additions & 6 deletions src/gbufs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ struct EagerGBufProvider{MT,C} <: GBufProvider
end

function EagerGBufProvider(im::IndexManager, batches)
map = zeros(Int, ne(im.g) * im.vdepth, 2)
map = zeros(Int, ne(im.g) * im.vdepth * 2)
for (i, e) in pairs(im.edgevec)
map[im.e_gbufr[i], 1] .= im.v_out[e.src]
map[im.e_gbufr[i], 2] .= im.v_out[e.dst]
map[im.e_gbufr[i].src] .= im.v_out[e.src]
map[im.e_gbufr[i].dst] .= im.v_out[e.dst]
end

N = ForwardDiff.pickchunksize(max(im.lastidx_dynamic, im.lastidx_p))
Expand All @@ -25,9 +25,8 @@ get_gbuf(bufp::EagerGBufProvider, o) = get_tmp(bufp.diffcache, o)
gather!(bufp::EagerGBufProvider, gbuf, o) = NNlib.gather!(gbuf, o, bufp.map)

Base.@propagate_inbounds function get_src_dst(gbuf::AbstractArray, batch, i)
bufr = @views gbuf_range(batch, i)
src = @views gbuf[bufr, 1]
dst = @views gbuf[bufr, 2]
src = @views gbuf[in_range(batch, i, :src)]
dst = @views gbuf[in_range(batch, i, :dst)]
src, dst
end

Expand Down
82 changes: 36 additions & 46 deletions src/network_structure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mutable struct IndexManager{G}
e_data::Vector{UnitRange{Int}} # e data in flat states
e_out::Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}} # e range in output buf
e_para::Vector{UnitRange{Int}} # e para in flat para
e_gbufr::Vector{UnitRange{Int}} # e input range in gather buffer
e_gbufr::Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}} # e range in input buf
# metadata
edepth::Int
vdepth::Int
Expand All @@ -34,7 +34,8 @@ mutable struct IndexManager{G}
(Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:4)...,
Vector{UnitRange{Int}}(undef, ne(g)),
Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}}(undef, ne(g)),
(Vector{UnitRange{Int}}(undef, ne(g)) for i in 1:2)...,
Vector{UnitRange{Int}}(undef, ne(g)),
Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}}(undef, ne(g)),
edepth, vdepth,
0, 0, 0, 0, 0,
vertexm, edgem,
Expand Down Expand Up @@ -124,46 +125,30 @@ struct NetworkLayer{GT,ETup,AF}
vdepth::Int # potential becomes range for multilayer
end

abstract type ComponentBatch{F} end

struct VertexBatch{T<:VertexModel,F,G,FFT,IV<:AbstractVector{<:Integer}} <: ComponentBatch{T}
"vertex indices contained in batch"
indices::IV
"vertex function"
compf::F
"vertex output function"
compg::G
ff::FFT
"state: dimension and first index"
statestride::BatchStride{1}
"output: dimension and first index"
outstride::BatchStride{1}
"parameter: dimension and first index"
pstride::BatchStride{1}
"aggregation: dimension and first index"
aggbufstride::BatchStride{1}
end

struct EdgeBatch{T<:EdgeModel,F,G,FFT,IV<:AbstractVector{<:Integer}} <: ComponentBatch{T}
"edge indices (as in edge iterator) contained in batch"
struct ComponentBatch{T,F,G,FFT,DIM,PDIM,INDIMS,OUTDIMS,IV}
"indices contained in batch"
indices::IV
"edge function"
"internal function"
compf::F
"edge output funciton"
"output function"
compg::G
ff::FFT
"state: dimension and first index"
statestride::BatchStride{1}
"output: dimension and first index"
outstride::BatchStride{2}
statestride::BatchStride{DIM}
"parameter: dimension and first index"
pstride::BatchStride{1}
"gathered vector: dimension and first index"
gbufstride::BatchStride{1}
pstride::BatchStride{PDIM}
"inputbuf: dimension(s) and first index"
inbufstride::BatchStride{INDIMS}
"outputbuf: dimension(s) and first index"
outbufstride::BatchStride{OUTDIMS}
function ComponentBatch(dT, i, f, g, ff, ss, ps, is, os)
new{dT,typeof.((f,g,ff))...,stridesT.((ss, ps, is, os))...,typeof(i)}(
i, f, g, ff, ss, ps, is, os)
end
end

@inline Base.length(cb::ComponentBatch) = Base.length(cb.indices)
@inline dispatchT(::ComponentBatch{F}) where {F} = F
@inline dispatchT(::ComponentBatch{T}) where {T} = T
@inline compf(b::ComponentBatch) = b.compf
@inline compg(b::ComponentBatch) = b.compg
@inline fftype(b::ComponentBatch) = b.ff
Expand All @@ -172,10 +157,10 @@ end

@inline state_range(batch, i) = _range(batch.statestride, i)
@inline parameter_range(batch, i) = _range(batch.pstride, i)
@inline out_range(batch, i) = _range(batch.outstride, i)
@inline out_range(batch, i, j) = _range(batch.outstride, i, j)
@inline aggbuf_range(batch, i) = _range(batch.aggbufstride, i)
@inline gbuf_range(batch, i) = _range(batch.gbufstride, i)
@inline out_range(batch, i) = _range(batch.outbufstride, i)
@inline out_range(batch, i, j) = _range(batch.outbufstride, i, j)
@inline in_range(batch, i) = _range(batch.inbufstride, i)
@inline in_range(batch, i, j) = _range(batch.inbufstride, i, j)

function register_vertices!(im::IndexManager, dim, outdim, pdim, idxs)
for i in idxs
Expand All @@ -184,10 +169,12 @@ function register_vertices!(im::IndexManager, dim, outdim, pdim, idxs)
im.v_para[i] = _nextprange!(im, pdim)
im.v_aggr[i] = _nextaggrrange!(im, im.edepth)
end
(BatchStride(first(im.v_data[first(idxs)]), dim),
BatchStride(first(im.v_out[first(idxs)]), outdim),
BatchStride(first(im.v_para[first(idxs)]), pdim),
BatchStride(first(im.v_aggr[first(idxs)]), im.edepth))
(;
state = BatchStride(first(im.v_data[first(idxs)]), dim),
p = BatchStride(first(im.v_para[first(idxs)]), pdim),
in = BatchStride(first(im.v_aggr[first(idxs)]), im.edepth),
out = BatchStride(first(im.v_out[first(idxs)]), outdim),
)
end
function register_edges!(im::IndexManager, dim, outdim, pdim, idxs)
for i in idxs
Expand All @@ -196,12 +183,15 @@ function register_edges!(im::IndexManager, dim, outdim, pdim, idxs)
im.e_out[i] = (src = _nextoutrange!(im, outdim.src),
dst = _nextoutrange!(im, outdim.dst))
im.e_para[i] = _nextprange!(im, pdim)
im.e_gbufr[i] = _nextgbufrange!(im, im.vdepth)
im.e_gbufr[i] = (src = _nextgbufrange!(im, im.vdepth),
dst = _nextgbufrange!(im, im.vdepth))
end
(BatchStride(first(im.e_data[first(idxs)]), dim),
BatchStride(first(flatrange(im.e_out[first(idxs)])), (outdim.src, outdim.dst)),
BatchStride(first(im.e_para[first(idxs)]), pdim),
BatchStride(first(im.e_gbufr[first(idxs)]), im.vdepth))
(;
state = BatchStride(first(im.e_data[first(idxs)]), dim),
p = BatchStride(first(im.e_para[first(idxs)]), pdim),
in = BatchStride(first(flatrange(im.e_gbufr[first(idxs)])), (;src=im.vdepth, dst=im.vdepth)),
out = BatchStride(first(flatrange(im.e_out[first(idxs)])), (;src=outdim.src, dst=outdim.dst)),
)
end
function _nexturange!(im::IndexManager, N)
newlast, range = _nextrange(im.lastidx_dynamic, N)
Expand Down
24 changes: 17 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
struct BatchStride{N}
struct BatchStride{T}
first::Int
strides::NTuple{N,Int}
strides::T # either StaticInt, Tuple or NamedTuple of StaticInt
end
BatchStride(first::Int, stride::Int) = BatchStride(first, (stride,))
stridesT(::BatchStride{T}) where {T} = T
function BatchStride(first::Int, strides::Union{Tuple,NamedTuple})
strides = map(Static.static, strides)
BatchStride{typeof(strides)}(first, strides)
end
BatchStride(first::Int, stride::Int) = BatchStride(first, Static.static(stride))

# get full stride length (all substrides)
_fullstride(bs::BatchStride{1}) = @inbounds bs.strides[1]
_fullstride(bs::BatchStride) = sum(bs.strides)

# full range for N elements with this stride
Expand All @@ -14,15 +18,21 @@ _fullstride(bs::BatchStride) = sum(bs.strides)
end

# range of i-th element
@inline function _range(bs::BatchStride{1}, i)
@inline function _range(bs::BatchStride{<:StaticInt}, i)
start = bs.first + (i - 1) * _fullstride(bs)
start:start+_fullstride(bs)-1
end
# subrange j of i-th element
@inline function _range(bs::BatchStride{N}, i, j) where {N}
start = bs.first + (i - 1) * _fullstride(bs) + sum(bs.strides[1:j-1], init=0)
@inline function _range(bs::BatchStride{<:Union{Tuple, NamedTuple}}, i, j)
start = bs.first + (i - 1) * _fullstride(bs) + _sum_upto_dim(bs.strides, j)
start:start+bs.strides[j]-1
end
@inline _sum_upto_dim(t::Tuple, idx, init=0) = sum(t[1:idx-1], init=0)
@inline _sum_upto_dim(t::NamedTuple, idx::Int) = _sum_upto_dim(values(t), idx)
@inline function _sum_upto_dim(t::NamedTuple, name::Symbol)
idx = findfirst(isequal(name), keys(t))
_sum_upto_dim(values(t), idx)
end

flatrange(r::AbstractRange) = r
function flatrange(rs::@NamedTuple{src::UnitRange{Int}, dst::UnitRange{Int}})
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
Loading

0 comments on commit dd7f5e9

Please sign in to comment.