diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 3d8f4565cb..ff499f2491 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" struct ProblemSize final { @@ -74,9 +75,10 @@ struct ProblemSizeSplitK final struct ExecutionConfig final { - bool do_verification = true; - int init_method = 2; - bool time_kernel = false; + // 0 - no verification, 1 - CPU, 2 - GPU + int do_verification = 2; + int init_method = 2; + bool time_kernel = false; }; template @@ -125,7 +127,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -175,7 +177,7 @@ bool parse_cmd_args(int argc, else { std::cerr - << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg1: verification (0=no, 1=CPU, 2=GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl @@ -224,7 +226,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl @@ -274,7 +276,7 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index b5fecb9752..b9284b2783 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 212b72f2a6..1684213641 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 1840390aa9..1e64e9a0a3 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp index 7a9e3f6186..30faf542dd 100644 --- a/example/01_gemm/gemm_dpp_fp16.cpp +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,6 +34,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device:: + ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index f8afe8d6db..28ab878ac3 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -68,6 +68,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 3cac55ef47..6cfff30dbd 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp index cc14dcb8eb..108c100cbd 100644 --- a/example/01_gemm/gemm_xdl_bf16_rtn.cpp +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,6 +34,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 2338cdc9c1..07d51855d6 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -47,6 +47,17 @@ using DeviceGemmInstance = DeviceGemmInstance1; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp index 979a200791..a996d034e6 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -42,6 +42,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp index eba0ea9d11..ecd3b7be5d 100644 --- a/example/01_gemm/gemm_xdl_fp16_v2.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -46,6 +46,17 @@ using DeviceGemmInstance = using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 8361576299..5afb3d1554 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -41,6 +41,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl BElementOp, CElementOp>; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index fe41602301..3c75a44d21 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -37,6 +37,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index acc5fbc515..1dec165abd 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,6 +44,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index cc03200b9d..3237f1a61c 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp index d29cb74cd6..62037f7740 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -53,6 +53,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp index e99249389e..75971bdecf 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,6 +52,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp index 7d433b6145..5a02457daf 100644 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,6 +44,17 @@ using DeviceGemmInstance = DeviceGemmStreamK; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index b0f963fee5..d8672f6a0c 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,6 +37,17 @@ using DeviceGemmInstance = DeviceGemmInstance; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index a6f0d0bcfe..d315d2cdc4 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -173,6 +173,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -193,6 +194,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); @@ -325,7 +328,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; - if(config.do_verification) + // CPU verification + if(config.do_verification == 1) { auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -354,6 +358,39 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #endif } + // GPU verification + if(config.do_verification == 2) + { + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + // ref_invoker_gpu.Run(ref_argument_gpu); + + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + // ave_time = ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{nullptr, + // config.time_kernel}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + return ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + return true; } diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp new file mode 100644 index 0000000000..639b5fe80f --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + index_t m, + index_t n, + index_t k, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ + using RowMajor = ck::tensor_layout::gemm::RowMajor; + + const int row_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int col_idx = blockIdx.y * blockDim.y + threadIdx.y; + + if(row_idx < m && col_idx < n) + { + + AccDataType v_acc = static_cast(0.0); + ComputeTypeA v_a = static_cast(0.0); + ComputeTypeB v_b = static_cast(0.0); + CDataType v_c = static_cast(0.0); + + for(int k_idx = 0; k_idx < k; ++k_idx) + { + // check input matrices layout + int element_idx_a = 0; + int element_idx_b = 0; + if constexpr(std::is_same_v) + { + element_idx_a = row_idx * k + k_idx; + } + else + { + element_idx_a = row_idx + m * k_idx; + } + if constexpr(std::is_same_v) + { + element_idx_b = k_idx * n + col_idx; + } + else + { + element_idx_b = k_idx + k * col_idx; + } + // apply a_element_op + a_element_op(v_a, p_a_grid[element_idx_a]); + // apply b_element_op + b_element_op(v_b, p_b_grid[element_idx_b]); + // multiply and accumulate + v_acc += static_cast(v_a) * static_cast(v_b); + } + // apply c_element_op + c_element_op(v_c, v_acc); + // check output matrix layout + int element_idx_c = 0; + if constexpr(std::is_same_v) + { + element_idx_c = row_idx * n + col_idx; + } + else + { + element_idx_c = row_idx + m * col_idx; + } + // prepare output + p_c_grid[element_idx_c] = v_c; + } +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ReferenceGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + void* p_c_grid, + index_t m, + index_t n, + index_t k, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_c_grid_{static_cast(p_c_grid)}, + m_{m}, + n_{n}, + k_{k}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + index_t m_; + index_t n_; + index_t k_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + int block_size = 16; + dim3 block_dim(block_size, block_size, 1); + dim3 grid_dim( + (arg.m_ + block_size - 1) / block_size, (arg.n_ + block_size - 1) / block_size, 1); + + auto launch_kernel = [&]() { + const auto kernel = naive_gemm_kernel; + + return launch_and_time_kernel(stream_config, + kernel, + grid_dim, + block_dim, + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.m_, + arg.n_, + arg.k_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + return launch_kernel(); + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const void* p_a_grid, + const void* p_b_grid, + void* p_c_grid, + index_t m, + index_t n, + index_t k, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + p_a_grid, p_b_grid, p_c_grid, m, n, k, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "Device Reference Gemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck