Skip to content

Commit

Permalink
Optimize expand by using SignedMultiplivativeInverse
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 21, 2024
1 parent 419481c commit 73dc429
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 14 deletions.
90 changes: 81 additions & 9 deletions src/nditeration.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,63 @@
module NDIteration

import Base.MultiplicativeInverses: SignedMultiplicativeInverse

# CartesianIndex uses Int instead of Int32

@eval EmptySMI() = $(Expr(:new, SignedMultiplicativeInverse{Int32}, Int32(0), typemax(Int32), 0 % Int8, 0 % UInt8))
SMI(i) = i == 0 ? EmptySMI() : SignedMultiplicativeInverse{Int32}(i)

struct FastCartesianIndices{N} <: AbstractArray{CartesianIndex{N}, N}
inverses::NTuple{N, SignedMultiplicativeInverse{Int32}}
end

function FastCartesianIndices(indices::NTuple{N}) where {N}
inverses = map(i -> SMI(Int32(i)), indices)
FastCartesianIndices(inverses)
end

function Base.size(FCI::FastCartesianIndices{N}) where {N}
ntuple(Val(N)) do I
FCI.inverses[I].divisor
end
end

@inline function Base.getindex(::FastCartesianIndices{0})
return CartesianIndex()
end

@inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where {N}
@boundscheck checkbounds(iter, I...)
index = map(iter.inverses, I) do inv, i
@inbounds getindex(Base.OneTo(inv.divisor), i)
end
CartesianIndex(index)
end

_ind2sub_recuse(::Tuple{}, ind) = (ind + 1,)
function _ind2sub_recurse(indslast::NTuple{1}, ind)
Base.@_inline_meta
(_lookup(ind, indslast[1]),)
end

function _ind2sub_recurse(inds, ind)
Base.@_inline_meta
inv = inds[1]
indnext, f, l = _div(ind, inv)
(ind - l * indnext + f, _ind2sub_recurse(Base.tail(inds), indnext)...)
end

_lookup(ind, inv::SignedMultiplicativeInverse) = ind + 1
function _div(ind, inv::SignedMultiplicativeInverse)
inv.divisor == 0 && throw(DivideError())
div(ind % Int32, inv), 1, inv.divisor
end

function Base._ind2sub(inv::FastCartesianIndices, ind)
Base.@_inline_meta
_ind2sub_recurse(inv.inverses, ind - 1)
end

export _Size, StaticSize, DynamicSize, get
export NDRange, blocks, workitems, expand
export DynamicCheck, NoDynamicCheck
Expand Down Expand Up @@ -50,18 +108,32 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems}
blocks::DynamicBlock
workitems::DynamicWorkitems

function NDRange{N, B, W}() where {N, B, W}
new{N, B, W, Nothing, Nothing}(nothing, nothing)
end

function NDRange{N, B, W}(blocks, workitems) where {N, B, W}
function NDRange{N, B, W}(blocks::Union{Nothing, FastCartesianIndices{N}}, workitems::Union{Nothing, FastCartesianIndices{N}}) where {N, B, W}
@assert B <: _Size
@assert W <: _Size
new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems)
end
end

@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::CartesianIndices{N}
function NDRange{N, B, W}() where {N, B, W}
NDRange{N, B, W}(nothing, nothing)
end

function NDRange{N, B, W}(blocks::CartesianIndices, workitems::CartesianIndices) where {N, B, W}
return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), FastCartesianIndices(size(workitems)))
end

function NDRange{N, B, W}(blocks::Nothing, workitems::CartesianIndices) where {N, B, W}
return NDRange{N, B, W}(blocks, FastCartesianIndices(size(workitems)))
end

function NDRange{N, B, W}(blocks::CartesianIndices, workitems::Nothing) where {N, B, W}
return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), workitems)
end

@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::FastCartesianIndices{N}
@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: StaticSize} = CartesianIndices(get(W))::CartesianIndices{N}
@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::CartesianIndices{N}
@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::FastCartesianIndices{N}
@inline blocks(range::NDRange{N, B}) where {N, B <: StaticSize} = CartesianIndices(get(B))::CartesianIndices{N}

import Base.iterate
Expand All @@ -80,8 +152,8 @@ Base.length(range::NDRange) = length(blocks(range))
CartesianIndex(nI)
end

Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer)
expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::Integer) where {N}
return expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx])
end

Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N}
Expand Down
2 changes: 1 addition & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

function compiler_testsuite(backend, ArrayT)
kernel = index(CPU(), DynamicSize(), DynamicSize())
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}()
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck()))
@test KernelAbstractions.__index_Global_NTuple(ctx, CartesianIndex(1)) == (1,)

Expand Down
4 changes: 2 additions & 2 deletions test/localmem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Test
end
I = @index(Global, Linear)
i = @index(Local, Linear)
lmem = @localmem Int (N,) # Ok iff groupsize is static
lmem = @localmem Int (N,) # Ok iff groupsize is static
@inbounds begin
lmem[i] = i
@synchronize
Expand All @@ -23,7 +23,7 @@ end
end
I = @index(Global, Linear)
i = @index(Local, Linear)
lmem = @localmem Int (N,) # Ok iff groupsize is static
lmem = @localmem Int (N,) # Ok iff groupsize is static
@inbounds begin
lmem[i] = i + 3
for j in 1:2
Expand Down
2 changes: 1 addition & 1 deletion test/private.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function private_testsuite(backend, ArrayT)

A = ArrayT{Int}(undef, 64, 64)
A .= 1
forloop(backend())(A, Val(size(A, 2)), ndrange = size(A, 1), workgroupsize = size(A, 1))
forloop(backend(), size(A, 1))(A, Val(size(A, 2)), ndrange = size(A, 1), workgroupsize = size(A, 1))
synchronize(backend())
@test all(Array(A)[:, 1] .== 64)
@test all(Array(A)[:, 2:end] .== 1)
Expand Down
2 changes: 1 addition & 1 deletion test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
@conditional_testset "Const" skip_tests begin
let kernel = constarg(Backend(), 8, (1024,))
# this is poking at internals
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}();
iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}()
ctx = if Backend == CPU
KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(NoDynamicCheck()))
else
Expand Down

0 comments on commit 73dc429

Please sign in to comment.