Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP add task_local for duplication things for multithreading #1070

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions docs/src/literate-howto/threaded_assembly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,12 @@ end
# purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but
# note that we set `fillzero = false` because we don't want to risk that a task that starts
# a bit later will zero out data that another task have already assembled.
function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues)
cell_cache = CellCache(dh)
n = ndofs_per_cell(dh)
Ke = zeros(n, n)
fe = zeros(n)
asm = start_assemble(K, f; fillzero = false)
return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm)
function create_scratch(scratch::ScratchData)
return ScratchData(
task_local(scratch.cell_cache), task_local(scratch.cellvalues),
task_local(scratch.Ke), task_local(scratch.fe),
task_local(scratch.assembler)
)
end
nothing # hide

Expand Down Expand Up @@ -220,13 +219,17 @@ using OhMyThreads, TaskLocalValues

function assemble_global!(
K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors,
cellvalues_template::CellValues; ntasks = Threads.nthreads()
cellvalues::CellValues; ntasks = Threads.nthreads()
)
## Zero-out existing data in K and f
_ = start_assemble(K, f)
## Body force and material stiffness
b = Vec{3}((0.0, 0.0, -1.0))
C = create_material_stiffness()
## Scratch data
scratch = ScratchData(
CellCache(dh), cellvalues,
zeros(ndofs_per_cell(dh), ndofs_per_cell(dh)), zeros(ndofs_per_cell(dh)),
start_assemble(K, f)
)
## Loop over the colors
for color in colors
## Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of
Expand All @@ -237,8 +240,8 @@ function assemble_global!(
## Tell the @tasks loop to use the scheduler defined above
@set scheduler = scheduler
## Obtain a task local scratch and unpack it
@local scratch = ScratchData(dh, K, f, cellvalues_template)
(; cell_cache, cellvalues, Ke, fe, assembler) = scratch
@local scratch = create_scratch(scratch)
local (; cell_cache, cellvalues, Ke, fe, assembler) = scratch
## Reinitialize the cell cache and then the cellvalues
reinit!(cell_cache, cellidx)
reinit!(cellvalues, cell_cache)
Expand All @@ -259,7 +262,7 @@ nothing # hide
# ```julia
# # using TaskLocalValues
# scratches = TaskLocalValue() do
# ScratchData(dh, K, f, cellvalues)
# create_scratch(scratch)
# end
# OhMyThreads.tforeach(color; scheduler) do cellidx
# # Obtain a task local scratch and unpack it
Expand Down
7 changes: 5 additions & 2 deletions src/FEValues/CellValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ function CellValues(::Type{T}, qr::QuadratureRule, ip::Interpolation, ip_geo::Ve
return CellValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...))
end

function Base.copy(cv::CellValues)
return CellValues(copy(cv.fun_values), copy(cv.geo_mapping), copy(cv.qr), _copy_or_nothing(cv.detJdV))
function task_local(cv::CellValues)
return CellValues(
task_local(cv.fun_values), task_local(cv.geo_mapping), task_local(cv.qr),
termi-official marked this conversation as resolved.
Show resolved Hide resolved
task_local(cv.detJdV)
)
end

# Access geometry values
Expand Down
10 changes: 6 additions & 4 deletions src/FEValues/FacetValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ function FacetValues(::Type{T}, qr::FacetQuadratureRule, ip::Interpolation, ip_g
return FacetValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...))
end

function Base.copy(fv::FacetValues)
fun_values = map(copy, fv.fun_values)
geo_mapping = map(copy, fv.geo_mapping)
return FacetValues(fun_values, geo_mapping, copy(fv.fqr), copy(fv.detJdV), copy(fv.normals), fv.current_facet)
function task_local(fv::FacetValues)
return FacetValues(
task_local(fv.fun_values), task_local(fv.geo_mapping),
task_local(fv.fqr), task_local(fv.detJdV), task_local(fv.normals),
task_local(fv.current_facet)
)
end

getngeobasefunctions(fv::FacetValues) = getngeobasefunctions(get_geo_mapping(fv))
Expand Down
15 changes: 7 additions & 8 deletions src/FEValues/FunctionValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,13 @@ function precompute_values!(fv::FunctionValues{2}, qr_points::AbstractVector{<:V
return reference_shape_hessians_gradients_and_values!(fv.d2Ndξ2, fv.dNdξ, fv.Nξ, fv.ip, qr_points)
end

function Base.copy(v::FunctionValues)
Nξ_copy = copy(v.Nξ)
Nx_copy = v.Nξ === v.Nx ? Nξ_copy : copy(v.Nx) # Preserve aliasing
dNdx_copy = _copy_or_nothing(v.dNdx)
dNdξ_copy = _copy_or_nothing(v.dNdξ)
d2Ndx2_copy = _copy_or_nothing(v.d2Ndx2)
d2Ndξ2_copy = _copy_or_nothing(v.d2Ndξ2)
return FunctionValues(copy(v.ip), Nx_copy, Nξ_copy, dNdx_copy, dNdξ_copy, d2Ndx2_copy, d2Ndξ2_copy)
function task_local(v::FunctionValues)
fredrikekre marked this conversation as resolved.
Show resolved Hide resolved
Nξ = task_local(v.Nξ)
Nx = v.Nξ === v.Nx ? Nξ : task_local(v.Nx) # Preserve aliasing
termi-official marked this conversation as resolved.
Show resolved Hide resolved
return FunctionValues(
task_local(v.ip), Nx, Nξ, task_local(v.dNdx), task_local(v.dNdξ),
task_local(v.d2Ndx2), task_local(v.d2Ndξ2)
)
end

getnbasefunctions(funvals::FunctionValues) = size(funvals.Nx, 1)
Expand Down
6 changes: 4 additions & 2 deletions src/FEValues/GeometryMapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ function precompute_values!(gm::GeometryMapping{2}, qr_points::AbstractVector{<:
return reference_shape_hessians_gradients_and_values!(gm.d2Mdξ2, gm.dMdξ, gm.M, gm.ip, qr_points)
end

function Base.copy(v::GeometryMapping)
return GeometryMapping(copy(v.ip), copy(v.M), _copy_or_nothing(v.dMdξ), _copy_or_nothing(v.d2Mdξ2))
function task_local(v::GeometryMapping)
return GeometryMapping(
task_local(v.ip), task_local(v.M), task_local(v.dMdξ), task_local(v.d2Mdξ2)
)
end

getngeobasefunctions(geo_mapping::GeometryMapping) = size(geo_mapping.M, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/FEValues/InterfaceValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ end
InterfaceValues(facetvalues_here::FVA, facetvalues_there::FVB = deepcopy(FacetValues_here)) where {FVA <: FacetValues, FVB <: FacetValues} =
InterfaceValues{FVA, FVB}(facetvalues_here, facetvalues_there)

function Base.copy(iv::InterfaceValues)
return InterfaceValues(copy(iv.here), copy(iv.there))
function task_local(iv::InterfaceValues)
return InterfaceValues(task_local(iv.here), task_local(iv.there))
end

function getnbasefunctions(iv::InterfaceValues)
Expand Down
2 changes: 2 additions & 0 deletions src/Ferrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ using .CollectionsOfViews:

include("exports.jl")

# Task based multithreading support
include("multithreading.jl")

"""
AbstractRefShape{refdim}
Expand Down
10 changes: 8 additions & 2 deletions src/Quadrature/quadrature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,11 @@ getpoints(qr::FacetQuadratureRule, face::Int) = getpoints(qr.face_rules[face])

getrefshape(::QuadratureRule{RefShape}) where {RefShape} = RefShape

# TODO: This is used in copy(::(Cell|Face)Values), but it it useful to get an actual copy?
Base.copy(qr::Union{QuadratureRule, FacetQuadratureRule}) = qr
# TODO: For typical use the quadrature rule is read-only, but seems safer to copy anyway?
# And might even be beneficial with e.g. NUMA?
Comment on lines +320 to +321
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think the cost of copying here would always be negligible, and there is a risk when using e.g. PointValues where the quadrature point is changed!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that task_local should always be safe and that the caller is responsible to not call task_local on immutable stuff, if he does not want the duplication.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, at least in the past, it was beneficial that the thread that used the data was the one allocating it so thats why I thought maybe it is better to duplicate this too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if that is also true for constant shared memory data. I suggest that we keep the rationale in a doc string here tho, so if we revisit this we have the exact arguments at hand (and also for users).

function task_local(qr::QR) where {refshape, QR <: QuadratureRule{refshape}}
return QuadratureRule{refshape}(task_local(qr.weights), task_local(qr.points))::QR
end
function task_local(qr::QR) where {refshape, QR <: FacetQuadratureRule{refshape}}
return FacetQuadratureRule{refshape}(task_local(qr.face_rules))::QR
end
8 changes: 8 additions & 0 deletions src/assembler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ matrix_handle(a::AbstractCSCAssembler) = a.K
matrix_handle(a::SymmetricCSCAssembler) = a.K.data
vector_handle(a::AbstractCSCAssembler) = a.f

function task_local(asm::CSCAssembler)
return CSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))
end
function task_local(asm::SymmetricCSCAssembler)
return SymmetricCSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs))
end


"""
start_assemble(K::AbstractSparseMatrixCSC; fillzero::Bool=true) -> CSCAssembler
start_assemble(K::AbstractSparseMatrixCSC, f::Vector; fillzero::Bool=true) -> CSCAssembler
Expand Down
11 changes: 11 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,14 @@ end
function celldofs!(::Vector, ::FacetCache)
throw(DeprecationError("celldofs!(v::Vector, fs::FacetCache)" => "celldofs!(v, celldofs(fc))"))
end

## Deprecations introduced in [email protected] (keep until 2.0) ##

# Base.copy -> task_local, https://github.com/Ferrite-FEM/Ferrite.jl/pull/1070
import Base: copy
for T in (
CellValues, FacetValues, FunctionValues, GeometryMapping, InterfaceValues,
Interpolation, QuadratureRule, FacetQuadratureRule,
)
@eval @deprecate copy(x::$T) task_local(x)
end
5 changes: 4 additions & 1 deletion src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,7 @@ export
evaluate_at_points,
PointIterator,
PointLocation,
PointValues
PointValues,

# Misc
task_local
2 changes: 0 additions & 2 deletions src/interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ nvertices(::Interpolation{RefShape}) where {RefShape} = nvertices(RefShape)
nedges(::Interpolation{RefShape}) where {RefShape} = nedges(RefShape)
nfaces(::Interpolation{RefShape}) where {RefShape} = nfaces(RefShape)

Base.copy(ip::Interpolation) = ip

"""
Ferrite.getrefdim(::Interpolation)

Expand Down
8 changes: 7 additions & 1 deletion src/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ end
UpdateFlags(; nodes::Bool = true, coords::Bool = true, dofs::Bool = true) =
UpdateFlags(nodes, coords, dofs)


###############
## CellCache ##
###############
Expand Down Expand Up @@ -83,6 +82,13 @@ function reinit!(cc::CellCache, i::Int)
return cc
end

function task_local(cc::CellCache)
return CellCache(
cc.flags, cc.grid, task_local(cc.cellid), task_local(cc.nodes),
task_local(cc.coords), cc.dh, task_local(cc.dofs)
)
end

# reinit! FEValues with CellCache
reinit!(cv::CellValues, cc::CellCache) = reinit!(cv, cc.coords)
reinit!(fv::FacetValues, cc::CellCache, f::Int) = reinit!(fv, cc.coords, f)
Expand Down
79 changes: 79 additions & 0 deletions src/multithreading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
function task_local end

"""
task_local(x::T) -> T

Duplicate `x` for a new task such that it can be used concurrently with the original `x`.
fredrikekre marked this conversation as resolved.
Show resolved Hide resolved
This is similar to `copy` but only the data that is known to be mutated, i.e. "scratch
data", are duplicated.

Typically, for concurrent assembly, there are some data structures that can't be shared
between the tasks, for example the local element matrix/vector and the `CellValues`.
`task_local` can thus be used to duplicate theses data structures for each task based on a
"template" data structure. For example,

```julia
# "Template" local matrix and cell values
Ke = zeros(...)
cv = CellValues(...)

# Spawn `ntasks` tasks for concurrent assembly
@sync for i in 1:ntasks
Threads.@spawn begin
Ke_task = task_local(Ke)
cv_task = task_local(Ke)
for cell in cells_for_task
assemble_element!(Ke_task, cv_task, ...)
end
end
end
```

See the how-to on [multi-threaded assembly](@ref tutorial-threaded-assembly) for a complete
example.

The following "user-facing" types define methods for `task_local`:

- [`CellValues`](@ref), [`FacetValues`](@ref), [`InterfaceValues`](@ref) are duplicated
such that they can be `reinit!`ed independently.
- `DenseArray` (for e.g. the local matrix and vector) are duplicated such that they can be
modified concurrently.
- [`CellCache`](@ref) (for caching element nodes and dofs) are duplicated such that they
can be `reinit!`ed independently.

The following types also define methods for `task_local` but are typically not used directly
by the user but instead used recursively by the above types:

- [`QuadratureRule`](@ref) and [`FacetQuadratureRule`](@ref)
- All types which are `isbitstype` (e.g. `Vec`, `Tensor`, `Int`, `Float64`, etc.)
"""
task_local(::Any)

# DenseVector/DenseMatrix (e.g. local matrix and vector)
function task_local(x::T)::T where {S, T <: DenseArray{S}}
@assert !isbitstype(T)
if isbitstype(S)
# If the eltype isbitstype the normal shallow copy can be used...
return copy(x)::T
else
# ... otherwise we recurse and call task_local on the elements
return map(task_local, x)::T
end
end

# FacetQuadratureRule can store the QuadratureRules as a tuple
function task_local(x::T)::T where {T <: Tuple}
if isbitstype(T)
return x
else
return map(task_local, x)::T
end
end

# General fallback for other types
function task_local(x::T)::T where {T}
if !isbitstype(T)
error("MethodError: task_local(::$T) is not implemented")

Check warning on line 76 in src/multithreading.jl

View check run for this annotation

Codecov / codecov/patch

src/multithreading.jl#L76

Added line #L76 was not covered by tests
end
return x
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Metis
include("test_utils.jl")

# Unit tests
include("test_multithreading.jl")
include("test_collectionsofviews.jl")
include("test_interpolations.jl")
include("test_cellvalues.jl")
Expand Down
25 changes: 0 additions & 25 deletions test/test_cellvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,6 @@
for (i, qp_x) in pairs(Ferrite.getpoints(quad_rule))
@test spatial_coordinate(cv, i, coords) ≈ qp_x
end

@testset "copy(::CellValues)" begin
cvc = copy(cv)
@test typeof(cv) == typeof(cvc)

# Test that all mutable types in FunctionValues and GeometryMapping have been copied
for key in (:fun_values, :geo_mapping)
val = getfield(cv, key)
valc = getfield(cvc, key)
for fname in fieldnames(typeof(val))
v = getfield(val, fname)
vc = getfield(valc, fname)
isbits(v) || @test v !== vc
@test v == vc
end
end
# Test that qr and detJdV is copied as expected.
# Note that qr remain aliased, as defined by `copy(qr)=qr`, see quadrature.jl.
for fname in (:qr, :detJdV)
v = getfield(cv, fname)
vc = getfield(cvc, fname)
fname === :qr || @test v !== vc
@test v == vc
end
end
end
end

Expand Down
16 changes: 16 additions & 0 deletions test/test_deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,20 @@ using Ferrite, Test
@test_throws Ferrite.DeprecationError celldofs!(v, fc)
end

# Ferrite v2
@testset "Base.copy -> task_local" begin
ip = Lagrange{RefQuadrilateral, 1}()
@test_deprecated copy(ip)
qr = QuadratureRule{RefQuadrilateral}(2)
@test_deprecated copy(qr)
cv = CellValues(qr, ip)
@test_deprecated copy(cv)
fqr = FacetQuadratureRule{RefQuadrilateral}(2)
@test_deprecated copy(fqr)
fv = FacetValues(fqr, ip)
@test_deprecated copy(fv)
iv = InterfaceValues(fqr, ip)
@test_deprecated copy(iv)
end

end # testset deprecations
Loading