Skip to content

Commit

Permalink
Add task_local for duplicating scratch data for concurrent use
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed Nov 8, 2024
1 parent 5153307 commit 8fd03fb
Show file tree
Hide file tree
Showing 20 changed files with 250 additions and 111 deletions.
29 changes: 16 additions & 13 deletions docs/src/literate-howto/threaded_assembly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,12 @@ end
# purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but
# note that we set `fillzero = false` because we don't want to risk that a task that starts
# a bit later will zero out data that another task have already assembled.
function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues)
cell_cache = CellCache(dh)
n = ndofs_per_cell(dh)
Ke = zeros(n, n)
fe = zeros(n)
asm = start_assemble(K, f; fillzero = false)
return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm)
function create_scratch(scratch::ScratchData)
return ScratchData(
task_local(scratch.cell_cache), task_local(scratch.cellvalues),
task_local(scratch.Ke), task_local(scratch.fe),
task_local(scratch.assembler)
)
end
nothing # hide

Expand Down Expand Up @@ -220,13 +219,17 @@ using OhMyThreads, TaskLocalValues

function assemble_global!(
K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors,
cellvalues_template::CellValues; ntasks = Threads.nthreads()
cellvalues::CellValues; ntasks = Threads.nthreads()
)
## Zero-out existing data in K and f
_ = start_assemble(K, f)
## Body force and material stiffness
b = Vec{3}((0.0, 0.0, -1.0))
C = create_material_stiffness()
## Scratch data
scratch = ScratchData(
CellCache(dh), cellvalues,
zeros(ndofs_per_cell(dh), ndofs_per_cell(dh)), zeros(ndofs_per_cell(dh)),
start_assemble(K, f)
)
## Loop over the colors
for color in colors
## Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of
Expand All @@ -237,8 +240,8 @@ function assemble_global!(
## Tell the @tasks loop to use the scheduler defined above
@set scheduler = scheduler
## Obtain a task local scratch and unpack it
@local scratch = ScratchData(dh, K, f, cellvalues_template)
(; cell_cache, cellvalues, Ke, fe, assembler) = scratch
@local scratch = create_scratch(scratch)
local (; cell_cache, cellvalues, Ke, fe, assembler) = scratch
## Reinitialize the cell cache and then the cellvalues
reinit!(cell_cache, cellidx)
reinit!(cellvalues, cell_cache)
Expand All @@ -259,7 +262,7 @@ nothing # hide
# ```julia
# # using TaskLocalValues
# scratches = TaskLocalValue() do
# ScratchData(dh, K, f, cellvalues)
# create_scratch(scratch)
# end
# OhMyThreads.tforeach(color; scheduler) do cellidx
# # Obtain a task local scratch and unpack it
Expand Down
7 changes: 5 additions & 2 deletions src/FEValues/CellValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ function CellValues(::Type{T}, qr::QuadratureRule, ip::Interpolation, ip_geo::Ve
return CellValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...))
end

function Base.copy(cv::CellValues)
return CellValues(copy(cv.fun_values), copy(cv.geo_mapping), copy(cv.qr), _copy_or_nothing(cv.detJdV))
function task_local(cv::CellValues)
return CellValues(
task_local(cv.fun_values), task_local(cv.geo_mapping), task_local(cv.qr),
task_local(cv.detJdV)
)
end

# Access geometry values
Expand Down
10 changes: 6 additions & 4 deletions src/FEValues/FacetValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ function FacetValues(::Type{T}, qr::FacetQuadratureRule, ip::Interpolation, ip_g
return FacetValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...))
end

function Base.copy(fv::FacetValues)
fun_values = map(copy, fv.fun_values)
geo_mapping = map(copy, fv.geo_mapping)
return FacetValues(fun_values, geo_mapping, copy(fv.fqr), copy(fv.detJdV), copy(fv.normals), fv.current_facet)
function task_local(fv::FacetValues)
return FacetValues(
map(task_local, fv.fun_values), map(task_local, fv.geo_mapping),
task_local(fv.fqr), task_local(fv.detJdV), task_local(fv.normals),
task_local(fv.current_facet)
)
end

getngeobasefunctions(fv::FacetValues) = getngeobasefunctions(get_geo_mapping(fv))
Expand Down
15 changes: 7 additions & 8 deletions src/FEValues/FunctionValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,13 @@ function precompute_values!(fv::FunctionValues{2}, qr_points::AbstractVector{<:V
return reference_shape_hessians_gradients_and_values!(fv.d2Ndξ2, fv.dNdξ, fv.Nξ, fv.ip, qr_points)
end

function Base.copy(v::FunctionValues)
Nξ_copy = copy(v.Nξ)
Nx_copy = v.=== v.Nx ? Nξ_copy : copy(v.Nx) # Preserve aliasing
dNdx_copy = _copy_or_nothing(v.dNdx)
dNdξ_copy = _copy_or_nothing(v.dNdξ)
d2Ndx2_copy = _copy_or_nothing(v.d2Ndx2)
d2Ndξ2_copy = _copy_or_nothing(v.d2Ndξ2)
return FunctionValues(copy(v.ip), Nx_copy, Nξ_copy, dNdx_copy, dNdξ_copy, d2Ndx2_copy, d2Ndξ2_copy)
function task_local(v::FunctionValues)
= task_local(v.Nξ)
Nx = v.=== v.Nx ?: task_local(v.Nx) # Preserve aliasing
return FunctionValues(
task_local(v.ip), Nx, Nξ, task_local(v.dNdx), task_local(v.dNdξ),
task_local(v.d2Ndx2), task_local(v.d2Ndξ2)
)
end

getnbasefunctions(funvals::FunctionValues) = size(funvals.Nx, 1)
Expand Down
6 changes: 4 additions & 2 deletions src/FEValues/GeometryMapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ function precompute_values!(gm::GeometryMapping{2}, qr_points::AbstractVector{<:
return reference_shape_hessians_gradients_and_values!(gm.d2Mdξ2, gm.dMdξ, gm.M, gm.ip, qr_points)
end

function Base.copy(v::GeometryMapping)
return GeometryMapping(copy(v.ip), copy(v.M), _copy_or_nothing(v.dMdξ), _copy_or_nothing(v.d2Mdξ2))
function task_local(v::GeometryMapping)
return GeometryMapping(
task_local(v.ip), task_local(v.M), task_local(v.dMdξ), task_local(v.d2Mdξ2)
)
end

getngeobasefunctions(geo_mapping::GeometryMapping) = size(geo_mapping.M, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/FEValues/InterfaceValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ end
InterfaceValues(facetvalues_here::FVA, facetvalues_there::FVB = deepcopy(FacetValues_here)) where {FVA <: FacetValues, FVB <: FacetValues} =
InterfaceValues{FVA, FVB}(facetvalues_here, facetvalues_there)

function Base.copy(iv::InterfaceValues)
return InterfaceValues(copy(iv.here), copy(iv.there))
function task_local(iv::InterfaceValues)
return InterfaceValues(task_local(iv.here), task_local(iv.there))
end

function getnbasefunctions(iv::InterfaceValues)
Expand Down
2 changes: 2 additions & 0 deletions src/Ferrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using .CollectionsOfViews:

include("exports.jl")

# Task based multithreading support
include("multithreading.jl")

"""
AbstractRefShape{refdim}
Expand Down
10 changes: 8 additions & 2 deletions src/Quadrature/quadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,11 @@ getpoints(qr::FacetQuadratureRule, face::Int) = getpoints(qr.face_rules[face])

getrefshape(::QuadratureRule{RefShape}) where {RefShape} = RefShape

# TODO: This is used in copy(::(Cell|Face)Values), but it it useful to get an actual copy?
Base.copy(qr::Union{QuadratureRule, FacetQuadratureRule}) = qr
# TODO: For typical use the quadrature rule is read-only, but seems safer to copy anyway?
# And might even be beneficial with e.g. NUMA?
function task_local(qr::QR) where {refshape, QR <: QuadratureRule{refshape}}
return QuadratureRule{refshape}(task_local(qr.weights), task_local(qr.points))::QR
end
function task_local(qr::QR) where {refshape, QR <: FacetQuadratureRule{refshape}}
return FacetQuadratureRule{refshape}(map(task_local, qr.face_rules))::QR
end
8 changes: 8 additions & 0 deletions src/assembler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ matrix_handle(a::AbstractCSCAssembler) = a.K
matrix_handle(a::SymmetricCSCAssembler) = a.K.data
vector_handle(a::AbstractCSCAssembler) = a.f

function task_local(asm::CSCAssembler)
return CSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))
end
function task_local(asm::SymmetricCSCAssembler)
return SymmetricCSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))
end


"""
start_assemble(K::AbstractSparseMatrixCSC; fillzero::Bool=true) -> CSCAssembler
start_assemble(K::AbstractSparseMatrixCSC, f::Vector; fillzero::Bool=true) -> CSCAssembler
Expand Down
11 changes: 11 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,14 @@ end
function celldofs!(::Vector, ::FacetCache)
throw(DeprecationError("celldofs!(v::Vector, fs::FacetCache)" => "celldofs!(v, celldofs(fc))"))
end

## Deprecations introduced in [email protected] (keep until 2.0) ##

# Base.copy -> task_local, https://github.com/Ferrite-FEM/Ferrite.jl/pull/1070
import Base: copy
for T in (
CellValues, FacetValues, FunctionValues, GeometryMapping, InterfaceValues,
Interpolation, QuadratureRule, FacetQuadratureRule,
)
@eval @deprecate copy(x::$T) task_local(x)
end
5 changes: 4 additions & 1 deletion src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,7 @@ export
evaluate_at_points,
PointIterator,
PointLocation,
PointValues
PointValues,

# Misc
task_local
2 changes: 0 additions & 2 deletions src/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ nvertices(::Interpolation{RefShape}) where {RefShape} = nvertices(RefShape)
nedges(::Interpolation{RefShape}) where {RefShape} = nedges(RefShape)
nfaces(::Interpolation{RefShape}) where {RefShape} = nfaces(RefShape)

Base.copy(ip::Interpolation) = ip

"""
Ferrite.getrefdim(::Interpolation)
Expand Down
8 changes: 7 additions & 1 deletion src/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ end
UpdateFlags(; nodes::Bool = true, coords::Bool = true, dofs::Bool = true) =
UpdateFlags(nodes, coords, dofs)


###############
## CellCache ##
###############
Expand Down Expand Up @@ -83,6 +82,13 @@ function reinit!(cc::CellCache, i::Int)
return cc
end

function task_local(cc::CellCache)
return CellCache(
cc.flags, cc.grid, task_local(cc.cellid), task_local(cc.nodes),
task_local(cc.coords), cc.dh, task_local(cc.dofs)
)
end

# reinit! FEValues with CellCache
reinit!(cv::CellValues, cc::CellCache) = reinit!(cv, cc.coords)
reinit!(fv::FacetValues, cc::CellCache, f::Int) = reinit!(fv, cc.coords, f)
Expand Down
79 changes: 79 additions & 0 deletions src/multithreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
function task_local end

"""
task_local(x::T) -> T
Duplicate `x` for a new task such that it can be used concurrently with the original `x`.
This is similar to `copy` but only the data that is known to be mutated, i.e. "scratch
data", are duplicated.
Typically, for concurrent assembly, there are some data structures that can't be shared
between the tasks, for example the local element matrix/vector and the `CellValues`.
`task_local` can thus be used to duplicate theses data structures for each task based on a
"template" data structure. For example,
```julia
# "Template" local matrix and cell values
Ke = zeros(...)
cv = CellValues(...)
# Spawn `ntasks` tasks for concurrent assembly
@sync for i in 1:ntasks
Threads.@spawn begin
Ke_task = task_local(Ke)
cv_task = task_local(Ke)
for cell in cells_for_task
assemble_element!(Ke_task, cv_task, ...)
end
end
end
```
See the how-to on [multi-threaded assembly](@ref tutorial-threaded-assembly) for a complete
example.
The following "user-facing" types define methods for `task_local`:
- [`CellValues`](@ref), [`FacetValues`](@ref), [`InterfaceValues`](@ref) are duplicated
such that they can be `reinit!`ed independently.
- `DenseArray` (for e.g. the local matrix and vector) are duplicated such that they can be
modified concurrently.
- [`CellCache`](@ref) (for caching element nodes and dofs) are duplicated such that they
can be `reinit!`ed independently.
The following types also define methods for `task_local` but are typically not used directly
by the user but instead used recursively by the above types:
- [`QuadratureRule`](@ref) and [`FacetQuadratureRule`](@ref)
- All types which are `isbitstype` (e.g. `Vec`, `Tensor`, `Int`, `Float64`, etc.)
"""
task_local(::Any)

# DenseVector/DenseMatrix (e.g. local matrix and vector)
function task_local(x::T)::T where {S, T <: DenseArray{S}}
@assert !isbitstype(T)
if isbitstype(S)
# If the eltype isbitstype the normal shallow copy can be used...
return copy(x)::T
else
# ... otherwise we recurse and call task_local on the elements
return map(task_local, x)::T

Check warning on line 60 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L60

Added line #L60 was not covered by tests
end
end

# FacetQuadratureRule can store the QuadratureRules as a tuple
function task_local(x::T)::T where {T <: Tuple}
if isbitstype(T)
return x
else
return map(task_local, x)::T
end
end

# General fallback for other types
function task_local(x::T)::T where {T}
if !isbitstype(T)
error("MethodError: task_local(::$T) is not implemented")

Check warning on line 76 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L76

Added line #L76 was not covered by tests
end
return x
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Metis
include("test_utils.jl")

# Unit tests
include("test_multithreading.jl")
include("test_collectionsofviews.jl")
include("test_interpolations.jl")
include("test_cellvalues.jl")
Expand Down
25 changes: 0 additions & 25 deletions test/test_cellvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,6 @@
for (i, qp_x) in pairs(Ferrite.getpoints(quad_rule))
@test spatial_coordinate(cv, i, coords) qp_x
end

@testset "copy(::CellValues)" begin
cvc = copy(cv)
@test typeof(cv) == typeof(cvc)

# Test that all mutable types in FunctionValues and GeometryMapping have been copied
for key in (:fun_values, :geo_mapping)
val = getfield(cv, key)
valc = getfield(cvc, key)
for fname in fieldnames(typeof(val))
v = getfield(val, fname)
vc = getfield(valc, fname)
isbits(v) || @test v !== vc
@test v == vc
end
end
# Test that qr and detJdV is copied as expected.
# Note that qr remain aliased, as defined by `copy(qr)=qr`, see quadrature.jl.
for fname in (:qr, :detJdV)
v = getfield(cv, fname)
vc = getfield(cvc, fname)
fname === :qr || @test v !== vc
@test v == vc
end
end
end
end

Expand Down
16 changes: 16 additions & 0 deletions test/test_deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,20 @@ using Ferrite, Test
@test_throws Ferrite.DeprecationError celldofs!(v, fc)
end

# Ferrite v2
@testset "Base.copy -> task_local" begin
ip = Lagrange{RefQuadrilateral, 1}()
@test_deprecated copy(ip)
qr = QuadratureRule{RefQuadrilateral}(2)
@test_deprecated copy(qr)
cv = CellValues(qr, ip)
@test_deprecated copy(cv)
fqr = FacetQuadratureRule{RefQuadrilateral}(2)
@test_deprecated copy(fqr)
fv = FacetValues(fqr, ip)
@test_deprecated copy(fv)
iv = InterfaceValues(fqr, ip)
@test_deprecated copy(iv)
end

end # testset deprecations
Loading

0 comments on commit 8fd03fb

Please sign in to comment.