diff --git a/Project.toml b/Project.toml index 13cf4d80..10d70f70 100644 --- a/Project.toml +++ b/Project.toml @@ -24,6 +24,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b" @@ -64,6 +65,7 @@ Random = "1" RecursiveArrayTools = "3.27.0" SciMLBase = "2" SparseArrays = "1" +Static = "1.1.1" StaticArrays = "1.9.4" SteadyStateDiffEq = "2.2.0" StyledStrings = "1.0.3" diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index adf42ca6..21257356 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -1,8 +1,8 @@ module CUDAExt -using NetworkDynamics: Network, NetworkLayer, VertexBatch, EdgeBatch, +using NetworkDynamics: Network, NetworkLayer, ComponentBatch, KAAggregator, AggregationMap, SparseAggregator, LazyGBufProvider, EagerGBufProvider, LazyGBuf, - dispatchT, compf, compg, fftype, iscudacompatible, executionstyle + dispatchT, iscudacompatible, executionstyle using NetworkDynamics.PreallocationTools: DiffCache using NetworkDynamics: KernelAbstractions as KA @@ -84,29 +84,14 @@ end #### #### Adapt VertexBatch/EdgeBatch #### -function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::VertexBatch) +function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::ComponentBatch) Adapt.adapt_structure(CuArray, b) end -function Adapt.adapt_structure(to::Type{<:CuArray{<:AbstractFloat}}, b::EdgeBatch) - Adapt.adapt_structure(CuArray, b) -end -function Adapt.adapt_structure(to, b::VertexBatch) - idxs = adapt(to, b.indices) - f = compf(b) - g = compg(b) - ff = fftype(b) - VertexBatch{dispatchT(b), typeof.((f,g,ff,idxs))...}( - idxs, f, g, ff, b.statestride, b.outstride, b.pstride, b.aggbufstride) +function Adapt.adapt_structure(to, b::ComponentBatch) + indices = adapt(to, b.indices) + ComponentBatch(dispatchT(b), indices, b.compf, b.compg, b.ff, + b.statestride, b.pstride, b.inbufstride, b.outbufstride) end -function Adapt.adapt_structure(to, b::EdgeBatch) - idxs = adapt(to, b.indices) - f = compf(b) - g = compg(b) - ff = fftype(b) - EdgeBatch{dispatchT(b), typeof.((f,g,ff,idxs))...}( - idxs, f, g, ff, b.statestride, b.outstride, b.pstride, b.gbufstride) -end - #### #### utils diff --git a/src/NetworkDynamics.jl b/src/NetworkDynamics.jl index ca6cc548..ce2e5e48 100644 --- a/src/NetworkDynamics.jl +++ b/src/NetworkDynamics.jl @@ -20,6 +20,7 @@ using FastClosures: @closure using ForwardDiff: ForwardDiff using Printf: @sprintf using Random: Random +using Static: Static, StaticInt @static if VERSION ≥ v"1.11.0-0" using Base: AnnotatedIOBuffer, AnnotatedString diff --git a/src/construction.jl b/src/construction.jl index cb90d2f8..a12525e1 100644 --- a/src/construction.jl +++ b/src/construction.jl @@ -277,76 +277,36 @@ end function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose) components = @view im.vertexm[idxs] + _compT = dispatchT(first(components)) + _compf = compf(first(components)) + _compg = compg(first(components)) + _ff = fftype(first(components)) + _dim = dim(first(components)) + _pdim = pdim(first(components)) + _outdim = outdim(first(components)) - try - # TODO: those checks seems expensive and redundant - # _compT = dispatchT(only(unique(dispatchT, components))) - # _compf = compf(only(unique(compf, components))) - # _compg = compg(only(unique(compg, components))) - # _ff = fftype(only(unique(fftype, components))) - # _dim = dim(only(unique(dim, components))) - # _outdim = outdim(only(unique(outdim, components))) - # _pdim = pdim(only(unique(pdim, components))) - _compT = dispatchT(first(components)) - _compf = compf(first(components)) - _compg = compg(first(components)) - _ff = fftype(first(components)) - _dim = dim(first(components)) - _outdim = outdim(first(components)) - _pdim = pdim(first(components)) - - (statestride, outstride, pstride, aggbufstride) = - register_vertices!(im, _dim, _outdim, _pdim, idxs) + strides = register_vertices!(im, _dim, _outdim, _pdim, idxs) - verbose && - println(" - VertexBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))") + verbose && println(" - VertexBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))") - VertexBatch{_compT, typeof(_compf), typeof(_compg), typeof(_ff), typeof(idxs)}( - idxs, _compf, _compg, _ff, statestride, outstride, pstride, aggbufstride) - catch e - if e isa ArgumentError && startswith(e.msg, "Collection has multiple elements") - throw(ArgumentError("Provided vertex models $idxs use the same function but have different metadata (dim, pdim,type,...)")) - else - rethrow(e) - end - end + ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out) end function EdgeBatch(im::IndexManager, idxs::Vector{Int}; verbose) components = @view im.edgem[idxs] + _compT = dispatchT(first(components)) + _compf = compf(first(components)) + _compg = compg(first(components)) + _ff = fftype(first(components)) + _dim = dim(first(components)) + _pdim = pdim(first(components)) + _outdim = outdim(first(components)) - try - # TODO: those checks seems expensive and redundant - # _compT = dispatchT(only(unique(dispatchT, components))) - # _compf = compf(only(unique(compf, components))) - # _compg = compg(only(unique(compg, components))) - # _ff = fftype(only(unique(fftype, components))) - # _dim = dim(only(unique(dim, components))) - # _outdim = outdim(only(unique(outdim, components))) - # _pdim = pdim(only(unique(pdim, components))) - _compT = dispatchT(first(components)) - _compf = compf(first(components)) - _compg = compg(first(components)) - _ff = fftype(first(components)) - _dim = dim(first(components)) - _outdim = outdim(first(components)) - _pdim = pdim(first(components)) - - (statestride, outstride, pstride, gbufstride) = - register_edges!(im, _dim, _outdim, _pdim, idxs) + strides = register_edges!(im, _dim, _outdim, _pdim, idxs) - verbose && - println(" - EdgeBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))") + verbose && println(" - EdgeBatch: dim=$(_dim), pdim=$(_pdim), length=$(length(idxs))") - EdgeBatch{_compT, typeof(_compf), typeof(_compg), typeof(_ff), typeof(idxs)}( - idxs, _compf, _compg, _ff, statestride, outstride, pstride, gbufstride) - catch e - if e isa ArgumentError && startswith(e.msg, "Collection has multiple elements") - throw(ArgumentError("Provided edge models $idxs use the same function but have different metadata (dim, pdim,type,...)")) - else - rethrow(e) - end - end + ComponentBatch(_compT, idxs, _compf, _compg, _ff, strides.state, strides.p, strides.in, strides.out) end batch_by_idxs(v, idxs::Vector{Vector{Int}}) = [v for batch in idxs] diff --git a/src/coreloop.jl b/src/coreloop.jl index a32ea1c6..861a221e 100644 --- a/src/coreloop.jl +++ b/src/coreloop.jl @@ -142,7 +142,7 @@ end _o = _needs_out(fg, batch) ? view(o, out_range(batch, i)) : nothing _du = _needs_du(fg, batch) ? view(du, state_range(batch, i)) : nothing _u = _needs_u(fg, batch) ? view(u, state_range(batch, i)) : nothing - _agg = _needs_in(fg, batch) ? view(aggbuf, aggbuf_range(batch, i)) : nothing + _agg = _needs_in(fg, batch) ? view(aggbuf, in_range(batch, i)) : nothing _p = _needs_p(fg, batch) && _indexable(p) ? view(p, parameter_range(batch, i)) : p evalf(fg, batch) && apply_compf(compf(batch), _du, _u, (_agg,), _p, t) evalg(fg, batch) && apply_compg(fftype(batch), compg(batch), (_o,), _u, (_agg,), _p, t) diff --git a/src/gbufs.jl b/src/gbufs.jl index 1834a2da..b797a278 100644 --- a/src/gbufs.jl +++ b/src/gbufs.jl @@ -11,10 +11,10 @@ struct EagerGBufProvider{MT,C} <: GBufProvider end function EagerGBufProvider(im::IndexManager, batches) - map = zeros(Int, ne(im.g) * im.vdepth, 2) + map = zeros(Int, ne(im.g) * im.vdepth * 2) for (i, e) in pairs(im.edgevec) - map[im.e_gbufr[i], 1] .= im.v_out[e.src] - map[im.e_gbufr[i], 2] .= im.v_out[e.dst] + map[im.e_gbufr[i].src] .= im.v_out[e.src] + map[im.e_gbufr[i].dst] .= im.v_out[e.dst] end N = ForwardDiff.pickchunksize(max(im.lastidx_dynamic, im.lastidx_p)) @@ -25,9 +25,8 @@ get_gbuf(bufp::EagerGBufProvider, o) = get_tmp(bufp.diffcache, o) gather!(bufp::EagerGBufProvider, gbuf, o) = NNlib.gather!(gbuf, o, bufp.map) Base.@propagate_inbounds function get_src_dst(gbuf::AbstractArray, batch, i) - bufr = @views gbuf_range(batch, i) - src = @views gbuf[bufr, 1] - dst = @views gbuf[bufr, 2] + src = @views gbuf[in_range(batch, i, :src)] + dst = @views gbuf[in_range(batch, i, :dst)] src, dst end diff --git a/src/network_structure.jl b/src/network_structure.jl index ffc98d0b..d998b3e7 100644 --- a/src/network_structure.jl +++ b/src/network_structure.jl @@ -10,7 +10,7 @@ mutable struct IndexManager{G} e_data::Vector{UnitRange{Int}} # e data in flat states e_out::Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}} # e range in output buf e_para::Vector{UnitRange{Int}} # e para in flat para - e_gbufr::Vector{UnitRange{Int}} # e input range in gather buffer + e_gbufr::Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}} # e range in input buf # metadata edepth::Int vdepth::Int @@ -34,7 +34,8 @@ mutable struct IndexManager{G} (Vector{UnitRange{Int}}(undef, nv(g)) for i in 1:4)..., Vector{UnitRange{Int}}(undef, ne(g)), Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}}(undef, ne(g)), - (Vector{UnitRange{Int}}(undef, ne(g)) for i in 1:2)..., + Vector{UnitRange{Int}}(undef, ne(g)), + Vector{@NamedTuple{src::UnitRange{Int},dst::UnitRange{Int}}}(undef, ne(g)), edepth, vdepth, 0, 0, 0, 0, 0, vertexm, edgem, @@ -124,46 +125,30 @@ struct NetworkLayer{GT,ETup,AF} vdepth::Int # potential becomes range for multilayer end -abstract type ComponentBatch{F} end - -struct VertexBatch{T<:VertexModel,F,G,FFT,IV<:AbstractVector{<:Integer}} <: ComponentBatch{T} - "vertex indices contained in batch" - indices::IV - "vertex function" - compf::F - "vertex output function" - compg::G - ff::FFT - "state: dimension and first index" - statestride::BatchStride{1} - "output: dimension and first index" - outstride::BatchStride{1} - "parameter: dimension and first index" - pstride::BatchStride{1} - "aggregation: dimension and first index" - aggbufstride::BatchStride{1} -end - -struct EdgeBatch{T<:EdgeModel,F,G,FFT,IV<:AbstractVector{<:Integer}} <: ComponentBatch{T} - "edge indices (as in edge iterator) contained in batch" +struct ComponentBatch{T,F,G,FFT,DIM,PDIM,INDIMS,OUTDIMS,IV} + "indices contained in batch" indices::IV - "edge function" + "internal function" compf::F - "edge output funciton" + "output function" compg::G ff::FFT "state: dimension and first index" - statestride::BatchStride{1} - "output: dimension and first index" - outstride::BatchStride{2} + statestride::BatchStride{DIM} "parameter: dimension and first index" - pstride::BatchStride{1} - "gathered vector: dimension and first index" - gbufstride::BatchStride{1} + pstride::BatchStride{PDIM} + "inputbuf: dimension(s) and first index" + inbufstride::BatchStride{INDIMS} + "outputbuf: dimension(s) and first index" + outbufstride::BatchStride{OUTDIMS} + function ComponentBatch(dT, i, f, g, ff, ss, ps, is, os) + new{dT,typeof.((f,g,ff))...,stridesT.((ss, ps, is, os))...,typeof(i)}( + i, f, g, ff, ss, ps, is, os) + end end @inline Base.length(cb::ComponentBatch) = Base.length(cb.indices) -@inline dispatchT(::ComponentBatch{F}) where {F} = F +@inline dispatchT(::ComponentBatch{T}) where {T} = T @inline compf(b::ComponentBatch) = b.compf @inline compg(b::ComponentBatch) = b.compg @inline fftype(b::ComponentBatch) = b.ff @@ -172,10 +157,10 @@ end @inline state_range(batch, i) = _range(batch.statestride, i) @inline parameter_range(batch, i) = _range(batch.pstride, i) -@inline out_range(batch, i) = _range(batch.outstride, i) -@inline out_range(batch, i, j) = _range(batch.outstride, i, j) -@inline aggbuf_range(batch, i) = _range(batch.aggbufstride, i) -@inline gbuf_range(batch, i) = _range(batch.gbufstride, i) +@inline out_range(batch, i) = _range(batch.outbufstride, i) +@inline out_range(batch, i, j) = _range(batch.outbufstride, i, j) +@inline in_range(batch, i) = _range(batch.inbufstride, i) +@inline in_range(batch, i, j) = _range(batch.inbufstride, i, j) function register_vertices!(im::IndexManager, dim, outdim, pdim, idxs) for i in idxs @@ -184,10 +169,12 @@ function register_vertices!(im::IndexManager, dim, outdim, pdim, idxs) im.v_para[i] = _nextprange!(im, pdim) im.v_aggr[i] = _nextaggrrange!(im, im.edepth) end - (BatchStride(first(im.v_data[first(idxs)]), dim), - BatchStride(first(im.v_out[first(idxs)]), outdim), - BatchStride(first(im.v_para[first(idxs)]), pdim), - BatchStride(first(im.v_aggr[first(idxs)]), im.edepth)) + (; + state = BatchStride(first(im.v_data[first(idxs)]), dim), + p = BatchStride(first(im.v_para[first(idxs)]), pdim), + in = BatchStride(first(im.v_aggr[first(idxs)]), im.edepth), + out = BatchStride(first(im.v_out[first(idxs)]), outdim), + ) end function register_edges!(im::IndexManager, dim, outdim, pdim, idxs) for i in idxs @@ -196,12 +183,15 @@ function register_edges!(im::IndexManager, dim, outdim, pdim, idxs) im.e_out[i] = (src = _nextoutrange!(im, outdim.src), dst = _nextoutrange!(im, outdim.dst)) im.e_para[i] = _nextprange!(im, pdim) - im.e_gbufr[i] = _nextgbufrange!(im, im.vdepth) + im.e_gbufr[i] = (src = _nextgbufrange!(im, im.vdepth), + dst = _nextgbufrange!(im, im.vdepth)) end - (BatchStride(first(im.e_data[first(idxs)]), dim), - BatchStride(first(flatrange(im.e_out[first(idxs)])), (outdim.src, outdim.dst)), - BatchStride(first(im.e_para[first(idxs)]), pdim), - BatchStride(first(im.e_gbufr[first(idxs)]), im.vdepth)) + (; + state = BatchStride(first(im.e_data[first(idxs)]), dim), + p = BatchStride(first(im.e_para[first(idxs)]), pdim), + in = BatchStride(first(flatrange(im.e_gbufr[first(idxs)])), (;src=im.vdepth, dst=im.vdepth)), + out = BatchStride(first(flatrange(im.e_out[first(idxs)])), (;src=outdim.src, dst=outdim.dst)), + ) end function _nexturange!(im::IndexManager, N) newlast, range = _nextrange(im.lastidx_dynamic, N) diff --git a/src/utils.jl b/src/utils.jl index 445d9c1e..860785f8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,11 +1,15 @@ -struct BatchStride{N} +struct BatchStride{T} first::Int - strides::NTuple{N,Int} + strides::T # either StaticInt, Tuple or NamedTuple of StaticInt end -BatchStride(first::Int, stride::Int) = BatchStride(first, (stride,)) +stridesT(::BatchStride{T}) where {T} = T +function BatchStride(first::Int, strides::Union{Tuple,NamedTuple}) + strides = map(Static.static, strides) + BatchStride{typeof(strides)}(first, strides) +end +BatchStride(first::Int, stride::Int) = BatchStride(first, Static.static(stride)) # get full stride length (all substrides) -_fullstride(bs::BatchStride{1}) = @inbounds bs.strides[1] _fullstride(bs::BatchStride) = sum(bs.strides) # full range for N elements with this stride @@ -14,15 +18,21 @@ _fullstride(bs::BatchStride) = sum(bs.strides) end # range of i-th element -@inline function _range(bs::BatchStride{1}, i) +@inline function _range(bs::BatchStride{<:StaticInt}, i) start = bs.first + (i - 1) * _fullstride(bs) start:start+_fullstride(bs)-1 end # subrange j of i-th element -@inline function _range(bs::BatchStride{N}, i, j) where {N} - start = bs.first + (i - 1) * _fullstride(bs) + sum(bs.strides[1:j-1], init=0) +@inline function _range(bs::BatchStride{<:Union{Tuple, NamedTuple}}, i, j) + start = bs.first + (i - 1) * _fullstride(bs) + _sum_upto_dim(bs.strides, j) start:start+bs.strides[j]-1 end +@inline _sum_upto_dim(t::Tuple, idx, init=0) = sum(t[1:idx-1], init=0) +@inline _sum_upto_dim(t::NamedTuple, idx::Int) = _sum_upto_dim(values(t), idx) +@inline function _sum_upto_dim(t::NamedTuple, name::Symbol) + idx = findfirst(isequal(name), keys(t)) + _sum_upto_dim(values(t), idx) +end flatrange(r::AbstractRange) = r function flatrange(rs::@NamedTuple{src::UnitRange{Int}, dst::UnitRange{Int}}) diff --git a/test/Project.toml b/test/Project.toml index c26e2f34..b6296f61 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,6 +31,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 002c49f8..3022c7da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,10 +30,10 @@ using ExplicitImports NetworkDynamics.CHECK_COMPONENT[] = false @safetestset "construction test" begin include("construction_test.jl") end @safetestset "Aggregation Tests" begin include("aggregators_test.jl") end - @safetestset "Symbolic Indexing Tests" begin include("symbolicindexing_test.jl") end @safetestset "massmatrix test" begin include("massmatrix_test.jl") end NetworkDynamics.CHECK_COMPONENT[] = true + @safetestset "Symbolic Indexing Tests" begin include("symbolicindexing_test.jl") end @safetestset "doctor test" begin include("doctor_test.jl") end @safetestset "Diffusion test" begin include("diffusion_test.jl") end diff --git a/test/utils_test.jl b/test/utils_test.jl index a50d2789..999dc043 100644 --- a/test/utils_test.jl +++ b/test/utils_test.jl @@ -112,4 +112,37 @@ using NetworkDynamics @test str_significant(0.00000122; sigdigits=3) == "1.22e-6" @test str_significant(-123.294191; sigdigits=5) == "-123.29" end + + @testset "batch stride" begin + using NetworkDynamics: BatchStride, _fullstride, _fullrange, _range + using Static + + b1 = BatchStride(2, 3) + b2 = BatchStride(2, (3, 1)) + + @test _fullstride(b1) == 3 + @test _fullstride(b2) == 4 + Test.@inferred _fullstride(b1) + Test.@inferred _fullstride(b2) + + @test _fullrange(b1, 2) == 2:(2*3+2-1) + @test _fullrange(b2, 2) == 2:(2*(3+1)+2-1) + + @test _range(b1, 1) == 2:4 + @test _range(b1, 2) == 5:7 + @test _range(b2, 1, 1) == 2:4 + @test _range(b2, 1, 2) == 5:5 + @test _range(b2, 2, 1) == 6:8 + @test _range(b2, 2, 2) == 9:9 + + # named batch stride + bn = BatchStride(2, (;src=3, dst=1)) + @test _fullstride(bn) == 4 + Test.@inferred _fullstride(bn) + @test _fullrange(bn, 2) == 2:(2*(3+1)+2-1) + @test _range(bn, 1, 1) == 2:4 + @test _range(bn, 1, 2) == 5:5 + @test _range(bn, 2, 1) == 6:8 + @test _range(bn, 2, 2) == 9:9 + end end