-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parameter tuning #154
Comments
FWIW, latest version of the script using a simple brute-force search; HyperOpt didn't help much, and I couldn't get an evolutionary algorithm to work properly. using CUDA, GemmKernels, UnicodePlots, Octavian
# we don't need super-accurate timings
const samples = 250
function main(; M=4096, K=4096, N=4096, T=Float32, time_limit=300)
## set-up
A = CUDA.rand(T, M, K)
B = CUDA.rand(T, K, N)
C = CUDA.zeros(T, M, N)
C_h = zeros(T, M, N)
Octavian.matmul!(C_h, Array(A), Array(B))
# pow2-sized, 128-bit aligned inputs, so we can use aligned layouts.
# we don't have transposed inputs, so everything is column major.
@assert stride(A, 2) % 16 == 0
global_a_layout = Layout.UnsafeAlignedColMajor{eltype(A)}
@assert stride(B, 2) % 16 == 0
global_b_layout = Layout.UnsafeAlignedColMajor{eltype(B)}
# we want to do a simple C = A * B, so no need to load C first.
global_c_layout = Layout.Zero{eltype(C)}
@assert stride(C, 2) % 16 == 0
global_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
# shared layouts are similar.
# the frequently-accessed a/b shmems are padded to avoid bank conflicts.
shared_a_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(A)}, 8}
shared_b_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(B)}, 8}
shared_c_layout = shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
# we use the single-stage kernel, for simplicity
kernel = Kernel.matmul_singlestage
# TODO: compute_warp is partially hardcoded in config.jl, requiring M>=4 and N >=2
# TODO: tune warps_per_block (which may affect correctness)
## evaluation helpers
function rand_params()
[rand(0:5) for _ in 1:9]
end
function create_config(params)
op_m, op_n, op_k, op_mb, op_nb, op_kb, block_m, block_n, block_k = 2 .^ params
op_shape = (M = op_m, N = op_n, K = op_k)
block_shape = (M = block_m, N = block_n, K = block_k)
compute_type = promote_type(eltype(A), eltype(B))
operator = Operator.FPUOp{op_m, op_n, op_k,
op_mb, op_nb, op_kb,
compute_type, eltype(C)}
conf = try
GemmKernels.get_config(;
gemm_shape = (; M, N, K), block_shape, operator,
global_a_layout, global_b_layout, global_c_layout, global_d_layout,
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
is_a_col_major = true,
is_b_col_major = true
)
catch err
if isa(err, GemmKernels.ConfigError)
return nothing
end
rethrow()
end
## LocalArray size limit, these are in the kernel
num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N
if num_fragments_m * num_fragments_n >= 32
return nothing
end
return conf
end
function measure_config(conf)
try
# warm-up & correctness check
C .= 0
GemmKernels.matmul(conf, A, B, C, C; kernel)
if !(Array(C) ≈ C_h)
@error "Configuration produced invalid result: $conf"
return nothing
end
# benchmark
device_synchronize()
GC.gc(true)
timings = zeros(samples)
for i in 1:samples
synchronize(stream())
timings[i] = CUDA.@elapsed GemmKernels.matmul(conf, A, B, C, C; kernel)
end
minimum(timings)
catch err
if isa(err, CuError)
@error "Configuration failed: $conf"
rethrow()
end
@warn "Skipping configuration: $conf\n" * sprint(Base.showerror, err)
return nothing
end
end
## actual evaluation
total = 0
results = Dict()
pending_configs = Channel(2)
Timer(time_limit) do _
close(pending_configs)
end
@sync begin
# producer
@async begin
while isopen(pending_configs)
total += 1
params = rand_params()
conf = create_config(params)
if conf !== nothing
try
push!(pending_configs, conf)
catch err
err == Base.closed_exception() && return
rethrow()
end
end
end
end
# consumer
@async begin
for conf in pending_configs
get!(results, conf) do
measure_config(conf)
end
end
end
end
attempts = length(results)
results = filter(kv -> !isnothing(kv[2]), results)
nresults = length(results)
skips = total - attempts
errors = attempts - nresults
println("Out of $total configurations, $skips ($(round(100*skips/total; digits=1))%) were skipped, $errors ($(round(100*errors/total; digits=1))%) errored, and $nresults ($(round(100*nresults/total; digits=1))%) were actually tested.")
configs = sort(collect(keys(results)); by=(key->results[key]), rev=true)
times = getindex.(Ref(results), configs)
plot = lineplot(times, title="$(M)×$(K)×$(N) $T GEMM", ylabel="time", xlabel="config")
display(plot)
last(configs)
end
isinteractive() || main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some of the choices of parameters are currently far from optimal, as quickly explored using the following script:
For example, let's do a 256x256 GEMM, FP32xFP32=FP32, using the FPU operator. On my system (RTX 6000 Ada), the default configuration (8x8x1 (N, M, K) and block 128x128x32) yields:
The script above optimizes this to 4x16x8 en 16x32x256, which yields:
For reference, CUBLAS:
So a 2x improvement, getting us way closer to CUBLAS.
One problem is that the current implementation has lots of implicit assumptions on the parameter values, so lots of configurations are skipped, because they error or even result in invalid results. This should be fixed before we can fully explore the parameter space.
The text was updated successfully, but these errors were encountered: