Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MultiplicativeInverse to speedup Linear to Cartesian indexing operations #539

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vchuravy
Copy link
Member

@vchuravy vchuravy commented Oct 18, 2024

Copy link
Contributor

github-actions bot commented Oct 18, 2024

Benchmark Results

main ccdc2ad... main/ccdc2ad288a09c...
saxpy/default/Float16/1024 0.527 ± 0.005 μs 2 ± 0.011 μs 0.264
saxpy/default/Float16/1048576 0.172 ± 0.0011 ms 1.61 ± 0.013 ms 0.107
saxpy/default/Float16/16384 3.13 ± 0.046 μs 25.7 ± 0.09 μs 0.122
saxpy/default/Float16/2048 0.693 ± 0.0055 μs 3.59 ± 0.014 μs 0.193
saxpy/default/Float16/256 0.394 ± 0.0042 μs 0.81 ± 0.005 μs 0.486
saxpy/default/Float16/262144 0.0439 ± 0.00055 ms 0.403 ± 0.00055 ms 0.109
saxpy/default/Float16/32768 5.8 ± 0.084 μs 0.0509 ± 0.00011 ms 0.114
saxpy/default/Float16/4096 1.08 ± 0.017 μs 6.76 ± 0.02 μs 0.161
saxpy/default/Float16/512 0.446 ± 0.0052 μs 1.22 ± 0.011 μs 0.366
saxpy/default/Float16/64 0.364 ± 0.0043 μs 0.515 ± 0.0039 μs 0.706
saxpy/default/Float16/65536 11.5 ± 0.18 μs 0.101 ± 0.00013 ms 0.113
saxpy/default/Float32/1024 0.43 ± 0.005 μs 1.06 ± 0.012 μs 0.405
saxpy/default/Float32/1048576 0.176 ± 0.025 ms 0.658 ± 0.0092 ms 0.268
saxpy/default/Float32/16384 2.89 ± 0.98 μs 10.8 ± 0.07 μs 0.267
saxpy/default/Float32/2048 0.53 ± 0.011 μs 1.76 ± 0.012 μs 0.302
saxpy/default/Float32/256 0.391 ± 0.006 μs 0.562 ± 0.0049 μs 0.696
saxpy/default/Float32/262144 0.0531 ± 0.012 ms 0.165 ± 0.00034 ms 0.323
saxpy/default/Float32/32768 5.5 ± 1.6 μs 21.2 ± 0.07 μs 0.259
saxpy/default/Float32/4096 0.906 ± 0.023 μs 3.05 ± 0.01 μs 0.297
saxpy/default/Float32/512 0.401 ± 0.0044 μs 0.723 ± 0.0061 μs 0.555
saxpy/default/Float32/64 0.371 ± 0.0028 μs 0.44 ± 0.0041 μs 0.844
saxpy/default/Float32/65536 12.5 ± 1.4 μs 0.0418 ± 7.1e-05 ms 0.3
saxpy/default/Float64/1024 0.527 ± 0.011 μs 1.11 ± 0.013 μs 0.474
saxpy/default/Float64/1048576 0.474 ± 0.045 ms 0.686 ± 0.02 ms 0.692
saxpy/default/Float64/16384 5.19 ± 1.3 μs 11.4 ± 0.28 μs 0.455
saxpy/default/Float64/2048 0.91 ± 0.018 μs 1.79 ± 0.011 μs 0.509
saxpy/default/Float64/256 0.393 ± 0.0039 μs 0.583 ± 0.0056 μs 0.675
saxpy/default/Float64/262144 0.0882 ± 0.0089 ms 0.166 ± 0.0014 ms 0.531
saxpy/default/Float64/32768 12 ± 1.5 μs 21.8 ± 0.62 μs 0.55
saxpy/default/Float64/4096 1.49 ± 0.3 μs 3.06 ± 0.013 μs 0.488
saxpy/default/Float64/512 0.442 ± 0.0062 μs 0.755 ± 0.0069 μs 0.585
saxpy/default/Float64/64 0.371 ± 0.0054 μs 0.457 ± 0.0043 μs 0.811
saxpy/default/Float64/65536 23 ± 2.4 μs 0.0424 ± 0.0012 ms 0.542
saxpy/static workgroup=(1024,)/Float16/1024 1.88 ± 0.024 μs 1.9 ± 0.022 μs 0.989
saxpy/static workgroup=(1024,)/Float16/1048576 0.166 ± 0.011 ms 0.159 ± 0.0029 ms 1.04
saxpy/static workgroup=(1024,)/Float16/16384 4.18 ± 0.14 μs 4.17 ± 0.13 μs 1
saxpy/static workgroup=(1024,)/Float16/2048 2.07 ± 0.038 μs 2.09 ± 0.023 μs 0.993
saxpy/static workgroup=(1024,)/Float16/256 2.57 ± 0.022 μs 2.57 ± 0.021 μs 1
saxpy/static workgroup=(1024,)/Float16/262144 0.0412 ± 0.001 ms 0.042 ± 0.0014 ms 0.982
saxpy/static workgroup=(1024,)/Float16/32768 6.55 ± 0.27 μs 6.56 ± 0.23 μs 0.999
saxpy/static workgroup=(1024,)/Float16/4096 2.39 ± 0.03 μs 2.41 ± 0.027 μs 0.989
saxpy/static workgroup=(1024,)/Float16/512 3 ± 0.043 μs 3 ± 0.028 μs 1
saxpy/static workgroup=(1024,)/Float16/64 2.26 ± 0.036 μs 2.26 ± 0.025 μs 0.998
saxpy/static workgroup=(1024,)/Float16/65536 12.1 ± 0.31 μs 12.1 ± 0.57 μs 0.997
saxpy/static workgroup=(1024,)/Float32/1024 1.93 ± 0.024 μs 1.95 ± 0.028 μs 0.987
saxpy/static workgroup=(1024,)/Float32/1048576 0.197 ± 0.028 ms 0.263 ± 0.029 ms 0.746
saxpy/static workgroup=(1024,)/Float32/16384 4.08 ± 0.58 μs 4.35 ± 0.73 μs 0.937
saxpy/static workgroup=(1024,)/Float32/2048 2.07 ± 0.03 μs 2.08 ± 0.03 μs 0.996
saxpy/static workgroup=(1024,)/Float32/256 2.42 ± 0.035 μs 2.43 ± 0.044 μs 0.996
saxpy/static workgroup=(1024,)/Float32/262144 0.0474 ± 0.0042 ms 0.0643 ± 0.0061 ms 0.736
saxpy/static workgroup=(1024,)/Float32/32768 7.02 ± 0.58 μs 7.74 ± 1 μs 0.907
saxpy/static workgroup=(1024,)/Float32/4096 2.35 ± 0.05 μs 2.37 ± 0.062 μs 0.99
saxpy/static workgroup=(1024,)/Float32/512 2.42 ± 0.035 μs 2.43 ± 0.036 μs 0.997
saxpy/static workgroup=(1024,)/Float32/64 2.65 ± 8.2 μs 2.45 ± 5.3 μs 1.08
saxpy/static workgroup=(1024,)/Float32/65536 14.2 ± 1.3 μs 16.5 ± 1.7 μs 0.859
saxpy/static workgroup=(1024,)/Float64/1024 2.03 ± 0.022 μs 2.04 ± 0.029 μs 0.993
saxpy/static workgroup=(1024,)/Float64/1048576 0.486 ± 0.054 ms 0.513 ± 0.027 ms 0.947
saxpy/static workgroup=(1024,)/Float64/16384 7.11 ± 0.88 μs 6.95 ± 1.1 μs 1.02
saxpy/static workgroup=(1024,)/Float64/2048 2.31 ± 0.035 μs 2.32 ± 0.046 μs 0.992
saxpy/static workgroup=(1024,)/Float64/256 2.4 ± 0.06 μs 2.42 ± 0.095 μs 0.994
saxpy/static workgroup=(1024,)/Float64/262144 0.099 ± 0.014 ms 0.1 ± 0.0087 ms 0.986
saxpy/static workgroup=(1024,)/Float64/32768 14.4 ± 1.6 μs 15.3 ± 0.82 μs 0.939
saxpy/static workgroup=(1024,)/Float64/4096 2.93 ± 0.25 μs 2.95 ± 0.2 μs 0.993
saxpy/static workgroup=(1024,)/Float64/512 2.38 ± 0.079 μs 2.39 ± 0.049 μs 0.998
saxpy/static workgroup=(1024,)/Float64/64 2.35 ± 0.82 μs 24.6 ± 27 μs 0.0957
saxpy/static workgroup=(1024,)/Float64/65536 25.7 ± 2.5 μs 28.5 ± 3.5 μs 0.904
time_to_load 0.729 ± 0.0033 s 0.734 ± 0.0049 s 0.993

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@vchuravy vchuravy changed the title Optimize expand to avoid one unecessary sdiv Use MultiplicativeInverse to speedup Linear to Cartesian indexing operations Oct 21, 2024
@vchuravy vchuravy marked this pull request as ready for review October 21, 2024 12:20
@vchuravy vchuravy requested a review from maleadt October 21, 2024 12:20
Copy link
Member Author

vchuravy commented Oct 21, 2024

@vchuravy vchuravy force-pushed the vc/better_expand branch 2 times, most recently from 102fbbd to 73dc429 Compare October 21, 2024 16:04
@vchuravy
Copy link
Member Author

vchuravy commented Oct 21, 2024

Well the only issue is that my benchmarks are angry at me...
And I can't reproduce locally...

@maleadt
Copy link
Member

maleadt commented Oct 21, 2024

From a quick look, this doesn't look to be a silver bullet. Some initial performance measurements in JuliaGPU/GPUArrays.jl#565 (comment).

Also, there's now some InexactErrors in the hot path because of converting to an Int32:

; ││││┌ @ /home/tim/Julia/pkg/KernelAbstractions/src/nditeration.jl:31 within `getindex`
; │││││┌ @ tuple.jl:383 within `map`
; ││││││┌ @ /home/tim/Julia/pkg/KernelAbstractions/src/nditeration.jl:32 within `#5`
; │││││││┌ @ array.jl:3065 within `getindex`
; ││││││││┌ @ range.jl:923 within `_getindex`
; │││││││││┌ @ range.jl:953 within `unsafe_getindex`
; ││││││││││┌ @ number.jl:7 within `convert`
; │││││││││││┌ @ boot.jl:891 within `Int32`
; ││││││││││││┌ @ boot.jl:801 within `toInt32`
; │││││││││││││┌ @ boot.jl:764 within `checked_trunc_sint`
                %47 = add nsw i64 %43, -2147483647
                %48 = icmp ult i64 %47, -4294967296
                br i1 %48, label %L304, label %L313

L304:                                             ; preds = %L219
                call fastcc void @julia__throw_inexacterror_25251({ i64, i32 } %state)
                call void @llvm.trap()
                call void asm sideeffect "exit;", ""()
                unreachable

L313:                                             ; preds = %L219
                %49 = add nsw i64 %45, -2147483647
                %50 = icmp ult i64 %49, -4294967296
                br i1 %50, label %L335, label %L345

L335:                                             ; preds = %L313
                call fastcc void @julia__throw_inexacterror_25251({ i64, i32 } %state)
                call void @llvm.trap()
                call void asm sideeffect "exit;", ""()
                unreachable

Looks like this generates a significant amount of code. With the scalar broadcast from ttps://github.com/JuliaGPU/GPUArrays.jl/issues/565, the CUDA.jl version that simply uses hardware indices vs. the KA.jl version:

define ptx_kernel void @old({ i64, i32 } %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { [1 x float], [2 x [1 x i64]] } %1) local_unnamed_addr {
conversion:
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 0
  %.fca.2.1.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 2, 1
  %.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 3
  %2 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %3 = zext i32 %2 to i64
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
  %5 = zext i32 %4 to i64
  %6 = mul nuw nsw i64 %3, %5
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %8 = add nuw nsw i32 %7, 1
  %9 = zext i32 %8 to i64
  %10 = add nuw nsw i64 %6, %9
  %11 = icmp sgt i64 %.fca.2.0.extract, 0
  call void @llvm.assume(i1 %11)
  %12 = icmp sgt i64 %.fca.2.1.extract, 0
  call void @llvm.assume(i1 %12)
  %.not = icmp sgt i64 %10, %.fca.3.extract
  br i1 %.not, label %L176, label %pass

L176:                                             ; preds = %pass, %conversion
  ret void

pass:                                             ; preds = %conversion
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0
  %13 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %14 = add nsw i64 %10, -1
  %15 = getelementptr inbounds float, float addrspace(1)* %13, i64 %14
  %.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %1, 0, 0
  store float %.fca.0.0.extract, float addrspace(1)* %15, align 4
  br label %L176
}
define ptx_kernel void @new({ i64, i32 } %state, { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, { [1 x float], [2 x [1 x i64]] } %2) local_unnamed_addr {
conversion:
  %.fca.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 0, 0, 0, 0
  %.fca.0.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 0, 0, 1, 0
  %.fca.1.0.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 0
  %.fca.1.0.0.0.1.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 1
  %.fca.1.0.0.0.2.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 2
  %.fca.1.0.0.0.3.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 0, 0, 0, 3
  %.fca.1.1.0.0.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 0
  %.fca.1.1.0.0.1.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 1
  %.fca.1.1.0.0.2.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 2
  %.fca.1.1.0.0.3.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 0, 3
  %.fca.1.1.0.1.0.extract = extractvalue { [1 x [2 x [1 x i64]]], [2 x [1 x [2 x { i32, i32, i8, i8 }]]] } %0, 1, 1, 0, 1, 0
  %3 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %5 = icmp ne i32 %3, 0
  call void @llvm.assume(i1 %5)
  %6 = zext i32 %3 to i64
  %7 = sext i32 %.fca.1.0.0.0.1.extract to i64
  %8 = mul nsw i64 %7, %6
  %9 = lshr i64 %8, 32
  %10 = trunc i64 %9 to i32
  %11 = sext i8 %.fca.1.0.0.0.2.extract to i32
  %12 = mul i32 %3, %11
  %13 = add i32 %12, %10
  %abs.i = call i32 @llvm.abs.i32(i32 %.fca.1.0.0.0.0.extract, i1 false)
  %.not = icmp eq i32 %abs.i, 1
  %14 = mul i32 %.fca.1.0.0.0.0.extract, %3
  %narrow = call i8 @llvm.umin.i8(i8 %.fca.1.0.0.0.3.extract, i8 31)
  %.v = zext i8 %narrow to i32
  %15 = ashr i32 %13, %.v
  %.lobit = lshr i32 %13, 31
  %16 = add i32 %.lobit, %15
  %17 = select i1 %.not, i32 %14, i32 %16
  %18 = mul i32 %17, %.fca.1.0.0.0.0.extract
  %19 = add nuw nsw i32 %3, 1
  %20 = sub i32 %19, %18
  %21 = add i32 %17, 1
  %22 = sext i32 %20 to i64
  %23 = sext i32 %21 to i64
  %24 = icmp ne i32 %4, 0
  call void @llvm.assume(i1 %24)
  %25 = zext i32 %4 to i64
  %26 = sext i32 %.fca.1.1.0.0.1.extract to i64
  %27 = mul nsw i64 %26, %25
  %28 = lshr i64 %27, 32
  %29 = trunc i64 %28 to i32
  %30 = sext i8 %.fca.1.1.0.0.2.extract to i32
  %31 = mul nsw i32 %4, %30
  %32 = add i32 %31, %29
  %abs.i29 = call i32 @llvm.abs.i32(i32 %.fca.1.1.0.0.0.extract, i1 false)
  %.not39 = icmp eq i32 %abs.i29, 1
  %33 = mul i32 %.fca.1.1.0.0.0.extract, %4
  %narrow37 = call i8 @llvm.umin.i8(i8 %.fca.1.1.0.0.3.extract, i8 31)
  %.v36 = zext i8 %narrow37 to i32
  %34 = ashr i32 %32, %.v36
  %.lobit38 = lshr i32 %32, 31
  %35 = add i32 %.lobit38, %34
  %36 = select i1 %.not39, i32 %33, i32 %35
  %37 = mul i32 %36, %.fca.1.1.0.0.0.extract
  %38 = add nuw nsw i32 %4, 1
  %39 = sub i32 %38, %37
  %40 = add i32 %36, 1
  %41 = sext i32 %39 to i64
  %42 = sext i32 %40 to i64
  %43 = add nsw i64 %22, -1
  %44 = sext i32 %.fca.1.1.0.0.0.extract to i64
  %45 = mul nsw i64 %43, %44
  %46 = add nsw i64 %45, %41
  %47 = add nsw i64 %23, -1
  %48 = sext i32 %.fca.1.1.0.1.0.extract to i64
  %49 = mul nsw i64 %47, %48
  %50 = add nsw i64 %49, %42
  %51 = icmp sgt i64 %46, 0
  %52 = icmp sle i64 %46, %.fca.0.0.0.0.extract
  %53 = and i1 %51, %52
  %54 = icmp sgt i64 %50, 0
  %55 = icmp sle i64 %50, %.fca.0.0.1.0.extract
  %56 = and i1 %54, %55
  %57 = and i1 %56, %53
  br i1 %57, label %L340, label %L723

L340:                                             ; preds = %conversion
  %.fca.0.0.extract = extractvalue { [1 x float], [2 x [1 x i64]] } %2, 0, 0
  %.fca.2.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 2, 0
  %.fca.0.extract = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0
  %58 = add nsw i64 %42, -1
  %59 = add nsw i64 %58, %49
  %60 = mul i64 %59, %.fca.2.0.extract
  %61 = add nsw i64 %41, -1
  %62 = add nsw i64 %61, %45
  %63 = add i64 %62, %60
  %64 = bitcast i8 addrspace(1)* %.fca.0.extract to float addrspace(1)*
  %65 = getelementptr inbounds float, float addrspace(1)* %64, i64 %63
  store float %.fca.0.0.extract, float addrspace(1)* %65, align 4
  br label %L723

L723:                                             ; preds = %L340, %conversion
  ret void
}

Of course, this looks extra bad because the base kernel is so simple.


The changes to get rid of all exceptions:

diff --git a/src/nditeration.jl b/src/nditeration.jl
index b933c2b..e7f087c 100644
--- a/src/nditeration.jl
+++ b/src/nditeration.jl
@@ -29,7 +29,7 @@ end
 @inline function Base.getindex(iter::FastCartesianIndices{N}, I::Vararg{Int, N}) where N
     @boundscheck checkbounds(iter, I...)
     index = map(iter.inverses, I) do inv, i
-        @inbounds getindex(Base.OneTo(inv.divisor), i)
+        @inbounds getindex(Base.OneTo(inv.divisor), i%Int32)
     end
     CartesianIndex(index)
 end
@@ -43,13 +43,15 @@ end
 function _ind2sub_recurse(inds, ind)
     Base.@_inline_meta
     inv = inds[1]
+    Main.LLVM.Interop.assume(ind > 0)
     indnext, f, l = _div(ind, inv)
     (ind-l*indnext+f, _ind2sub_recurse(Base.tail(inds), indnext)...)
 end

 _lookup(ind, inv::SignedMultiplicativeInverse) = ind+1
 function _div(ind, inv::SignedMultiplicativeInverse)
-    inv.divisor == 0 && throw(DivideError())
+    #inv.divisor == 0 && throw(DivideError())
+    Main.LLVM.Interop.assume(ind >= 0)
     div(ind%Int32, inv), 1, inv.divisor
 end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants