Skip to content

Commit

Permalink
Add task_local for duplicating scratch data for concurrent use
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed Nov 8, 2024
1 parent 5153307 commit b75e345
Show file tree
Hide file tree
Showing 20 changed files with 247 additions and 111 deletions.
29 changes: 16 additions & 13 deletions docs/src/literate-howto/threaded_assembly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,12 @@ end
# purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but
# note that we set `fillzero = false` because we don't want to risk that a task that starts
# a bit later will zero out data that another task have already assembled.
function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues)
cell_cache = CellCache(dh)
n = ndofs_per_cell(dh)
Ke = zeros(n, n)
fe = zeros(n)
asm = start_assemble(K, f; fillzero = false)
return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm)
function create_scratch(scratch::ScratchData)
return ScratchData(
task_local(scratch.cell_cache), task_local(scratch.cellvalues),
task_local(scratch.Ke), task_local(scratch.fe),
task_local(scratch.assembler)
)
end
nothing # hide

Expand Down Expand Up @@ -220,13 +219,17 @@ using OhMyThreads, TaskLocalValues

function assemble_global!(
K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors,
cellvalues_template::CellValues; ntasks = Threads.nthreads()
cellvalues::CellValues; ntasks = Threads.nthreads()
)
## Zero-out existing data in K and f
_ = start_assemble(K, f)
## Body force and material stiffness
b = Vec{3}((0.0, 0.0, -1.0))
C = create_material_stiffness()
## Scratch data
scratch = ScratchData(
CellCache(dh), cellvalues,
zeros(ndofs_per_cell(dh), ndofs_per_cell(dh)), zeros(ndofs_per_cell(dh)),
start_assemble(K, f)
)
## Loop over the colors
for color in colors
## Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of
Expand All @@ -237,8 +240,8 @@ function assemble_global!(
## Tell the @tasks loop to use the scheduler defined above
@set scheduler = scheduler
## Obtain a task local scratch and unpack it
@local scratch = ScratchData(dh, K, f, cellvalues_template)
(; cell_cache, cellvalues, Ke, fe, assembler) = scratch
@local scratch = create_scratch(scratch)
local (; cell_cache, cellvalues, Ke, fe, assembler) = scratch
## Reinitialize the cell cache and then the cellvalues
reinit!(cell_cache, cellidx)
reinit!(cellvalues, cell_cache)
Expand All @@ -259,7 +262,7 @@ nothing # hide
# ```julia
# # using TaskLocalValues
# scratches = TaskLocalValue() do
# ScratchData(dh, K, f, cellvalues)
# create_scratch(scratch)
# end
# OhMyThreads.tforeach(color; scheduler) do cellidx
# # Obtain a task local scratch and unpack it
Expand Down
7 changes: 5 additions & 2 deletions src/FEValues/CellValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ function CellValues(::Type{T}, qr::QuadratureRule, ip::Interpolation, ip_geo::Ve
return CellValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...))
end

function Base.copy(cv::CellValues)
return CellValues(copy(cv.fun_values), copy(cv.geo_mapping), copy(cv.qr), _copy_or_nothing(cv.detJdV))
function task_local(cv::CellValues)
return CellValues(

Check warning on line 69 in src/FEValues/CellValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/CellValues.jl#L68-L69

Added lines #L68 - L69 were not covered by tests
task_local(cv.fun_values), task_local(cv.geo_mapping), task_local(cv.qr),
task_local(cv.detJdV)
)
end

# Access geometry values
Expand Down
10 changes: 6 additions & 4 deletions src/FEValues/FacetValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ function FacetValues(::Type{T}, qr::FacetQuadratureRule, ip::Interpolation, ip_g
return FacetValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...))
end

function Base.copy(fv::FacetValues)
fun_values = map(copy, fv.fun_values)
geo_mapping = map(copy, fv.geo_mapping)
return FacetValues(fun_values, geo_mapping, copy(fv.fqr), copy(fv.detJdV), copy(fv.normals), fv.current_facet)
function task_local(fv::FacetValues)
return FacetValues(

Check warning on line 72 in src/FEValues/FacetValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/FacetValues.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
map(task_local, fv.fun_values), map(task_local, fv.geo_mapping),
task_local(fv.fqr), task_local(fv.detJdV), task_local(fv.normals),
task_local(fv.current_facet)
)
end

getngeobasefunctions(fv::FacetValues) = getngeobasefunctions(get_geo_mapping(fv))
Expand Down
15 changes: 7 additions & 8 deletions src/FEValues/FunctionValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,13 @@ function precompute_values!(fv::FunctionValues{2}, qr_points::AbstractVector{<:V
return reference_shape_hessians_gradients_and_values!(fv.d2Ndξ2, fv.dNdξ, fv.Nξ, fv.ip, qr_points)
end

function Base.copy(v::FunctionValues)
Nξ_copy = copy(v.Nξ)
Nx_copy = v.=== v.Nx ? Nξ_copy : copy(v.Nx) # Preserve aliasing
dNdx_copy = _copy_or_nothing(v.dNdx)
dNdξ_copy = _copy_or_nothing(v.dNdξ)
d2Ndx2_copy = _copy_or_nothing(v.d2Ndx2)
d2Ndξ2_copy = _copy_or_nothing(v.d2Ndξ2)
return FunctionValues(copy(v.ip), Nx_copy, Nξ_copy, dNdx_copy, dNdξ_copy, d2Ndx2_copy, d2Ndξ2_copy)
function task_local(v::FunctionValues)
= task_local(v.Nξ)
Nx = v.=== v.Nx ?: task_local(v.Nx) # Preserve aliasing
return FunctionValues(

Check warning on line 104 in src/FEValues/FunctionValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/FunctionValues.jl#L101-L104

Added lines #L101 - L104 were not covered by tests
task_local(v.ip), Nx, Nξ, task_local(v.dNdx), task_local(v.dNdξ),
task_local(v.d2Ndx2), task_local(v.d2Ndξ2)
)
end

getnbasefunctions(funvals::FunctionValues) = size(funvals.Nx, 1)
Expand Down
6 changes: 4 additions & 2 deletions src/FEValues/GeometryMapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ function precompute_values!(gm::GeometryMapping{2}, qr_points::AbstractVector{<:
return reference_shape_hessians_gradients_and_values!(gm.d2Mdξ2, gm.dMdξ, gm.M, gm.ip, qr_points)
end

function Base.copy(v::GeometryMapping)
return GeometryMapping(copy(v.ip), copy(v.M), _copy_or_nothing(v.dMdξ), _copy_or_nothing(v.d2Mdξ2))
function task_local(v::GeometryMapping)
return GeometryMapping(

Check warning on line 99 in src/FEValues/GeometryMapping.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/GeometryMapping.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
task_local(v.ip), task_local(v.M), task_local(v.dMdξ), task_local(v.d2Mdξ2)
)
end

getngeobasefunctions(geo_mapping::GeometryMapping) = size(geo_mapping.M, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/FEValues/InterfaceValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ end
InterfaceValues(facetvalues_here::FVA, facetvalues_there::FVB = deepcopy(FacetValues_here)) where {FVA <: FacetValues, FVB <: FacetValues} =
InterfaceValues{FVA, FVB}(facetvalues_here, facetvalues_there)

function Base.copy(iv::InterfaceValues)
return InterfaceValues(copy(iv.here), copy(iv.there))
function task_local(iv::InterfaceValues)
return InterfaceValues(task_local(iv.here), task_local(iv.there))

Check warning on line 90 in src/FEValues/InterfaceValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/InterfaceValues.jl#L89-L90

Added lines #L89 - L90 were not covered by tests
end

function getnbasefunctions(iv::InterfaceValues)
Expand Down
2 changes: 2 additions & 0 deletions src/Ferrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using .CollectionsOfViews:

include("exports.jl")

# Task based multithreading support
include("multithreading.jl")

"""
AbstractRefShape{refdim}
Expand Down
10 changes: 8 additions & 2 deletions src/Quadrature/quadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,11 @@ getpoints(qr::FacetQuadratureRule, face::Int) = getpoints(qr.face_rules[face])

getrefshape(::QuadratureRule{RefShape}) where {RefShape} = RefShape

# TODO: This is used in copy(::(Cell|Face)Values), but it it useful to get an actual copy?
Base.copy(qr::Union{QuadratureRule, FacetQuadratureRule}) = qr
# TODO: For typical use the quadrature rule is read-only, but seems safer to copy anyway?
# And might even be beneficial with e.g. NUMA?
function task_local(qr::QR) where {refshape, QR <: QuadratureRule{refshape}}
return QuadratureRule{refshape}(task_local(qr.weights), task_local(qr.points))::QR

Check warning on line 323 in src/Quadrature/quadrature.jl

View check run for this annotation

Codecov / codecov/patch

src/Quadrature/quadrature.jl#L322-L323

Added lines #L322 - L323 were not covered by tests
end
function task_local(qr::QR) where {refshape, QR <: FacetQuadratureRule{refshape}}
return FacetQuadratureRule{refshape}(map(task_local, qr.face_rules))::QR

Check warning on line 326 in src/Quadrature/quadrature.jl

View check run for this annotation

Codecov / codecov/patch

src/Quadrature/quadrature.jl#L325-L326

Added lines #L325 - L326 were not covered by tests
end
8 changes: 8 additions & 0 deletions src/assembler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ matrix_handle(a::AbstractCSCAssembler) = a.K
matrix_handle(a::SymmetricCSCAssembler) = a.K.data
vector_handle(a::AbstractCSCAssembler) = a.f

function task_local(asm::CSCAssembler)
return CSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))

Check warning on line 208 in src/assembler.jl

View check run for this annotation

Codecov / codecov/patch

src/assembler.jl#L207-L208

Added lines #L207 - L208 were not covered by tests
end
function task_local(asm::SymmetricCSCAssembler)
return SymmetricCSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))

Check warning on line 211 in src/assembler.jl

View check run for this annotation

Codecov / codecov/patch

src/assembler.jl#L210-L211

Added lines #L210 - L211 were not covered by tests
end


"""
start_assemble(K::AbstractSparseMatrixCSC; fillzero::Bool=true) -> CSCAssembler
start_assemble(K::AbstractSparseMatrixCSC, f::Vector; fillzero::Bool=true) -> CSCAssembler
Expand Down
11 changes: 11 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,14 @@ end
function celldofs!(::Vector, ::FacetCache)
throw(DeprecationError("celldofs!(v::Vector, fs::FacetCache)" => "celldofs!(v, celldofs(fc))"))
end

## Deprecations introduced in [email protected] (keep until 2.0) ##

# Base.copy -> task_local, https://github.com/Ferrite-FEM/Ferrite.jl/pull/1070
import Base: copy
for T in (
CellValues, FacetValues, FunctionValues, GeometryMapping, InterfaceValues,
Interpolation, QuadratureRule, FacetQuadratureRule,
)
@eval @deprecate copy(x::$T) task_local(x)
end
5 changes: 4 additions & 1 deletion src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,7 @@ export
evaluate_at_points,
PointIterator,
PointLocation,
PointValues
PointValues,

# Misc
task_local
2 changes: 0 additions & 2 deletions src/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ nvertices(::Interpolation{RefShape}) where {RefShape} = nvertices(RefShape)
nedges(::Interpolation{RefShape}) where {RefShape} = nedges(RefShape)
nfaces(::Interpolation{RefShape}) where {RefShape} = nfaces(RefShape)

Base.copy(ip::Interpolation) = ip

"""
Ferrite.getrefdim(::Interpolation)
Expand Down
8 changes: 7 additions & 1 deletion src/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ end
UpdateFlags(; nodes::Bool = true, coords::Bool = true, dofs::Bool = true) =
UpdateFlags(nodes, coords, dofs)


###############
## CellCache ##
###############
Expand Down Expand Up @@ -83,6 +82,13 @@ function reinit!(cc::CellCache, i::Int)
return cc
end

function task_local(cc::CellCache)
return CellCache(

Check warning on line 86 in src/iterators.jl

View check run for this annotation

Codecov / codecov/patch

src/iterators.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
cc.flags, cc.grid, task_local(cc.cellid), task_local(cc.nodes),
task_local(cc.coords), cc.dh, task_local(cc.dofs)
)
end

# reinit! FEValues with CellCache
reinit!(cv::CellValues, cc::CellCache) = reinit!(cv, cc.coords)
reinit!(fv::FacetValues, cc::CellCache, f::Int) = reinit!(fv, cc.coords, f)
Expand Down
79 changes: 79 additions & 0 deletions src/multithreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
function task_local end

"""
task_local(x::T) -> T
Duplicate `x` for a new task such that it can be used concurrently with the original `x`.
This is similar to `copy` but only the data that is known to be mutated, i.e. "scratch
data", are duplicated.
Typically, for concurrent assembly, there are some data structures that can't be shared
between the tasks, for example the local element matrix/vector and the `CellValues`.
`task_local` can thus be used to duplicate theses data structures for each task based on a
"template" data structure. For example,
```julia
# "Template" local matrix and cell values
Ke = zeros(...)
cv = CellValues(...)
# Spawn `ntasks` tasks for concurrent assembly
@sync for i in 1:ntasks
Threads.@spawn begin
Ke_task = task_local(Ke)
cv_task = task_local(Ke)
for cell in cells_for_task
assemble_element!(Ke_task, cv_task, ...)
end
end
end
```
See the how-to on [multi-threaded assembly](@ref tutorial-threaded-assembly) for a complete
example.
The following "user-facing" types define methods for `task_local`:
- [`CellValues`](@ref), [`FacetValues`](@ref), [`InterfaceValues`](@ref) are duplicated
such that they can be `reinit!`ed independently.
- `DenseArray` (for e.g. the local matrix and vector) are duplicated such that they can be
modified concurrently.
- [`CellCache`](@ref) (for caching element nodes and dofs) are duplicated such that they
can be `reinit!`ed independently.
The following types also define methods for `task_local` but are typically not used directly
by the user but instead used recursively by the above types:
- [`QuadratureRule`](@ref) and [`FacetQuadratureRule`](@ref)
- All types which are `isbitstype` (e.g. `Vec`, `Tensor`, `Int`, `Float64`, etc.)
"""
task_local(::Any)

# DenseVector/DenseMatrix (e.g. local matrix and vector)
function task_local(x::T)::T where {S, T <: DenseArray{S}}
@assert !isbitstype(T)
if isbitstype(S)

Check warning on line 55 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L53-L55

Added lines #L53 - L55 were not covered by tests
# If the eltype isbitstype the normal shallow copy can be used...
return copy(x)::T

Check warning on line 57 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L57

Added line #L57 was not covered by tests
else
# ... otherwise we recurse and call task_local on the elements
return map(task_local, x)::T

Check warning on line 60 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L60

Added line #L60 was not covered by tests
end
end

# FacetQuadratureRule can store the QuadratureRules as a tuple
function task_local(x::T)::T where {T <: Tuple}
if isbitstype(T)
return x

Check warning on line 67 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
else
return map(task_local, x)::T

Check warning on line 69 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L69

Added line #L69 was not covered by tests
end
end

# General fallback for other types
function task_local(x::T)::T where {T}
if !isbitstype(T)
error("MethodError: task_local(::$T) is not implemented")

Check warning on line 76 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L74-L76

Added lines #L74 - L76 were not covered by tests
end
return x

Check warning on line 78 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L78

Added line #L78 was not covered by tests
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Metis
include("test_utils.jl")

# Unit tests
include("test_multithreading.jl")
include("test_collectionsofviews.jl")
include("test_interpolations.jl")
include("test_cellvalues.jl")
Expand Down
25 changes: 0 additions & 25 deletions test/test_cellvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,6 @@
for (i, qp_x) in pairs(Ferrite.getpoints(quad_rule))
@test spatial_coordinate(cv, i, coords) qp_x
end

@testset "copy(::CellValues)" begin
cvc = copy(cv)
@test typeof(cv) == typeof(cvc)

# Test that all mutable types in FunctionValues and GeometryMapping have been copied
for key in (:fun_values, :geo_mapping)
val = getfield(cv, key)
valc = getfield(cvc, key)
for fname in fieldnames(typeof(val))
v = getfield(val, fname)
vc = getfield(valc, fname)
isbits(v) || @test v !== vc
@test v == vc
end
end
# Test that qr and detJdV is copied as expected.
# Note that qr remain aliased, as defined by `copy(qr)=qr`, see quadrature.jl.
for fname in (:qr, :detJdV)
v = getfield(cv, fname)
vc = getfield(cvc, fname)
fname === :qr || @test v !== vc
@test v == vc
end
end
end
end

Expand Down
16 changes: 16 additions & 0 deletions test/test_deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,20 @@ using Ferrite, Test
@test_throws Ferrite.DeprecationError celldofs!(v, fc)
end

# Ferrite v2
@testset "Base.copy -> task_local" begin
ip = Lagrange{RefQuadrilateral, 1}()
@test_deprecated copy(ip)
qr = QuadratureRule{RefQuadrilateral}(2)
@test_deprecated copy(qr)
cv = CellValues(qr, ip)
@test_deprecated copy(cv)
fqr = FacetQuadratureRule{RefQuadrilateral}(2)
@test_deprecated copy(fqr)
fv = FacetValues(fqr, ip)
@test_deprecated copy(fv)
iv = InterfaceValues(fqr, ip)
@test_deprecated copy(iv)
end

end # testset deprecations
Loading

0 comments on commit b75e345

Please sign in to comment.