Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter tuning #154

Open
maleadt opened this issue Sep 15, 2023 · 1 comment
Open

Parameter tuning #154

maleadt opened this issue Sep 15, 2023 · 1 comment

Comments

@maleadt
Copy link
Member

maleadt commented Sep 15, 2023

Some of the choices of parameters are currently far from optimal, as quickly explored using the following script:

using CUDA, GemmKernels
using Hyperopt
using Octavian

# we don't need super-accurate timings
const samples = 250

function main()
    M = K = N = 4096

    A = CUDA.rand(Float32, M, K)
    B = CUDA.rand(Float32, K, N)
    C = CUDA.zeros(Float32, M, N)

    C_h = zeros(Float32, M, N)
    Octavian.matmul!(C_h, Array(A), Array(B))

    # pow2-sized, 128-bit aligned inputs, so we can use aligned layouts.
    # we don't have transposed inputs, so everything is column major.
    @assert stride(A, 2) % 16 == 0
    global_a_layout = Layout.UnsafeAlignedColMajor{eltype(A)}
    @assert stride(B, 2) % 16 == 0
    global_b_layout = Layout.UnsafeAlignedColMajor{eltype(B)}
    # we want to do a simple C = A * B, so no need to load C first.
    global_c_layout = Layout.Zero{eltype(C)}
    @assert stride(C, 2) % 16 == 0
    global_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}

    # shared layouts are similar.
    # the frequently-accessed a/b shmems are padded to avoid bank conflicts.
    shared_a_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(A)}, 8}
    shared_b_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(B)}, 8}
    shared_c_layout = shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}

    # we use the single-stage kernel, for simplicity
    kernel = Kernel.matmul_singlestage

    # TODO: compute_warp is partially hardcoded in config.jl, requiring M>=4 and N >=2
    # TODO: tune warps_per_block (which may affect correctness)

    total = 0
    attempts = 0
    benchmarks = 0

    ho = @hyperopt for i = 1000,
                       OPERATOR_M = 2 .^ (1:4),
                       OPERATOR_N = 2 .^ (1:4),
                       OPERATOR_K = 2 .^ (1:4),
                       BLOCK_M = 2 .^ (1:8),
                       BLOCK_N = 2 .^ (1:8),
                       BLOCK_K = 2 .^ (1:8)
        op_shape = (M = OPERATOR_M, N = OPERATOR_N, K = OPERATOR_K)
        block_shape = (M = BLOCK_M, N = BLOCK_N, K = BLOCK_K)
        total += 1

        # validate the operator shape
        ## may not be larger than the block shape
        if op_shape.M > block_shape.M ||
           op_shape.N > block_shape.N ||
           op_shape.K > block_shape.K
            return Inf
        end
        ## the FPU operator's base shape is 4x8x1. can we relax this?
        if op_shape.M < 4 || op_shape.M % 4 != 0 ||
           op_shape.N < 8 || op_shape.N % 8 != 0
            return Inf
        end
        ## LocalArray size limits (these are the ways FPUOp instantiates them)
        if op_shape.M÷4 * op_shape.K >= 32 ||
           op_shape.K * op_shape.N÷8 >= 32 ||
           op_shape.M÷4 * op_shape.N÷8 >= 32
            # in isolation, i.e. https://github.com/JuliaGPU/GemmKernels.jl/issues/99,
            # a LocalArray of 32 elements is fine, but in the context of the kernel,
            # it's too large. I don't know why.
            return Inf
        end

        # validate the block shape
        ## needs to exactly covers the inputs, so that we can use the unsafe layouts.
        if M % block_shape.M != 0 || N % block_shape.N != 0 || K % block_shape.K != 0
            return Inf
        end
        ## need to be 128-bit aligned so that we can perform vectorized loads
        # XXX: is this correct?
        if block_shape.M * sizeof(eltype(A)) % 16 != 0 ||
           block_shape.N * sizeof(eltype(B)) % 16 != 0 ||
           block_shape.K * sizeof(eltype(C)) % 16 != 0
            return Inf
        end

        compute_type = promote_type(eltype(A), eltype(B))
        operator = Operator.FPUOp{OPERATOR_M, OPERATOR_N, OPERATOR_K, compute_type, eltype(C)}

        conf = GemmKernels.get_config(;
            gemm_shape = (; M, N, K), block_shape, operator,

            global_a_layout, global_b_layout, global_c_layout, global_d_layout,
            shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,

            is_a_col_major = true,
            is_b_col_major = true
        )

        ## another LocalArray size limit, these are in the kernel
        num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
        num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N
        if num_fragments_m * num_fragments_n >= 32
            return Inf
        end

        try
            # warm-up & correctness check
            attempts += 1
            C .= 0
            GemmKernels.matmul(conf, A, B, C, C; kernel)
            if !(Array(C)  C_h)
                @warn "Configuration produced invalid result: $conf"
                return Inf
            end

            # benchmark
            benchmarks += 1
            device_synchronize()
            GC.gc(true)
            timings = zeros(samples)
            for i in 1:samples
                synchronize(stream())
                timings[i] = CUDA.@elapsed GemmKernels.matmul(conf, A, B, C, C; kernel)
            end

            minimum(timings)
        catch err
            if isa(err, CuError)
                @error "Configuration failed: $conf"
                rethrow()
            end
            @info "Skipping configuration: $conf\n" * sprint(Base.showerror, err)
            # TODO: introduce GemmKernels.ConfigError, to differentiate from e.g.
            #       compilation errors, which we want to report verbosely.
            Inf
        end
    end

    skips = total - attempts
    errors = attempts - benchmarks
    println("Out of $total configurations, $skips ($(round(100*skips/total; digits=1))%) were skipped, $errors ($(round(100*errors/total; digits=1))%) errored, and $benchmarks ($(round(100*benchmarks/total; digits=1))%) were actually tested.")

    ho
end

isinteractive() || println(main())

For example, let's do a 256x256 GEMM, FP32xFP32=FP32, using the FPU operator. On my system (RTX 6000 Ada), the default configuration (8x8x1 (N, M, K) and block 128x128x32) yields:

julia> main()
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  53.589 μs … 115.339 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     55.609 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   55.744 μs ±   1.117 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                   ▂▂▄▆▇▆█▆█▄▃▄▁
  ▁▁▁▁▁▂▂▃▄▃▄▅▄▅▆▇▇█████████████▇▇▆▅▅▄▄▅▄▄▄▃▃▃▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁ ▄
  53.6 μs         Histogram: frequency by time         58.6 μs <

 Memory estimate: 2.98 KiB, allocs estimate: 50.

The script above optimizes this to 4x16x8 en 16x32x256, which yields:

julia> main()
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  22.549 μs …  87.980 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     24.030 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   24.212 μs ± 974.562 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                ▁▃▃▇▇██▅▆▅▃▃▂
  ▂▁▁▂▂▂▂▂▃▃▃▄▅▆█████████████▇█▆▆▆▅▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂ ▄
  22.5 μs         Histogram: frequency by time         26.5 μs <

 Memory estimate: 2.98 KiB, allocs estimate: 50.

For reference, CUBLAS:

julia> @benchmark CUDA.@sync mul!(C, A, B)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  19.850 μs …  75.970 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     20.790 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   20.884 μs ± 735.922 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                ▂▂▆███▃▃▂▁
  ▂▂▁▂▂▂▂▂▂▃▃▄▆▇███████████▇▆▅▅▅▄▄▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▄
  19.8 μs         Histogram: frequency by time         22.7 μs <

 Memory estimate: 592 bytes, allocs estimate: 20.

So a 2x improvement, getting us way closer to CUBLAS.


One problem is that the current implementation has lots of implicit assumptions on the parameter values, so lots of configurations are skipped, because they error or even result in invalid results. This should be fixed before we can fully explore the parameter space.

@maleadt
Copy link
Member Author

maleadt commented Nov 14, 2023

FWIW, latest version of the script using a simple brute-force search; HyperOpt didn't help much, and I couldn't get an evolutionary algorithm to work properly.

using CUDA, GemmKernels, UnicodePlots, Octavian

# we don't need super-accurate timings
const samples = 250

function main(; M=4096, K=4096, N=4096, T=Float32, time_limit=300)
    ## set-up

    A = CUDA.rand(T, M, K)
    B = CUDA.rand(T, K, N)
    C = CUDA.zeros(T, M, N)

    C_h = zeros(T, M, N)
    Octavian.matmul!(C_h, Array(A), Array(B))

    # pow2-sized, 128-bit aligned inputs, so we can use aligned layouts.
    # we don't have transposed inputs, so everything is column major.
    @assert stride(A, 2) % 16 == 0
    global_a_layout = Layout.UnsafeAlignedColMajor{eltype(A)}
    @assert stride(B, 2) % 16 == 0
    global_b_layout = Layout.UnsafeAlignedColMajor{eltype(B)}
    # we want to do a simple C = A * B, so no need to load C first.
    global_c_layout = Layout.Zero{eltype(C)}
    @assert stride(C, 2) % 16 == 0
    global_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}

    # shared layouts are similar.
    # the frequently-accessed a/b shmems are padded to avoid bank conflicts.
    shared_a_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(A)}, 8}
    shared_b_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(B)}, 8}
    shared_c_layout = shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}

    # we use the single-stage kernel, for simplicity
    kernel = Kernel.matmul_singlestage

    # TODO: compute_warp is partially hardcoded in config.jl, requiring M>=4 and N >=2
    # TODO: tune warps_per_block (which may affect correctness)


    ## evaluation helpers

    function rand_params()
        [rand(0:5) for _ in 1:9]
    end

    function create_config(params)
        op_m, op_n, op_k, op_mb, op_nb, op_kb, block_m, block_n, block_k = 2 .^ params
        op_shape = (M = op_m, N = op_n, K = op_k)
        block_shape = (M = block_m, N = block_n, K = block_k)

        compute_type = promote_type(eltype(A), eltype(B))
        operator = Operator.FPUOp{op_m, op_n, op_k,
                                  op_mb, op_nb, op_kb,
                                  compute_type, eltype(C)}

        conf = try
            GemmKernels.get_config(;
                gemm_shape = (; M, N, K), block_shape, operator,

                global_a_layout, global_b_layout, global_c_layout, global_d_layout,
                shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,

                is_a_col_major = true,
                is_b_col_major = true
            )
        catch err
            if isa(err, GemmKernels.ConfigError)
                return nothing
            end
            rethrow()
        end

        ## LocalArray size limit, these are in the kernel
        num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
        num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N
        if num_fragments_m * num_fragments_n >= 32
            return nothing
        end

        return conf
    end

    function measure_config(conf)
        try
            # warm-up & correctness check
            C .= 0
            GemmKernels.matmul(conf, A, B, C, C; kernel)
            if !(Array(C)  C_h)
                @error "Configuration produced invalid result: $conf"
                return nothing
            end

            # benchmark
            device_synchronize()
            GC.gc(true)
            timings = zeros(samples)
            for i in 1:samples
                synchronize(stream())
                timings[i] = CUDA.@elapsed GemmKernels.matmul(conf, A, B, C, C; kernel)
            end

            minimum(timings)
        catch err
            if isa(err, CuError)
                @error "Configuration failed: $conf"
                rethrow()
            end
            @warn "Skipping configuration: $conf\n" * sprint(Base.showerror, err)
            return nothing
        end
    end


    ## actual evaluation

    total = 0
    results = Dict()
    pending_configs = Channel(2)
    Timer(time_limit) do _
        close(pending_configs)
    end
    @sync begin
        # producer
        @async begin
            while isopen(pending_configs)
                total += 1
                params = rand_params()
                conf = create_config(params)
                if conf !== nothing
                    try
                        push!(pending_configs, conf)
                    catch err
                        err == Base.closed_exception() && return
                        rethrow()
                    end
                end
            end
        end

        # consumer
        @async begin
            for conf in pending_configs
                get!(results, conf) do
                    measure_config(conf)
                end
            end
        end
    end

    attempts = length(results)
    results = filter(kv -> !isnothing(kv[2]), results)
    nresults = length(results)

    skips = total - attempts
    errors = attempts - nresults
    println("Out of $total configurations, $skips ($(round(100*skips/total; digits=1))%) were skipped, $errors ($(round(100*errors/total; digits=1))%) errored, and $nresults ($(round(100*nresults/total; digits=1))%) were actually tested.")

    configs = sort(collect(keys(results)); by=(key->results[key]), rev=true)
    times = getindex.(Ref(results), configs)
    plot = lineplot(times, title="$(M)×$(K)×$(N) $T GEMM", ylabel="time", xlabel="config")
    display(plot)

    last(configs)
end

isinteractive() || main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant