From 00bcec3091a4eab0fa46eb1f88f72ea482f22d95 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 18 Oct 2024 12:58:39 +0200 Subject: [PATCH] Optimize expand by using SignedMultiplivativeInverse --- src/nditeration.jl | 108 +++++++++++++++++++++++++++++++++++++++++---- test/compiler.jl | 2 +- test/localmem.jl | 4 +- test/private.jl | 2 +- test/test.jl | 2 +- 5 files changed, 104 insertions(+), 14 deletions(-) diff --git a/src/nditeration.jl b/src/nditeration.jl index ab3fd4ec..fd87ed66 100644 --- a/src/nditeration.jl +++ b/src/nditeration.jl @@ -1,5 +1,65 @@ module NDIteration +import Base.MultiplicativeInverses: SignedMultiplicativeInverse + +# CartesianIndex uses Int instead of Int32 + +@eval EmptySMI() = $(Expr(:new, SignedMultiplicativeInverse{Int32}, Int32(0), typemax(Int32), 0 % Int8, 0 % UInt8)) +SMI(i) = i == 0 ? EmptySMI() : SignedMultiplicativeInverse{Int32}(i) + +struct FastCartesianIndices{N} <: AbstractArray{CartesianIndex{N}, N} + inverses::NTuple{N, SignedMultiplicativeInverse{Int32}} +end + +function FastCartesianIndices(indices::NTuple{N}) where {N} + inverses = map(i -> SMI(Int32(i)), indices) + FastCartesianIndices(inverses) +end + +function Base.size(FCI::FastCartesianIndices{N}) where {N} + ntuple(Val(N)) do I + FCI.inverses[I].divisor + end +end + +@inline function Base.getindex(::FastCartesianIndices{0}) + return CartesianIndex() +end + +@inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where {N} + @boundscheck checkbounds(iter, I...) + index = map(iter.inverses, I) do inv, i + @inbounds getindex(Base.OneTo(inv.divisor), i%Int32) + end + CartesianIndex(index) +end + +_ind2sub_recuse(::Tuple{}, ind) = (ind + 1,) +function _ind2sub_recurse(indslast::NTuple{1}, ind) + Base.@_inline_meta + (_lookup(ind, indslast[1]),) +end + +function _ind2sub_recurse(inds, ind) + Base.@_inline_meta + assume(ind >= 0) + inv = inds[1] + indnext, f, l = _div(ind, inv) + (ind - l * indnext + f, _ind2sub_recurse(Base.tail(inds), indnext)...) +end + +_lookup(ind, inv::SignedMultiplicativeInverse) = ind + 1 +function _div(ind, inv::SignedMultiplicativeInverse) + # inv.divisor == 0 && throw(DivideError()) + assume(ind >= 0) + div(ind % Int32, inv), 1, inv.divisor +end + +function Base._ind2sub(inv::FastCartesianIndices, ind) + Base.@_inline_meta + _ind2sub_recurse(inv.inverses, ind - 1) +end + export _Size, StaticSize, DynamicSize, get export NDRange, blocks, workitems, expand export DynamicCheck, NoDynamicCheck @@ -50,18 +110,32 @@ struct NDRange{N, StaticBlocks, StaticWorkitems, DynamicBlock, DynamicWorkitems} blocks::DynamicBlock workitems::DynamicWorkitems - function NDRange{N, B, W}() where {N, B, W} - new{N, B, W, Nothing, Nothing}(nothing, nothing) - end - - function NDRange{N, B, W}(blocks, workitems) where {N, B, W} + function NDRange{N, B, W}(blocks::Union{Nothing, FastCartesianIndices{N}}, workitems::Union{Nothing, FastCartesianIndices{N}}) where {N, B, W} + @assert B <: _Size + @assert W <: _Size new{N, B, W, typeof(blocks), typeof(workitems)}(blocks, workitems) end end -@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::CartesianIndices{N} +function NDRange{N, B, W}() where {N, B, W} + NDRange{N, B, W}(nothing, nothing) +end + +function NDRange{N, B, W}(blocks::CartesianIndices, workitems::CartesianIndices) where {N, B, W} + return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), FastCartesianIndices(size(workitems))) +end + +function NDRange{N, B, W}(blocks::Nothing, workitems::CartesianIndices) where {N, B, W} + return NDRange{N, B, W}(blocks, FastCartesianIndices(size(workitems))) +end + +function NDRange{N, B, W}(blocks::CartesianIndices, workitems::Nothing) where {N, B, W} + return NDRange{N, B, W}(FastCartesianIndices(size(blocks)), workitems) +end + +@inline workitems(range::NDRange{N, B, W}) where {N, B, W <: DynamicSize} = range.workitems::FastCartesianIndices{N} @inline workitems(range::NDRange{N, B, W}) where {N, B, W <: StaticSize} = CartesianIndices(get(W))::CartesianIndices{N} -@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::CartesianIndices{N} +@inline blocks(range::NDRange{N, B}) where {N, B <: DynamicSize} = range.blocks::FastCartesianIndices{N} @inline blocks(range::NDRange{N, B}) where {N, B <: StaticSize} = CartesianIndices(get(B))::CartesianIndices{N} import Base.iterate @@ -80,8 +154,8 @@ Base.length(range::NDRange) = length(blocks(range)) CartesianIndex(nI) end -Base.@propagate_inbounds function expand(ndrange::NDRange, groupidx::Integer, idx::Integer) - expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx]) +Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::Integer, idx::Integer) where {N} + return expand(ndrange, blocks(ndrange)[groupidx], workitems(ndrange)[idx]) end Base.@propagate_inbounds function expand(ndrange::NDRange{N}, groupidx::CartesianIndex{N}, idx::Integer) where {N} @@ -126,4 +200,20 @@ needs to perform dynamic bounds-checking. end end + + +""" + assume(cond::Bool) +Assume that the condition `cond` is true. This is a hint to the compiler, possibly enabling +it to optimize more aggressively. +""" +@inline assume(cond::Bool) = Base.llvmcall((""" + declare void @llvm.assume(i1) + define void @entry(i8) #0 { + %cond = icmp eq i8 %0, 1 + call void @llvm.assume(i1 %cond) + ret void + } + attributes #0 = { alwaysinline }""", "entry"), + Nothing, Tuple{Bool}, cond) end #module diff --git a/test/compiler.jl b/test/compiler.jl index cf86386e..ba930305 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -39,7 +39,7 @@ end function compiler_testsuite(backend, ArrayT) kernel = index(CPU(), DynamicSize(), DynamicSize()) - iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}(); + iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}() ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(KernelAbstractions.NoDynamicCheck())) @test KernelAbstractions.__index_Global_NTuple(ctx, CartesianIndex(1)) == (1,) diff --git a/test/localmem.jl b/test/localmem.jl index b03bfa74..8f1167a3 100644 --- a/test/localmem.jl +++ b/test/localmem.jl @@ -8,7 +8,7 @@ using Test end I = @index(Global, Linear) i = @index(Local, Linear) - lmem = @localmem Int (N,) # Ok iff groupsize is static + lmem = @localmem Int (N,) # Ok iff groupsize is static @inbounds begin lmem[i] = i @synchronize @@ -23,7 +23,7 @@ end end I = @index(Global, Linear) i = @index(Local, Linear) - lmem = @localmem Int (N,) # Ok iff groupsize is static + lmem = @localmem Int (N,) # Ok iff groupsize is static @inbounds begin lmem[i] = i + 3 for j in 1:2 diff --git a/test/private.jl b/test/private.jl index 02c25aea..b1a0273e 100644 --- a/test/private.jl +++ b/test/private.jl @@ -73,7 +73,7 @@ function private_testsuite(backend, ArrayT) A = ArrayT{Int}(undef, 64, 64) A .= 1 - forloop(backend())(A, Val(size(A, 2)), ndrange = size(A, 1), workgroupsize = size(A, 1)) + forloop(backend(), size(A, 1))(A, Val(size(A, 2)), ndrange = size(A, 1), workgroupsize = size(A, 1)) synchronize(backend()) @test all(Array(A)[:, 1] .== 64) @test all(Array(A)[:, 2:end] .== 1) diff --git a/test/test.jl b/test/test.jl index 8977231b..c2d07137 100644 --- a/test/test.jl +++ b/test/test.jl @@ -154,7 +154,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk @conditional_testset "Const" skip_tests begin let kernel = constarg(Backend(), 8, (1024,)) # this is poking at internals - iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}(); + iterspace = NDRange{1, StaticSize{(128,)}, StaticSize{(8,)}}() ctx = if Backend == CPU KernelAbstractions.mkcontext(kernel, 1, nothing, iterspace, Val(NoDynamicCheck())) else