-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use MultiplicativeInverse to speedup Linear to Cartesian indexing operations #539
base: main
Are you sure you want to change the base?
Conversation
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
This stack of pull requests is managed by Graphite. Learn more about stacking. |
102fbbd
to
73dc429
Compare
Well the only issue is that my benchmarks are angry at me... |
From a quick look, this doesn't look to be a silver bullet. Some initial performance measurements in JuliaGPU/GPUArrays.jl#565 (comment). Also, there's now some ; ││││┌ @ /home/tim/Julia/pkg/KernelAbstractions/src/nditeration.jl:31 within `getindex`
; │││││┌ @ tuple.jl:383 within `map`
; ││││││┌ @ /home/tim/Julia/pkg/KernelAbstractions/src/nditeration.jl:32 within `#5`
; │││││││┌ @ array.jl:3065 within `getindex`
; ││││││││┌ @ range.jl:923 within `_getindex`
; │││││││││┌ @ range.jl:953 within `unsafe_getindex`
; ││││││││││┌ @ number.jl:7 within `convert`
; │││││││││││┌ @ boot.jl:891 within `Int32`
; ││││││││││││┌ @ boot.jl:801 within `toInt32`
; │││││││││││││┌ @ boot.jl:764 within `checked_trunc_sint`
%47 = add nsw i64 %43, -2147483647
%48 = icmp ult i64 %47, -4294967296
br i1 %48, label %L304, label %L313
L304: ; preds = %L219
call fastcc void @julia__throw_inexacterror_25251({ i64, i32 } %state)
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable
L313: ; preds = %L219
%49 = add nsw i64 %45, -2147483647
%50 = icmp ult i64 %49, -4294967296
br i1 %50, label %L335, label %L345
L335: ; preds = %L313
call fastcc void @julia__throw_inexacterror_25251({ i64, i32 } %state)
call void @llvm.trap()
call void asm sideeffect "exit;", ""()
unreachable Looks like this generates a significant amount of code. With the scalar broadcast from ttps://github.com/JuliaGPU/GPUArrays.jl/issues/565, the CUDA.jl version that simply uses hardware indices vs. the KA.jl version: define ptx_kernel void @old({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
%.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
%2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%3 = zext i32 %2 to i64
%4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%5 = zext i32 %4 to i64
%6 = mul nuw nsw i64 %3, %5
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = add nuw nsw i32 %7, 1
%9 = zext i32 %8 to i64
%10 = add nuw nsw i64 %6, %9
%11 = icmp sgt i64 %.fca.2.0.extract, 0
call void @llvm.assume(i1 %11)
%12 = icmp sgt i64 %.fca.2.1.extract, 0
call void @llvm.assume(i1 %12)
%.not = icmp sgt i64 %10, %.fca.3.extract
br i1 %.not, label %L176, label %pass
L176: ; preds = %pass, %conversion
ret void
pass: ; preds = %conversion
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
%13 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%14 = add nsw i64 %10, -1
%15 = getelementptr inbounds float, float addrspace(1)* %13, i64 %14
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
store float %.fca.0.0.extract, float addrspace(1)* %15, align 4
br label %L176
} define ptx_kernel void @new({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
%.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 0, 0, 0, 0
%.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 0, 0, 1, 0
%.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 0
%.fca.1.0.0.0.1.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 1
%.fca.1.0.0.0.2.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 2
%.fca.1.0.0.0.3.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 3
%.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 0
%.fca.1.1.0.0.1.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 1
%.fca.1.1.0.0.2.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 2
%.fca.1.1.0.0.3.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 3
%.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 1, 0
%3 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%4 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%5 = icmp ne i32 %3, 0
call void @llvm.assume(i1 %5)
%6 = zext i32 %3 to i64
%7 = sext i32 %.fca.1.0.0.0.1.extract to i64
%8 = mul nsw i64 %7, %6
%9 = lshr i64 %8, 32
%10 = trunc i64 %9 to i32
%11 = sext i8 %.fca.1.0.0.0.2.extract to i32
%12 = mul i32 %3, %11
%13 = add i32 %12, %10
%abs.i = call i32 @llvm.abs.i32(i32 %.fca.1.0.0.0.0.extract, i1 false)
%.not = icmp eq i32 %abs.i, 1
%14 = mul i32 %.fca.1.0.0.0.0.extract, %3
%narrow = call i8 @llvm.umin.i8(i8 %.fca.1.0.0.0.3.extract, i8 31)
%.v = zext i8 %narrow to i32
%15 = ashr i32 %13, %.v
%.lobit = lshr i32 %13, 31
%16 = add i32 %.lobit, %15
%17 = select i1 %.not, i32 %14, i32 %16
%18 = mul i32 %17, %.fca.1.0.0.0.0.extract
%19 = add nuw nsw i32 %3, 1
%20 = sub i32 %19, %18
%21 = add i32 %17, 1
%22 = sext i32 %20 to i64
%23 = sext i32 %21 to i64
%24 = icmp ne i32 %4, 0
call void @llvm.assume(i1 %24)
%25 = zext i32 %4 to i64
%26 = sext i32 %.fca.1.1.0.0.1.extract to i64
%27 = mul nsw i64 %26, %25
%28 = lshr i64 %27, 32
%29 = trunc i64 %28 to i32
%30 = sext i8 %.fca.1.1.0.0.2.extract to i32
%31 = mul nsw i32 %4, %30
%32 = add i32 %31, %29
%abs.i29 = call i32 @llvm.abs.i32(i32 %.fca.1.1.0.0.0.extract, i1 false)
%.not39 = icmp eq i32 %abs.i29, 1
%33 = mul i32 %.fca.1.1.0.0.0.extract, %4
%narrow37 = call i8 @llvm.umin.i8(i8 %.fca.1.1.0.0.3.extract, i8 31)
%.v36 = zext i8 %narrow37 to i32
%34 = ashr i32 %32, %.v36
%.lobit38 = lshr i32 %32, 31
%35 = add i32 %.lobit38, %34
%36 = select i1 %.not39, i32 %33, i32 %35
%37 = mul i32 %36, %.fca.1.1.0.0.0.extract
%38 = add nuw nsw i32 %4, 1
%39 = sub i32 %38, %37
%40 = add i32 %36, 1
%41 = sext i32 %39 to i64
%42 = sext i32 %40 to i64
%43 = add nsw i64 %22, -1
%44 = sext i32 %.fca.1.1.0.0.0.extract to i64
%45 = mul nsw i64 %43, %44
%46 = add nsw i64 %45, %41
%47 = add nsw i64 %23, -1
%48 = sext i32 %.fca.1.1.0.1.0.extract to i64
%49 = mul nsw i64 %47, %48
%50 = add nsw i64 %49, %42
%51 = icmp sgt i64 %46, 0
%52 = icmp sle i64 %46, %.fca.0.0.0.0.extract
%53 = and i1 %51, %52
%54 = icmp sgt i64 %50, 0
%55 = icmp sle i64 %50, %.fca.0.0.1.0.extract
%56 = and i1 %54, %55
%57 = and i1 %56, %53
br i1 %57, label %L340, label %L723
L340: ; preds = %conversion
%.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
%.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
%.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
%58 = add nsw i64 %42, -1
%59 = add nsw i64 %58, %49
%60 = mul i64 %59, %.fca.2.0.extract
%61 = add nsw i64 %41, -1
%62 = add nsw i64 %61, %45
%63 = add i64 %62, %60
%64 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
%65 = getelementptr inbounds float, float addrspace(1)* %64, i64 %63
store float %.fca.0.0.extract, float addrspace(1)* %65, align 4
br label %L723
L723: ; preds = %L340, %conversion
ret void
} Of course, this looks extra bad because the base kernel is so simple. The changes to get rid of all exceptions: diff --git a/src/nditeration.jl b/src/nditeration.jl
index b933c2b..e7f087c 100644
--- a/src/nditeration.jl
+++ b/src/nditeration.jl
@@ -29,7 +29,7 @@ end
@inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where N
@boundscheck checkbounds(iter, I...)
index = map(iter.inverses, I) do inv, i
- @inbounds getindex(Base.OneTo(inv.divisor), i)
+ @inbounds getindex(Base.OneTo(inv.divisor), i%Int32)
end
CartesianIndex(index)
end
@@ -43,13 +43,15 @@ end
function _ind2sub_recurse(inds, ind)
Base.@_inline_meta
inv = inds[1]
+ Main.LLVM.Interop.assume(ind > 0)
indnext, f, l = _div(ind, inv)
(ind-l*indnext+f, _ind2sub_recurse(Base.tail(inds), indnext)...)
end
_lookup(ind, inv::SignedMultiplicativeInverse) = ind+1
function _div(ind, inv::SignedMultiplicativeInverse)
- inv.divisor == 0 && throw(DivideError())
+ #inv.divisor == 0 && throw(DivideError())
+ Main.LLVM.Interop.assume(ind >= 0)
div(ind%Int32, inv), 1, inv.divisor
end |
73dc429
to
00bcec3
Compare
Related to JuliaGPU/Metal.jl#101
Using the idea from @N5N3 JuliaGPU/Metal.jl#101 (comment)