diff --git a/docs/src/literate-howto/threaded_assembly.jl b/docs/src/literate-howto/threaded_assembly.jl index 47cbd32297..1e17b78c58 100644 --- a/docs/src/literate-howto/threaded_assembly.jl +++ b/docs/src/literate-howto/threaded_assembly.jl @@ -172,13 +172,12 @@ end # purpose. Finally, for the assembler we call `start_assemble` to create a new assembler but # note that we set `fillzero = false` because we don't want to risk that a task that starts # a bit later will zero out data that another task have already assembled. -function ScratchData(dh::DofHandler, K::SparseMatrixCSC, f::Vector, cellvalues::CellValues) - cell_cache = CellCache(dh) - n = ndofs_per_cell(dh) - Ke = zeros(n, n) - fe = zeros(n) - asm = start_assemble(K, f; fillzero = false) - return ScratchData(cell_cache, copy(cellvalues), Ke, fe, asm) +function create_scratch(scratch::ScratchData) + return ScratchData( + task_local(scratch.cell_cache), task_local(scratch.cellvalues), + task_local(scratch.Ke), task_local(scratch.fe), + task_local(scratch.assembler) + ) end nothing # hide @@ -220,13 +219,17 @@ using OhMyThreads, TaskLocalValues function assemble_global!( K::SparseMatrixCSC, f::Vector, dh::DofHandler, colors, - cellvalues_template::CellValues; ntasks = Threads.nthreads() + cellvalues::CellValues; ntasks = Threads.nthreads() ) - ## Zero-out existing data in K and f - _ = start_assemble(K, f) ## Body force and material stiffness b = Vec{3}((0.0, 0.0, -1.0)) C = create_material_stiffness() + ## Scratch data + scratch = ScratchData( + CellCache(dh), cellvalues, + zeros(ndofs_per_cell(dh), ndofs_per_cell(dh)), zeros(ndofs_per_cell(dh)), + start_assemble(K, f) + ) ## Loop over the colors for color in colors ## Dynamic scheduler spawning `ntasks` tasks where each task will process a chunk of @@ -237,8 +240,8 @@ function assemble_global!( ## Tell the @tasks loop to use the scheduler defined above @set scheduler = scheduler ## Obtain a task local scratch and unpack it - @local scratch = ScratchData(dh, K, f, cellvalues_template) - (; cell_cache, cellvalues, Ke, fe, assembler) = scratch + @local scratch = create_scratch(scratch) + local (; cell_cache, cellvalues, Ke, fe, assembler) = scratch ## Reinitialize the cell cache and then the cellvalues reinit!(cell_cache, cellidx) reinit!(cellvalues, cell_cache) @@ -259,7 +262,7 @@ nothing # hide # ```julia # # using TaskLocalValues # scratches = TaskLocalValue() do -# ScratchData(dh, K, f, cellvalues) +# create_scratch(scratch) # end # OhMyThreads.tforeach(color; scheduler) do cellidx # # Obtain a task local scratch and unpack it diff --git a/src/FEValues/CellValues.jl b/src/FEValues/CellValues.jl index 9f7583c1ab..a603f1f3c9 100644 --- a/src/FEValues/CellValues.jl +++ b/src/FEValues/CellValues.jl @@ -65,8 +65,11 @@ function CellValues(::Type{T}, qr::QuadratureRule, ip::Interpolation, ip_geo::Ve return CellValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...)) end -function Base.copy(cv::CellValues) - return CellValues(copy(cv.fun_values), copy(cv.geo_mapping), copy(cv.qr), _copy_or_nothing(cv.detJdV)) +function task_local(cv::CellValues) + return CellValues( + task_local(cv.fun_values), task_local(cv.geo_mapping), task_local(cv.qr), + task_local(cv.detJdV) + ) end # Access geometry values diff --git a/src/FEValues/FacetValues.jl b/src/FEValues/FacetValues.jl index 43c9860cc3..a897d3b75d 100644 --- a/src/FEValues/FacetValues.jl +++ b/src/FEValues/FacetValues.jl @@ -68,10 +68,12 @@ function FacetValues(::Type{T}, qr::FacetQuadratureRule, ip::Interpolation, ip_g return FacetValues(T, qr, ip, ip_geo, ValuesUpdateFlags(ip; kwargs...)) end -function Base.copy(fv::FacetValues) - fun_values = map(copy, fv.fun_values) - geo_mapping = map(copy, fv.geo_mapping) - return FacetValues(fun_values, geo_mapping, copy(fv.fqr), copy(fv.detJdV), copy(fv.normals), fv.current_facet) +function task_local(fv::FacetValues) + return FacetValues( + map(task_local, fv.fun_values), map(task_local, fv.geo_mapping), + task_local(fv.fqr), task_local(fv.detJdV), task_local(fv.normals), + task_local(fv.current_facet) + ) end getngeobasefunctions(fv::FacetValues) = getngeobasefunctions(get_geo_mapping(fv)) diff --git a/src/FEValues/FunctionValues.jl b/src/FEValues/FunctionValues.jl index 006b365789..5e5d45f91e 100644 --- a/src/FEValues/FunctionValues.jl +++ b/src/FEValues/FunctionValues.jl @@ -98,14 +98,13 @@ function precompute_values!(fv::FunctionValues{2}, qr_points::AbstractVector{<:V return reference_shape_hessians_gradients_and_values!(fv.d2Ndξ2, fv.dNdξ, fv.Nξ, fv.ip, qr_points) end -function Base.copy(v::FunctionValues) - Nξ_copy = copy(v.Nξ) - Nx_copy = v.Nξ === v.Nx ? Nξ_copy : copy(v.Nx) # Preserve aliasing - dNdx_copy = _copy_or_nothing(v.dNdx) - dNdξ_copy = _copy_or_nothing(v.dNdξ) - d2Ndx2_copy = _copy_or_nothing(v.d2Ndx2) - d2Ndξ2_copy = _copy_or_nothing(v.d2Ndξ2) - return FunctionValues(copy(v.ip), Nx_copy, Nξ_copy, dNdx_copy, dNdξ_copy, d2Ndx2_copy, d2Ndξ2_copy) +function task_local(v::FunctionValues) + Nξ = task_local(v.Nξ) + Nx = v.Nξ === v.Nx ? Nξ : task_local(v.Nx) # Preserve aliasing + return FunctionValues( + task_local(v.ip), Nx, Nξ, task_local(v.dNdx), task_local(v.dNdξ), + task_local(v.d2Ndx2), task_local(v.d2Ndξ2) + ) end getnbasefunctions(funvals::FunctionValues) = size(funvals.Nx, 1) diff --git a/src/FEValues/GeometryMapping.jl b/src/FEValues/GeometryMapping.jl index 85aed2d45a..07734d7a77 100644 --- a/src/FEValues/GeometryMapping.jl +++ b/src/FEValues/GeometryMapping.jl @@ -95,8 +95,10 @@ function precompute_values!(gm::GeometryMapping{2}, qr_points::AbstractVector{<: return reference_shape_hessians_gradients_and_values!(gm.d2Mdξ2, gm.dMdξ, gm.M, gm.ip, qr_points) end -function Base.copy(v::GeometryMapping) - return GeometryMapping(copy(v.ip), copy(v.M), _copy_or_nothing(v.dMdξ), _copy_or_nothing(v.d2Mdξ2)) +function task_local(v::GeometryMapping) + return GeometryMapping( + task_local(v.ip), task_local(v.M), task_local(v.dMdξ), task_local(v.d2Mdξ2) + ) end getngeobasefunctions(geo_mapping::GeometryMapping) = size(geo_mapping.M, 1) diff --git a/src/FEValues/InterfaceValues.jl b/src/FEValues/InterfaceValues.jl index 8866811aa4..5475b6c947 100644 --- a/src/FEValues/InterfaceValues.jl +++ b/src/FEValues/InterfaceValues.jl @@ -86,8 +86,8 @@ end InterfaceValues(facetvalues_here::FVA, facetvalues_there::FVB = deepcopy(FacetValues_here)) where {FVA <: FacetValues, FVB <: FacetValues} = InterfaceValues{FVA, FVB}(facetvalues_here, facetvalues_there) -function Base.copy(iv::InterfaceValues) - return InterfaceValues(copy(iv.here), copy(iv.there)) +function task_local(iv::InterfaceValues) + return InterfaceValues(task_local(iv.here), task_local(iv.there)) end function getnbasefunctions(iv::InterfaceValues) diff --git a/src/Ferrite.jl b/src/Ferrite.jl index 7b358463be..f9f39dd394 100644 --- a/src/Ferrite.jl +++ b/src/Ferrite.jl @@ -31,6 +31,8 @@ using .CollectionsOfViews: include("exports.jl") +# Task based multithreading support +include("multithreading.jl") """ AbstractRefShape{refdim} diff --git a/src/Quadrature/quadrature.jl b/src/Quadrature/quadrature.jl index dae3cc9795..081efe5221 100644 --- a/src/Quadrature/quadrature.jl +++ b/src/Quadrature/quadrature.jl @@ -317,5 +317,11 @@ getpoints(qr::FacetQuadratureRule, face::Int) = getpoints(qr.face_rules[face]) getrefshape(::QuadratureRule{RefShape}) where {RefShape} = RefShape -# TODO: This is used in copy(::(Cell|Face)Values), but it it useful to get an actual copy? -Base.copy(qr::Union{QuadratureRule, FacetQuadratureRule}) = qr +# TODO: For typical use the quadrature rule is read-only, but seems safer to copy anyway? +# And might even be beneficial with e.g. NUMA? +function task_local(qr::QR) where {refshape, QR <: QuadratureRule{refshape}} + return QuadratureRule{refshape}(task_local(qr.weights), task_local(qr.points))::QR +end +function task_local(qr::QR) where {refshape, QR <: FacetQuadratureRule{refshape}} + return FacetQuadratureRule{refshape}(map(task_local, qr.face_rules))::QR +end diff --git a/src/assembler.jl b/src/assembler.jl index 078f5fcd34..81f68e2ded 100644 --- a/src/assembler.jl +++ b/src/assembler.jl @@ -204,6 +204,14 @@ matrix_handle(a::AbstractCSCAssembler) = a.K matrix_handle(a::SymmetricCSCAssembler) = a.K.data vector_handle(a::AbstractCSCAssembler) = a.f +function task_local(asm::CSCAssembler) + return CSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs)) +end +function task_local(asm::SymmetricCSCAssembler) + return SymmetricCSCAssembler(asm.K, asm.f, task_local(asm.permutation), task_local(asm.sorteddofs)) +end + + """ start_assemble(K::AbstractSparseMatrixCSC; fillzero::Bool=true) -> CSCAssembler start_assemble(K::AbstractSparseMatrixCSC, f::Vector; fillzero::Bool=true) -> CSCAssembler diff --git a/src/deprecations.jl b/src/deprecations.jl index 06b931f804..e3adf1ec33 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -425,3 +425,14 @@ end function celldofs!(::Vector, ::FacetCache) throw(DeprecationError("celldofs!(v::Vector, fs::FacetCache)" => "celldofs!(v, celldofs(fc))")) end + +## Deprecations introduced in Ferrite@1.x (keep until 2.0) ## + +# Base.copy -> task_local, https://github.com/Ferrite-FEM/Ferrite.jl/pull/1070 +import Base: copy +for T in ( + CellValues, FacetValues, FunctionValues, GeometryMapping, InterfaceValues, + Interpolation, QuadratureRule, FacetQuadratureRule, + ) + @eval @deprecate copy(x::$T) task_local(x) +end diff --git a/src/exports.jl b/src/exports.jl index c710861c15..15507ecc32 100644 --- a/src/exports.jl +++ b/src/exports.jl @@ -184,4 +184,7 @@ export evaluate_at_points, PointIterator, PointLocation, - PointValues + PointValues, + + # Misc + task_local diff --git a/src/interpolations.jl b/src/interpolations.jl index 5aaed4fdfb..b8398815e2 100644 --- a/src/interpolations.jl +++ b/src/interpolations.jl @@ -116,8 +116,6 @@ nvertices(::Interpolation{RefShape}) where {RefShape} = nvertices(RefShape) nedges(::Interpolation{RefShape}) where {RefShape} = nedges(RefShape) nfaces(::Interpolation{RefShape}) where {RefShape} = nfaces(RefShape) -Base.copy(ip::Interpolation) = ip - """ Ferrite.getrefdim(::Interpolation) diff --git a/src/iterators.jl b/src/iterators.jl index 0abf2972b1..f44db84417 100644 --- a/src/iterators.jl +++ b/src/iterators.jl @@ -9,7 +9,6 @@ end UpdateFlags(; nodes::Bool = true, coords::Bool = true, dofs::Bool = true) = UpdateFlags(nodes, coords, dofs) - ############### ## CellCache ## ############### @@ -83,6 +82,13 @@ function reinit!(cc::CellCache, i::Int) return cc end +function task_local(cc::CellCache) + return CellCache( + cc.flags, cc.grid, task_local(cc.cellid), task_local(cc.nodes), + task_local(cc.coords), cc.dh, task_local(cc.dofs) + ) +end + # reinit! FEValues with CellCache reinit!(cv::CellValues, cc::CellCache) = reinit!(cv, cc.coords) reinit!(fv::FacetValues, cc::CellCache, f::Int) = reinit!(fv, cc.coords, f) diff --git a/src/multithreading.jl b/src/multithreading.jl new file mode 100644 index 0000000000..fad41ef767 --- /dev/null +++ b/src/multithreading.jl @@ -0,0 +1,79 @@ +function task_local end + +""" + task_local(x::T) -> T + +Duplicate `x` for a new task such that it can be used concurrently with the original `x`. +This is similar to `copy` but only the data that is known to be mutated, i.e. "scratch +data", are duplicated. + +Typically, for concurrent assembly, there are some data structures that can't be shared +between the tasks, for example the local element matrix/vector and the `CellValues`. +`task_local` can thus be used to duplicate theses data structures for each task based on a +"template" data structure. For example, + +```julia +# "Template" local matrix and cell values +Ke = zeros(...) +cv = CellValues(...) + +# Spawn `ntasks` tasks for concurrent assembly +@sync for i in 1:ntasks + Threads.@spawn begin + Ke_task = task_local(Ke) + cv_task = task_local(Ke) + for cell in cells_for_task + assemble_element!(Ke_task, cv_task, ...) + end + end +end +``` + +See the how-to on [multi-threaded assembly](@ref tutorial-threaded-assembly) for a complete +example. + +The following "user-facing" types define methods for `task_local`: + + - [`CellValues`](@ref), [`FacetValues`](@ref), [`InterfaceValues`](@ref) are duplicated + such that they can be `reinit!`ed independently. + - `DenseArray` (for e.g. the local matrix and vector) are duplicated such that they can be + modified concurrently. + - [`CellCache`](@ref) (for caching element nodes and dofs) are duplicated such that they + can be `reinit!`ed independently. + +The following types also define methods for `task_local` but are typically not used directly +by the user but instead used recursively by the above types: + + - [`QuadratureRule`](@ref) and [`FacetQuadratureRule`](@ref) + - All types which are `isbitstype` (e.g. `Vec`, `Tensor`, `Int`, `Float64`, etc.) +""" +task_local(::Any) + +# DenseVector/DenseMatrix (e.g. local matrix and vector) +function task_local(x::T)::T where {S, T <: DenseArray{S}} + @assert !isbitstype(T) + if isbitstype(S) + # If the eltype isbitstype the normal shallow copy can be used... + return copy(x)::T + else + # ... otherwise we recurse and call task_local on the elements + return map(task_local, x)::T + end +end + +# FacetQuadratureRule can store the QuadratureRules as a tuple +function task_local(x::T)::T where {T <: Tuple} + if isbitstype(T) + return x + else + return map(task_local, x)::T + end +end + +# General fallback for other types +function task_local(x::T)::T where {T} + if !isbitstype(T) + error("MethodError: task_local(::$T) is not implemented") + end + return x +end diff --git a/test/runtests.jl b/test/runtests.jl index 57cd82d8a7..7339fdd1f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ import Metis include("test_utils.jl") # Unit tests +include("test_multithreading.jl") include("test_collectionsofviews.jl") include("test_interpolations.jl") include("test_cellvalues.jl") diff --git a/test/test_cellvalues.jl b/test/test_cellvalues.jl index e8be79127b..fc35ca7c66 100644 --- a/test/test_cellvalues.jl +++ b/test/test_cellvalues.jl @@ -138,31 +138,6 @@ for (i, qp_x) in pairs(Ferrite.getpoints(quad_rule)) @test spatial_coordinate(cv, i, coords) ≈ qp_x end - - @testset "copy(::CellValues)" begin - cvc = copy(cv) - @test typeof(cv) == typeof(cvc) - - # Test that all mutable types in FunctionValues and GeometryMapping have been copied - for key in (:fun_values, :geo_mapping) - val = getfield(cv, key) - valc = getfield(cvc, key) - for fname in fieldnames(typeof(val)) - v = getfield(val, fname) - vc = getfield(valc, fname) - isbits(v) || @test v !== vc - @test v == vc - end - end - # Test that qr and detJdV is copied as expected. - # Note that qr remain aliased, as defined by `copy(qr)=qr`, see quadrature.jl. - for fname in (:qr, :detJdV) - v = getfield(cv, fname) - vc = getfield(cvc, fname) - fname === :qr || @test v !== vc - @test v == vc - end - end end end diff --git a/test/test_deprecations.jl b/test/test_deprecations.jl index 6bb672e912..46ed1be964 100644 --- a/test/test_deprecations.jl +++ b/test/test_deprecations.jl @@ -131,4 +131,20 @@ using Ferrite, Test @test_throws Ferrite.DeprecationError celldofs!(v, fc) end + # Ferrite v2 + @testset "Base.copy -> task_local" begin + ip = Lagrange{RefQuadrilateral, 1}() + @test_deprecated copy(ip) + qr = QuadratureRule{RefQuadrilateral}(2) + @test_deprecated copy(qr) + cv = CellValues(qr, ip) + @test_deprecated copy(cv) + fqr = FacetQuadratureRule{RefQuadrilateral}(2) + @test_deprecated copy(fqr) + fv = FacetValues(fqr, ip) + @test_deprecated copy(fv) + iv = InterfaceValues(fqr, ip) + @test_deprecated copy(iv) + end + end # testset deprecations diff --git a/test/test_facevalues.jl b/test/test_facevalues.jl index 672713b315..aee08b76b5 100644 --- a/test/test_facevalues.jl +++ b/test/test_facevalues.jl @@ -145,35 +145,6 @@ # end end - - @testset "copy(::FacetValues)" begin - fvc = copy(fv) - @test typeof(fv) == typeof(fvc) - - # Test that all mutable types in FunctionValues and GeometryMapping have been copied - for key in (:fun_values, :geo_mapping) - for i in eachindex(getfield(fv, key)) - val = getfield(fv, key)[i] - valc = getfield(fvc, key)[i] - for fname in fieldnames(typeof(val)) - v = getfield(val, fname) - vc = getfield(valc, fname) - isbits(v) || @test v !== vc - @test v == vc - end - end - end - # Test that fqr, detJdV, and normals, are copied as expected. - # Note that qr remain aliased, as defined by `copy(qr)=qr`, see quadrature.jl. - for fname in (:fqr, :detJdV, :normals) - v = getfield(fv, fname) - vc = getfield(fvc, fname) - if fname !== :fqr # Test unaliased - @test v !== vc - end - @test v == vc - end - end end end diff --git a/test/test_interfacevalues.jl b/test/test_interfacevalues.jl index 8fb31dc741..0296b0fe59 100644 --- a/test/test_interfacevalues.jl +++ b/test/test_interfacevalues.jl @@ -233,26 +233,6 @@ ic = first(InterfaceIterator(dh)) @test dof_range(ic, :p) == (9:12, 25:28) end - # Test copy - iv = InterfaceValues(FacetQuadratureRule{RefQuadrilateral}(2), DiscontinuousLagrange{RefQuadrilateral, 1}()) - ivc = copy(iv) - @test typeof(iv) == typeof(ivc) - for fname in fieldnames(typeof(iv)) - v = getfield(iv, fname) - vc = getfield(ivc, fname) - if hasmethod(pointer, Tuple{typeof(v)}) - @test pointer(v) != pointer(vc) - end - v isa FacetValues && continue - for fname in fieldnames(typeof(vc)) - v2 = getfield(v, fname) - vc2 = getfield(vc, fname) - if hasmethod(pointer, Tuple{typeof(v2)}) - @test pointer(v2) != pointer(vc2) - end - @test v2 == vc2 - end - end @testset "undefined transformation matrix error path" begin it = Ferrite.InterfaceOrientationInfo{DummyRefShapes.RefDodecahedron, DummyRefShapes.RefDodecahedron}(false, 0, 0, 1, 1) @test_throws ArgumentError("transformation is not implemented") Ferrite.get_transformation_matrix(it) diff --git a/test/test_multithreading.jl b/test/test_multithreading.jl new file mode 100644 index 0000000000..a15ba09976 --- /dev/null +++ b/test/test_multithreading.jl @@ -0,0 +1,71 @@ +using Ferrite, Test +using LinearAlgebra: Symmetric +using SparseArrays: sprand + +function equivalent_but_distinct(x::T, y::T) where {T} + if isbitstype(T) + @test x === y + elseif T <: AbstractArray + @test x !== y + for i in eachindex(x, y) + equivalent_but_distinct(x[i], y[i]) + end + else + # (mutable) struct, Tuple, etc. + for s in fieldnames(T) + equivalent_but_distinct(getfield(x, s), getfield(y, s)) + end + end + return +end + +@testset "task_local" begin + # General fallback for bitstypes + for x in (1, 1.0, true, nothing) + equivalent_but_distinct(x, task_local(x)) + end + # task_local(::Array) behaves like copy(::Array) + for x in (rand(1), rand(1, 1), rand(1, 1, 1)) + equivalent_but_distinct(x, task_local(x)) + end + # task_local(::QuadratureRule) behaves like copy(::QuadratureRule) + # TODO: This doesn't actually copy but maybe it should? + for qr in (QuadratureRule{RefTriangle}(2), FacetQuadratureRule{RefTriangle}(2)) + equivalent_but_distinct(qr, task_local(qr)) + end + # Interpolations are are assumed to be singletons + for ip in (Lagrange{RefTriangle, 1}(), Lagrange{RefTriangle, 2}()^2) + equivalent_but_distinct(ip, task_local(ip)) + end + # GeometryMapping, FunctionValues + ip = Lagrange{RefTriangle, 2}() + qr = QuadratureRule{RefTriangle}(2) + for DiffOrder in (0, 1, 2) + gm = Ferrite.GeometryMapping{DiffOrder}(Float64, ip, qr) + equivalent_but_distinct(gm, task_local(gm)) + fv = Ferrite.FunctionValues{DiffOrder}(Float64, ip, qr, Ferrite.VectorizedInterpolation{2}(ip)) + equivalent_but_distinct(fv, task_local(fv)) + end + # CellValues + for cv in (CellValues(qr, ip), CellValues(qr, ip; update_hessians = true)) + equivalent_but_distinct(cv, task_local(cv)) + end + # FacetValues + fqr = FacetQuadratureRule{RefTriangle}(2) + fv = FacetValues(fqr, ip) + equivalent_but_distinct(fv, task_local(fv)) + # InterfaceValues + iv = InterfaceValues(fqr, ip) + equivalent_but_distinct(iv, task_local(iv)) + # CSCAssembler, SymmetricCSCAssembler + let K = sprand(10, 10, 0.5), f = rand(10) + for assembler in (start_assemble(K, f), start_assemble(Symmetric(K), f)) + tl = task_local(assembler) + @test tl.K === assembler.K + @test tl.f === assembler.f + @test tl.permutation !== assembler.permutation + @test tl.sorteddofs !== assembler.sorteddofs + end + end + # TODO: Test CellCache +end