Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial ideas for a QuadraturePointIterator #883

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions heatflow_qp_values.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
using Ferrite

# Standard element routine
function assemble_element_std!(Ke::Matrix, fe::Vector, cellvalues::CellValues)
n_basefuncs = getnbasefunctions(cellvalues)
## Loop over quadrature points
for q_point in 1:getnquadpoints(cellvalues)
## Get the quadrature weight
dΩ = getdetJdV(cellvalues, q_point)
## Loop over test shape functions
for i in 1:n_basefuncs
δu = shape_value(cellvalues, q_point, i)
∇δu = shape_gradient(cellvalues, q_point, i)
## Add contribution to fe
fe[i] += δu * dΩ
## Loop over trial shape functions
for j in 1:n_basefuncs
∇u = shape_gradient(cellvalues, q_point, j)
## Add contribution to Ke
Ke[i, j] += (∇δu ⋅ ∇u) * dΩ
end
end
end
return Ke, fe
end

# Element routine using QuadratureValuesIterator
function assemble_element_qpiter!(Ke::Matrix, fe::Vector, cellvalues)
n_basefuncs = getnbasefunctions(cellvalues)
## Loop over quadrature points
for qv in Ferrite.QuadratureValuesIterator(cellvalues)
## Get the quadrature weight
dΩ = getdetJdV(qv)
## Loop over test shape functions
for i in 1:n_basefuncs
δu = shape_value(qv, i)
∇δu = shape_gradient(qv, i)
## Add contribution to fe
fe[i] += δu * dΩ
## Loop over trial shape functions
for j in 1:n_basefuncs
∇u = shape_gradient(qv, j)
## Add contribution to Ke
Ke[i, j] += (∇δu ⋅ ∇u) * dΩ
end
end
end
return Ke, fe
end

function assemble_element_qpiter!(Ke::Matrix, fe::Vector, cellvalues, cell_coords::AbstractVector)
n_basefuncs = getnbasefunctions(cellvalues)
## Loop over quadrature points
for qv in Ferrite.QuadratureValuesIterator(cellvalues, cell_coords)
## Get the quadrature weight
dΩ = getdetJdV(qv)
## Loop over test shape functions
for i in 1:n_basefuncs
δu = shape_value(qv, i)
∇δu = shape_gradient(qv, i)
## Add contribution to fe
fe[i] += δu * dΩ
## Loop over trial shape functions
for j in 1:n_basefuncs
∇u = shape_gradient(qv, j)
## Add contribution to Ke
Ke[i, j] += (∇δu ⋅ ∇u) * dΩ
end
end
end
return Ke, fe
end

function assemble_global(cellvalues, dh; kwargs...)
assemble_global!(create_buffers(cellvalues, dh), cellvalues, dh; kwargs...)
end

function assemble_global!(buffer, cellvalues, dh::DofHandler; qp_iter::Val{QPiter}, reinit::Val{ReInit}) where {QPiter, ReInit}
(;f, K, assembler, Ke, fe) = buffer
for cell in CellIterator(dh)
fill!(Ke, 0)
fill!(fe, 0)
if QPiter
if ReInit
reinit!(cellvalues, getcoordinates(cell))
assemble_element_qpiter!(Ke, fe, cellvalues)
else
assemble_element_qpiter!(Ke, fe, cellvalues, getcoordinates(cell))
end
else
reinit!(cellvalues, getcoordinates(cell))
assemble_element_std!(Ke, fe, cellvalues)
end
assemble!(assembler, celldofs(cell), Ke, fe)
end
return K, f
end

function create_buffers(cellvalues, dh)
f = zeros(ndofs(dh))
K = create_sparsity_pattern(dh)
assembler = start_assemble(K, f)
## Local quantities
n_basefuncs = getnbasefunctions(cellvalues)
Ke = zeros(n_basefuncs, n_basefuncs)
fe = zeros(n_basefuncs)
return (;f, K, assembler, Ke, fe)
end

n = 50
grid = generate_grid(Quadrilateral, (n, n));
ip = Lagrange{RefQuadrilateral, 1}()
qr = QuadratureRule{RefQuadrilateral}(2)

dh = DofHandler(grid)
add!(dh, :u, ip)
close!(dh);

cellvalues = CellValues(qr, ip);
static_cellvalues = Ferrite.StaticCellValues(cellvalues)

stdassy(buffer, cv, dh) = assemble_global!(buffer, cv, dh; qp_iter=Val(false), reinit=Val(false))
qp_outside(buffer, cv, dh) = assemble_global!(buffer, cv, dh; qp_iter=Val(true), reinit=Val(true))
qp_inside(buffer, cv, dh) = assemble_global!(buffer, cv, dh; qp_iter=Val(true), reinit=Val(false))

Kstd, fstd = stdassy(create_buffers(cellvalues, dh), cellvalues, dh);
K_qp_o, f_qp_o = qp_outside(create_buffers(cellvalues, dh), cellvalues, dh);
K_qp_i, f_qp_i = qp_inside(create_buffers(cellvalues, dh), cellvalues, dh);

cvs_o = Ferrite.StaticCellValues(cellvalues, Val(true)) # Save cell_coords in cvs_o
Ks_o, fs_o = qp_outside(create_buffers(cvs_o, dh), cvs_o, dh);

cvs_i = Ferrite.StaticCellValues(cellvalues, Val(false)) # Don't save cell_coords in cvs_o
Ks_i, fs_i = qp_inside(create_buffers(cvs_i, dh), cvs_i, dh);

using Test
@testset "check outputs" begin
for (k, K, f) in (("qpo", K_qp_o, f_qp_o), ("qpi", K_qp_i, f_qp_i), ("so", Ks_o, fs_o), ("si", Ks_i, fs_i))
@testset "$k" begin
@test K ≈ Kstd
@test f ≈ fstd
end
end
end

# Benchmarking
using BenchmarkTools
if n ≤ 100
print("Standard: ")
@btime stdassy(buffer, $cellvalues, $dh) setup=(buffer=create_buffers(cellvalues, dh));
print("Std qpoint outside: ")
@btime qp_outside(buffer, $cellvalues, $dh) setup=(buffer=create_buffers(cellvalues, dh));
print("Std qpoint inside: ")
@btime qp_inside(buffer, $cellvalues, $dh) setup=(buffer=create_buffers(cellvalues, dh));
print("Static outside: ")
@btime qp_outside(buffer, $cvs_o, $dh) setup=(buffer=create_buffers(cvs_o, dh));
print("Static inside: ")
@btime qp_inside(buffer, $cvs_i, $dh) setup=(buffer=create_buffers(cvs_i, dh));
else
buffer = create_buffers(cellvalues, dh)
print("Standard: ")
@time stdassy(buffer, cellvalues, dh)
print("Std qpoint outside: ")
@time qp_outside(buffer, cellvalues, dh)
print("Std qpoint inside: ")
@time qp_inside(buffer, cellvalues, dh)
print("Static outside: ")
@time qp_outside(buffer, cvs_o, dh)
print("Static inside: ")
@time qp_inside(buffer, cvs_i, dh)
end
nothing

139 changes: 139 additions & 0 deletions src/FEValues/QuadratureValues.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# QuadratureValuesIterator
struct QuadratureValuesIterator{VT,XT}
v::VT
cell_coords::XT # Union{AbstractArray{<:Vec}, Nothing}
function QuadratureValuesIterator(v::V) where V
return new{V, Nothing}(v, nothing)

Check warning on line 6 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L5-L6

Added lines #L5 - L6 were not covered by tests
end
function QuadratureValuesIterator(v::V, cell_coords::VT) where {V, VT <: AbstractArray}
reinit!(v, cell_coords)
return new{V, VT}(v, cell_coords)

Check warning on line 10 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L8-L10

Added lines #L8 - L10 were not covered by tests
end
end

function Base.iterate(iterator::QuadratureValuesIterator{<:Any, Nothing}, q_point=1)
checkbounds(Bool, 1:getnquadpoints(iterator.v), q_point) || return nothing
qp_v = @inbounds quadrature_point_values(iterator.v, q_point)
return (qp_v, q_point+1)

Check warning on line 17 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L14-L17

Added lines #L14 - L17 were not covered by tests
end
function Base.iterate(iterator::QuadratureValuesIterator{<:Any, <:AbstractVector}, q_point=1)
checkbounds(Bool, 1:getnquadpoints(iterator.v), q_point) || return nothing
qp_v = @inbounds quadrature_point_values(iterator.v, q_point, iterator.cell_coords)
return (qp_v, q_point+1)

Check warning on line 22 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L19-L22

Added lines #L19 - L22 were not covered by tests
end
Base.IteratorEltype(::Type{<:QuadratureValuesIterator}) = Base.EltypeUnknown()
Base.length(iterator::QuadratureValuesIterator) = getnquadpoints(iterator.v)

Check warning on line 25 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L24-L25

Added lines #L24 - L25 were not covered by tests

# AbstractQuadratureValues
abstract type AbstractQuadratureValues end

function function_value(qp_v::AbstractQuadratureValues, u::AbstractVector, dof_range = eachindex(u))
n_base_funcs = getnbasefunctions(qp_v)
length(dof_range) == n_base_funcs || throw_incompatible_dof_length(length(dof_range), n_base_funcs)
@boundscheck checkbounds(u, dof_range)
val = function_value_init(qp_v, u)
@inbounds for (i, j) in pairs(dof_range)
val += shape_value(qp_v, i) * u[j]
end
return val

Check warning on line 38 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L30-L38

Added lines #L30 - L38 were not covered by tests
end

function function_gradient(qp_v::AbstractQuadratureValues, u::AbstractVector, dof_range = eachindex(u))
n_base_funcs = getnbasefunctions(qp_v)
length(dof_range) == n_base_funcs || throw_incompatible_dof_length(length(dof_range), n_base_funcs)
@boundscheck checkbounds(u, dof_range)
grad = function_gradient_init(qp_v, u)
@inbounds for (i, j) in pairs(dof_range)
grad += shape_gradient(qp_v, i) * u[j]
end
return grad

Check warning on line 49 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L41-L49

Added lines #L41 - L49 were not covered by tests
end

function function_symmetric_gradient(qp_v::AbstractQuadratureValues, u::AbstractVector, dof_range)
grad = function_gradient(qp_v, u, dof_range)
return symmetric(grad)

Check warning on line 54 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L52-L54

Added lines #L52 - L54 were not covered by tests
end

function function_symmetric_gradient(qp_v::AbstractQuadratureValues, u::AbstractVector)
grad = function_gradient(qp_v, u)
return symmetric(grad)

Check warning on line 59 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L57-L59

Added lines #L57 - L59 were not covered by tests
end

function function_divergence(qp_v::AbstractQuadratureValues, u::AbstractVector, dof_range = eachindex(u))
return divergence_from_gradient(function_gradient(qp_v, u, dof_range))

Check warning on line 63 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
end

function function_curl(qp_v::AbstractQuadratureValues, u::AbstractVector, dof_range = eachindex(u))
return curl_from_gradient(function_gradient(qp_v, u, dof_range))

Check warning on line 67 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
end

function spatial_coordinate(qp_v::AbstractQuadratureValues, x::AbstractVector{<:Vec})
n_base_funcs = getngeobasefunctions(qp_v)
length(x) == n_base_funcs || throw_incompatible_coord_length(length(x), n_base_funcs)
vec = zero(eltype(x))
@inbounds for i in 1:n_base_funcs
vec += geometric_value(qp_v, i) * x[i]
end
return vec

Check warning on line 77 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L70-L77

Added lines #L70 - L77 were not covered by tests
end

# Specific design for QuadratureValues <: AbstractQuadratureValues
# which contains standard AbstractValues
struct QuadratureValues{VT<:AbstractValues} <: AbstractQuadratureValues
v::VT
q_point::Int
Base.@propagate_inbounds function QuadratureValues(v::AbstractValues, q_point::Int)
@boundscheck checkbounds(1:getnbasefunctions(v), q_point)
return new{typeof(v)}(v, q_point)

Check warning on line 87 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L85-L87

Added lines #L85 - L87 were not covered by tests
end
end

@inline quadrature_point_values(fe_v::AbstractValues, q_point, args...) = QuadratureValues(fe_v, q_point)

Check warning on line 91 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L91

Added line #L91 was not covered by tests

@propagate_inbounds getngeobasefunctions(qv::QuadratureValues) = getngeobasefunctions(qv.v)
@propagate_inbounds geometric_value(qv::QuadratureValues, i) = geometric_value(qv.v, qv.q_point, i)
geometric_interpolation(qv::QuadratureValues) = geometric_interpolation(qv.v)

Check warning on line 95 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L93-L95

Added lines #L93 - L95 were not covered by tests

getdetJdV(qv::QuadratureValues) = @inbounds getdetJdV(qv.v, qv.q_point)

Check warning on line 97 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L97

Added line #L97 was not covered by tests

# Accessors for function values
getnbasefunctions(qv::QuadratureValues) = getnbasefunctions(qv.v)
function_interpolation(qv::QuadratureValues) = function_interpolation(qv.v)
function_difforder(qv::QuadratureValues) = function_difforder(qv.v)
shape_value_type(qv::QuadratureValues) = shape_value_type(qv.v)
shape_gradient_type(qv::QuadratureValues) = shape_gradient_type(qv.v)

Check warning on line 104 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L100-L104

Added lines #L100 - L104 were not covered by tests

@propagate_inbounds shape_value(qv::QuadratureValues, i::Int) = shape_value(qv.v, qv.q_point, i)
@propagate_inbounds shape_gradient(qv::QuadratureValues, i::Int) = shape_gradient(qv.v, qv.q_point, i)
@propagate_inbounds shape_symmetric_gradient(qv::QuadratureValues, i::Int) = shape_symmetric_gradient(qv.v, qv.q_point, i)

Check warning on line 108 in src/FEValues/QuadratureValues.jl

View check run for this annotation

Codecov / codecov/patch

src/FEValues/QuadratureValues.jl#L106-L108

Added lines #L106 - L108 were not covered by tests



#= Proposed syntax, for heatflow in general
function assemble_element!(Ke::Matrix, fe::Vector, cellvalues)
n_basefuncs = getnbasefunctions(cellvalues)
for qv in Ferrite.QuadratureValuesIterator(cellvalues)
dΩ = getdetJdV(qv)
for i in 1:n_basefuncs
δu = shape_value(qv, i)
∇δu = shape_gradient(qv, i)
fe[i] += δu * dΩ
for j in 1:n_basefuncs
∇u = shape_gradient(qv, j)
Ke[i, j] += (∇δu ⋅ ∇u) * dΩ
end
end
end
return Ke, fe
end

Where the default for a QuadratureValuesIterator would be to return a
`QuadratureValues` as above, but custom `AbstractValues` can be created where
for example the element type would be a static QuadPointValue type which doesn't
use heap allocated buffers, e.g. by only saving the cell and coordinates during reinit,
and then calculating all values for each element in the iterator.

References:
https://github.com/termi-official/Thunderbolt.jl/pull/53/files#diff-2b486be5a947c02ef2a38ff3f82af3141193af0b6f01ed9d5129b914ed1d84f6
https://github.com/Ferrite-FEM/Ferrite.jl/compare/master...kam/StaticValues2
=#
Loading
Loading