From 895e8c40c7617696d9aae166ba276fca79170f8f Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 09:36:44 -0700 Subject: [PATCH 01/19] replace buffer_atomic with global_atomic --- cmake/gtest.cmake | 3 +- include/ck/utility/amd_buffer_addressing.hpp | 36 ++++++++++++++++++-- include/ck/utility/dynamic_buffer.hpp | 6 ++-- script/cmake-ck-dev.sh | 2 +- 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 0915f53411..c1da2a22c0 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -8,7 +8,8 @@ endif() FetchContent_Declare( GTest - GIT_REPOSITORY https://github.com/google/googletest.git + #GIT_REPOSITORY https://github.com/google/googletest.git + GIT_REPOSITORY git@github.com:google/googletest.git GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 ) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index ab22134fc6..acd41b7823 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -562,6 +562,33 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src dst_wave_addr_offset); } +template +__device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, + T* addr) +{ + if constexpr(is_same::value) + { + if constexpr(N == 2) + { + __builtin_amdgcn_global_atomic_fadd_v2f16(addr, src_thread_data); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + static_for<0, 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType()[i]); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + static_for<0, 4, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType()[i]); + }); + } + } +} + template __device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, @@ -907,7 +934,7 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; -#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#if 0 uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_add_impl( @@ -915,8 +942,11 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr #else if(dst_thread_element_valid) { - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + ignore = dst_wave_buffer_resource; + ignore = dst_thread_addr_offset; + //amd_buffer_atomic_add_impl( + //src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + amd_global_atomic_add_impl(src_thread_data, p_dst_wave + dst_thread_element_offset); } #endif } diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 76390e614e..0dcc514a2f 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -358,13 +358,15 @@ struct DynamicBuffer bool constexpr use_amd_buffer_addressing = is_same_v, int32_t> || is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0); + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #else bool constexpr use_amd_buffer_addressing = false; #endif diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 26326523f4..54abbbce14 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -14,7 +14,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_HIP_FLAGS="--save-temps -v -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ From b0f295cb60c33d6478cbc56df1174ecec20716cc Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 10:02:08 -0700 Subject: [PATCH 02/19] fixed global_atomic_add --- include/ck/utility/amd_buffer_addressing.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index acd41b7823..f866693e7e 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -568,6 +568,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ { if constexpr(is_same::value) { +#if 0 if constexpr(N == 2) { __builtin_amdgcn_global_atomic_fadd_v2f16(addr, src_thread_data); @@ -586,6 +587,13 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ __builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType()[i]); }); } +#else + static_assert(N % 2 == 0, ""); + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, tmp.template AsType()[i]); + }); +#endif } } From f9b8a5d0754bd2e6853ac6a049c0258967e345a3 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 10:17:03 -0700 Subject: [PATCH 03/19] added bf16 atomic_add --- example/01_gemm/gemm_xdl_bf16_v3.cpp | 2 +- example/01_gemm/run_gemm_example_v2.inc | 2 +- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 7 ----- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 4 ++- include/ck/utility/amd_buffer_addressing.hpp | 29 +++++-------------- 5 files changed, 13 insertions(+), 31 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16_v3.cpp b/example/01_gemm/gemm_xdl_bf16_v3.cpp index e538aee1fe..d19d76aba6 100644 --- a/example/01_gemm/gemm_xdl_bf16_v3.cpp +++ b/example/01_gemm/gemm_xdl_bf16_v3.cpp @@ -19,7 +19,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; // clang-format off using DeviceGemmV2Instance = diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 5dcf8c3faa..40cc22c360 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -272,7 +272,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) if(config.time_kernel) { - ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 5785e64d1a..2670b5e0bc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -168,7 +168,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2, bhalf_t>::value) { if(arg_.KBatch > 1) hipGetErrorString( @@ -190,7 +189,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2, bhalf_t>::value) { if(arg.KBatch > 1) hipGetErrorString(hipMemsetAsync(arg.p_c_grid, @@ -215,7 +213,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { - if constexpr(!is_same, bhalf_t>::value) { const auto kernel = kernel_gemm_xdl_cshuffle_v3 1) { - if constexpr(!is_same, bhalf_t>::value) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { @@ -473,7 +469,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { - if constexpr(!is_same, bhalf_t>::value) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -525,7 +520,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { - if constexpr(!is_same, bhalf_t>::value) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -582,7 +576,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { - if constexpr(!is_same, bhalf_t>::value) { const auto kernel = kernel_gemm_xdl_cshuffle_v3, half_t>::value || - is_same, float>::value)) + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) { if(!karg.IsReduceAdd()) { diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index f866693e7e..139c0714cd 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -568,32 +568,19 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ { if constexpr(is_same::value) { -#if 0 - if constexpr(N == 2) - { - __builtin_amdgcn_global_atomic_fadd_v2f16(addr, src_thread_data); - } - else if constexpr(N == 4) - { - vector_type tmp{src_thread_data}; - static_for<0, 2, 1>{}([&](auto i) { - __builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType()[i]); - }); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - static_for<0, 4, 1>{}([&](auto i) { - __builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType()[i]); - }); - } -#else static_assert(N % 2 == 0, ""); vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, tmp.template AsType()[i]); }); -#endif + } + else if constexpr(is_same::value) + { + static_assert(N % 2 == 0, ""); + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, tmp.template AsType()[i]); + }); } } From c70aacd37df49277ab3ba4ea025e2bfb99deb14d Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 13:38:44 -0700 Subject: [PATCH 04/19] format --- example/01_gemm/run_gemm_example_v2.inc | 3 +- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 291 +++++++++--------- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 23 +- include/ck/utility/amd_buffer_addressing.hpp | 13 +- 4 files changed, 154 insertions(+), 176 deletions(-) diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 40cc22c360..5e5cb935b9 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) if(config.time_kernel) { - ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 2670b5e0bc..f820f9ed2e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) - hipGetErrorString( - hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - } + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( @@ -189,13 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - } + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); ave_time = launch_and_time_kernel( stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); @@ -213,14 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); } else { @@ -237,117 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) { const auto kernel = kernel_gemm_xdl_cshuffle_v3< GridwiseGemm, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, - TailNumber::One>; + TailNumber::Two>; Run(kernel); } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) { const auto kernel = kernel_gemm_xdl_cshuffle_v3< GridwiseGemm, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, - TailNumber::Full>; + TailNumber::Three>; Run(kernel); } + } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Two) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Two>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Three>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); } + } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); } + } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Six) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); } + } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); } } } @@ -469,27 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); } } else @@ -520,27 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); } } else @@ -576,14 +561,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 1) { - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); } else { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index fcff776744..319982257d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -29,7 +29,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -57,7 +57,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 139c0714cd..3ef2a658d1 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ static_assert(N % 2 == 0, ""); vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { - __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, tmp.template AsType()[i]); + __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, + tmp.template AsType()[i]); }); } else if constexpr(is_same::value) @@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ static_assert(N % 2 == 0, ""); vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { - __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, tmp.template AsType()[i]); + __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, + tmp.template AsType()[i]); }); } } @@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr { ignore = dst_wave_buffer_resource; ignore = dst_thread_addr_offset; - //amd_buffer_atomic_add_impl( - //src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - amd_global_atomic_add_impl(src_thread_data, p_dst_wave + dst_thread_element_offset); + // amd_buffer_atomic_add_impl( + // src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + amd_global_atomic_add_impl(src_thread_data, + p_dst_wave + dst_thread_element_offset); } #endif } From f5ea85f34cfb76b449b76fb1068681a539e2f701 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 15:41:54 -0500 Subject: [PATCH 05/19] clang-format-12 --- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 319982257d..fcff776744 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -29,7 +29,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -57,7 +57,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -485,11 +485,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC - << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 - << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } From f71203425ead5f60a2ad688962ad22adc7a99779 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 13:53:49 -0700 Subject: [PATCH 06/19] clean --- include/ck/utility/amd_buffer_addressing.hpp | 34 ++++++++++++-------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 3ef2a658d1..1595073153 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -931,22 +931,30 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; -#if 0 - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if(dst_thread_element_valid) + if constexpr(is_same::value) { - ignore = dst_wave_buffer_resource; - ignore = dst_thread_addr_offset; - // amd_buffer_atomic_add_impl( - // src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - amd_global_atomic_add_impl(src_thread_data, - p_dst_wave + dst_thread_element_offset); + if(dst_thread_element_valid) + { + + amd_global_atomic_add_impl( + src_thread_data, p_dst_wave + dst_thread_element_offset); + } } + else + { +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } #endif + } } // buffer_atomic_max requires: From 35e61bf63e74fd710a4e41f140ec368bb041135d Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 13:56:14 -0700 Subject: [PATCH 07/19] clean --- script/cmake-ck-dev.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 54abbbce14..26326523f4 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -14,7 +14,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_HIP_FLAGS="--save-temps -v -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ From 0fff2a666c3ff445b2d769a8927409434d55411b Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 13:58:40 -0700 Subject: [PATCH 08/19] add guards --- include/ck/utility/amd_buffer_addressing.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 1595073153..20452821ce 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -566,9 +566,12 @@ template __device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, T* addr) { + static_assert((is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + if constexpr(is_same::value) { - static_assert(N % 2 == 0, ""); vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, @@ -577,7 +580,6 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ } else if constexpr(is_same::value) { - static_assert(N % 2 == 0, ""); vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, @@ -935,7 +937,6 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr { if(dst_thread_element_valid) { - amd_global_atomic_add_impl( src_thread_data, p_dst_wave + dst_thread_element_offset); } From 65fb572df0d879e0e44beda25ee108bf2d1cb1f6 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 31 Jul 2024 16:00:04 -0500 Subject: [PATCH 09/19] Update gtest.cmake --- cmake/gtest.cmake | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index c1da2a22c0..0915f53411 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -8,8 +8,7 @@ endif() FetchContent_Declare( GTest - #GIT_REPOSITORY https://github.com/google/googletest.git - GIT_REPOSITORY git@github.com:google/googletest.git + GIT_REPOSITORY https://github.com/google/googletest.git GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 ) From ed2d5e404d47e5985f44d71019d1ced0b44779a5 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 15:06:05 -0700 Subject: [PATCH 10/19] enabled splitk_gemm_multi_d --- .../gemm_multiply_multiply_xdl_fp8.cpp | 36 ++-- ...emm_multiply_multiply_xdl_fp8_ab_scale.cpp | 2 +- .../gpu/device/device_gemm_multiple_d.hpp | 45 ++++ ...device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 193 +++++++++++++++++- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 1 - .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 10 +- 6 files changed, 258 insertions(+), 29 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index b0b1aa73c1..fd7870629b 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -69,18 +69,19 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 - // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RRR - ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; -///###### RCR - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format off + < Row, Col, DsLayout, ELayout, + A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, 256, 128, 64, + 16, 16, 32, 32, 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -99,6 +100,8 @@ int main(int argc, char* argv[]) ck::index_t StrideD = 0; ck::index_t StrideE = N; + ck::index_t KBatch = 1; + if(argc == 1) { // use default case @@ -109,7 +112,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 11) + else if(argc == 12) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -123,13 +126,15 @@ int main(int argc, char* argv[]) StrideB = std::stoi(argv[8]); StrideD = std::stoi(argv[9]); StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); exit(0); } @@ -212,6 +217,7 @@ int main(int argc, char* argv[]) StrideB, std::array{I0, I0}, StrideE, + KBatch, a_element_op, b_element_op, cde_element_op); @@ -236,10 +242,12 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + Tensor c_m_n({M, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm MakeInvokerPointer() = 0; }; +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + + + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 92aa47d53d..08f2082d3a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -69,7 +69,7 @@ template -struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD, bhalf_t>::value) - { if(arg_.KBatch > 1) hipGetErrorString( hipMemsetAsync(arg_.p_c_grid, 0, arg_.M * arg_.N * sizeof(CDataType), stream_config.stream_id_)); - } }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( @@ -234,6 +231,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else { const auto kernel = kernel_gemm_xdl_cshuffle_v3 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { @@ -361,7 +481,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -387,6 +530,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -416,6 +583,16 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else { const auto kernel = kernel_gemm_xdl_cshuffle_v3 StrideDs, index_t StrideC, + index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -494,7 +672,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD StrideDs, index_t StrideC, + index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override @@ -529,7 +708,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD 1) { const auto kernel = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 3a1ac6c6de..510bf8cee5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -38,8 +38,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -56,7 +55,7 @@ __global__ void karg.c_element_op); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template Date: Wed, 31 Jul 2024 17:08:46 -0500 Subject: [PATCH 11/19] format --- .../gemm_multiply_multiply_xdl_fp8.cpp | 5 +-- ...emm_multiply_multiply_xdl_fp8_ab_scale.cpp | 2 +- .../gpu/device/device_gemm_multiple_d.hpp | 2 -- ...device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 31 +++++++++---------- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index fd7870629b..35534895b4 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -72,7 +72,7 @@ using CDEElementOp = MultiplyMultiply; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -// clang-format off + // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -134,7 +134,8 @@ int main(int argc, char* argv[]) printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); exit(0); } diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 92eca85b57..2568754648 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -60,7 +60,7 @@ static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 -// clang-format off + // clang-format off MakeInvokerPointer() = 0; }; - - } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 08f2082d3a..72523e5027 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -70,16 +70,16 @@ template struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation> { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -192,12 +192,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) - hipGetErrorString( - hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( From 32380a2798643cf1b0b5e85112556721adb5a8d1 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 16:03:58 -0700 Subject: [PATCH 12/19] add ckProfiler --- .../gpu/gemm_multiply_multiply.hpp | 24 +++++++++---------- ...f8_bf16_mk_nk_mn_comp_default_instance.cpp | 2 +- ...8_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 2 +- ...bf16_mk_nk_mn_comp_mnkpadding_instance.cpp | 2 +- ..._bf16_mk_nk_mn_comp_mnpadding_instance.cpp | 2 +- ..._bf16_mk_nk_mn_mem_v1_default_instance.cpp | 2 +- ...bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 2 +- ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 2 +- ..._bf16_mk_nk_mn_mem_v2_default_instance.cpp | 2 +- ...bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 2 +- ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 2 +- .../profile_gemm_multiply_multiply_impl.hpp | 4 +++- .../src/profile_gemm_multiply_multiply.cpp | 20 +++++++++------- 13 files changed, 37 insertions(+), 31 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp index f8e8e8fdec..7aa95566ff 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp @@ -18,7 +18,7 @@ namespace device { namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( - std::vector, Row, @@ -31,7 +31,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector, Row, @@ -44,7 +44,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_ins MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances( - std::vector, Row, @@ -57,7 +57,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_in MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances( - std::vector, Row, @@ -70,7 +70,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_i MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector, Row, @@ -83,7 +83,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_in MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector, Row, @@ -96,7 +96,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_i MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( - std::vector, Row, @@ -109,7 +109,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector, Row, @@ -122,7 +122,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_in MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector, Row, @@ -135,7 +135,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( - std::vector, Row, @@ -154,7 +154,7 @@ template -struct DeviceOperationInstanceFactory, @@ -167,7 +167,7 @@ struct DeviceOperationInstanceFactory> { - using DeviceOp = DeviceGemmMultipleD, CLayout, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp index 81131b4de2..ebe4571eb9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp index 149e4ad144..363504d656 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp index ba71f924e0..895a726cef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp index e76f4f82b3..82304d9e5b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp index 03f360a457..c1adbf2df7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp index 194615e0fa..c9c9fdd2b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp index ae82b5800e..707af413b9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp index 47bf0df2c7..adb74c5898 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp index 88ee816202..66272d71f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector, Row, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp index 2c8784bedb..45f9ac0c72 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace device { namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( - std::vector, Row, diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp index 022399a9c0..e838db596f 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp @@ -48,6 +48,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, int StrideD0, int StrideD1, int StrideE, + int KBatch, int n_warmup, int n_iter, uint64_t rotating = 0) @@ -129,7 +130,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, d1_device_buf.ToDevice(d1_m_n.mData.data()); using DeviceOp = - ck::tensor_operation::device::DeviceGemmMultipleD, ELayout, @@ -199,6 +200,7 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, StrideB, std::array{StrideD0, StrideD1}, StrideE, + KBatch, a_element_op, b_element_op, c_element_op); diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index 42201f7f22..e809c6df6a 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -34,7 +34,7 @@ enum struct GemmDataType int profile_gemm_multiply_multiply(int argc, char* argv[]) { - if(argc != 16 && argc != 19) + if(argc != 16 && argc != 20) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " @@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) printf("arg7: time kernel (0=no, 1=yes)\n"); printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); printf("optional:\n"); - printf("arg16: number of warm-up cycles (default 1)\n"); - printf("arg17: number of iterations (default 10)\n"); - printf("arg18: memory for rotating buffer (default 0, size in MB)\n"); + printf("arg16: number of kbatch (default 1)\n"); + printf("arg17: number of warm-up cycles (default 1)\n"); + printf("arg18: number of iterations (default 10)\n"); + printf("arg19: memory for rotating buffer (default 0, size in MB)\n"); exit(1); } @@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) int n_warmup = 1; int n_iter = 10; uint64_t rotating = 0; - if(argc == 18) + int KBatch = 1; + if(argc == 20) { - n_warmup = std::stoi(argv[16]); - n_iter = std::stoi(argv[17]); - rotating = std::stoull(argv[18]) * 1024 * 1024; + KBatch = std::stoi(argv[16]); + n_warmup = std::stoi(argv[17]); + n_iter = std::stoi(argv[18]); + rotating = std::stoull(argv[19]) * 1024 * 1024; } using F32 = float; @@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, (StrideE < 0) ? DefaultStrideE : StrideE, + KBatch, n_warmup, n_iter, rotating); From bbb29a9df3e701073cbba18feb97eb9401845f1e Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 18:04:22 -0500 Subject: [PATCH 13/19] format --- .../gpu/gemm_multiply_multiply.hpp | 223 +++++++++--------- ...f8_bf16_mk_nk_mn_comp_default_instance.cpp | 20 +- ...8_bf16_mk_nk_mn_comp_kpadding_instance.cpp | 20 +- ...bf16_mk_nk_mn_comp_mnkpadding_instance.cpp | 20 +- ..._bf16_mk_nk_mn_comp_mnpadding_instance.cpp | 20 +- ..._bf16_mk_nk_mn_mem_v1_default_instance.cpp | 20 +- ...bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 20 +- ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 20 +- ..._bf16_mk_nk_mn_mem_v2_default_instance.cpp | 20 +- ...bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 20 +- ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 20 +- .../profile_gemm_multiply_multiply_impl.hpp | 20 +- .../src/profile_gemm_multiply_multiply.cpp | 4 +- 13 files changed, 224 insertions(+), 223 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp index 7aa95566ff..ec81fd7b30 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp @@ -19,133 +19,133 @@ namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances); + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances); #endif template > { - using DeviceOp = DeviceGemmMultipleDSplitK, - CLayout, - ADataType, - BDataType, - Tuple, - CDataType, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::MultiplyMultiply>; + using DeviceOp = + DeviceGemmMultipleDSplitK, + CLayout, + ADataType, + BDataType, + Tuple, + CDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::MultiplyMultiply>; static auto GetInstances() { diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp index ebe4571eb9..6527d93473 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp index 363504d656..7f16a7a2c5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp index 895a726cef..972a2ece39 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp index 82304d9e5b..d5b5e35660 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp index c1adbf2df7..d2e64be2f6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp index c9c9fdd2b7..3a57f860f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp index 707af413b9..8315116572 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp index adb74c5898..1515021f6d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp index 66272d71f0..1b80244f84 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp index 45f9ac0c72..6978f89b96 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -10,16 +10,16 @@ namespace instance { void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& instances) + Col, + Tuple, + Row, + F8, + F8, + Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyMultiply>>>& instances) { add_device_operation_instances( instances, diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp index e838db596f..7b72777128 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp @@ -131,16 +131,16 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDSplitK, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CElementOp>; + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index e809c6df6a..b7e80ed798 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -77,10 +77,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) int n_warmup = 1; int n_iter = 10; uint64_t rotating = 0; - int KBatch = 1; + int KBatch = 1; if(argc == 20) { - KBatch = std::stoi(argv[16]); + KBatch = std::stoi(argv[16]); n_warmup = std::stoi(argv[17]); n_iter = std::stoi(argv[18]); rotating = std::stoull(argv[19]) * 1024 * 1024; From 8d74dcac5f4499d05d2f89f9c0f7da83d5118a75 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 16:42:11 -0700 Subject: [PATCH 14/19] fixed naming --- ...device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 56 +++++++++---------- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 4 +- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 72523e5027..6d83480ead 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -233,7 +233,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; @@ -242,7 +242,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; @@ -257,7 +257,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; @@ -594,7 +594,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 510bf8cee5..2a19f1ab5c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -36,7 +36,7 @@ __global__ void __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -68,7 +68,7 @@ __global__ void __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write From ff47f28c17bc0acc6b9ca6cee999e674ad5200c6 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Wed, 31 Jul 2024 18:43:00 -0500 Subject: [PATCH 15/19] format --- ...device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 208 +++++++++--------- 1 file changed, 104 insertions(+), 104 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 6d83480ead..e955ba9474 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -232,20 +232,20 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_multi_d; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d; + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } @@ -256,23 +256,23 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; Run(kernel); } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_multi_d; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; Run(kernel); } @@ -370,10 +370,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; Run(kernel); } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == @@ -381,10 +381,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; Run(kernel); } @@ -392,12 +392,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; Run(kernel); } } @@ -407,12 +407,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; Run(kernel); } } @@ -422,12 +422,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; Run(kernel); } } @@ -437,12 +437,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; Run(kernel); } } @@ -451,12 +451,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; Run(kernel); } } @@ -466,12 +466,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; Run(kernel); } } @@ -507,22 +507,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_multi_d_2lds; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -533,22 +533,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_multi_d; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -558,20 +558,20 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d; + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } @@ -584,20 +584,20 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_multi_d; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d; + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; Run(kernel); } } From 79ac8751a4399b40de89af4ba47f74d013f93da9 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 1 Aug 2024 18:18:37 -0700 Subject: [PATCH 16/19] clean --- example/01_gemm/gemm_xdl_bf16_v3.cpp | 2 +- example/01_gemm/run_gemm_example_v2.inc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16_v3.cpp b/example/01_gemm/gemm_xdl_bf16_v3.cpp index d19d76aba6..e538aee1fe 100644 --- a/example/01_gemm/gemm_xdl_bf16_v3.cpp +++ b/example/01_gemm/gemm_xdl_bf16_v3.cpp @@ -19,7 +19,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // clang-format off using DeviceGemmV2Instance = diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 5e5cb935b9..5dcf8c3faa 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -272,8 +272,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) if(config.time_kernel) { - ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = From a1cd282e4811fada4b54d106b46fa72fff9f03dc Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 1 Aug 2024 18:20:44 -0700 Subject: [PATCH 17/19] clean --- .../gemm_multiply_multiply_xdl_fp8.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index 35534895b4..51557f8778 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -72,16 +72,15 @@ using CDEElementOp = MultiplyMultiply; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 - // clang-format off - < Row, Col, DsLayout, ELayout, - A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, 256, 128, 64, - 16, 16, 32, 32, 4, 2, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR + ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; +///###### RCR + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // clang-format on int main(int argc, char* argv[]) From 7d69eb3b15a83bff38b1153c0d5e11fd145fb52b Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Thu, 1 Aug 2024 18:50:52 -0700 Subject: [PATCH 18/19] add guards --- example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp | 1 + include/ck/utility/amd_buffer_addressing.hpp | 2 ++ 2 files changed, 3 insertions(+) diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp index 5fea43ffc3..580f38a79f 100644 --- a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -208,6 +208,7 @@ int main(int argc, char* argv[]) StrideB, std::array{StrideD, StrideD}, StrideE, + 1, a_element_op, b_element_op, cde_element_op); diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 20452821ce..d4ee5c886c 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -578,6 +578,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } +#if defined(__gfx942__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; @@ -586,6 +587,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } +#endif } template From dddff92a99c8d0e6ae4f48c603d59d868c067c4d Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 5 Aug 2024 12:28:50 -0700 Subject: [PATCH 19/19] add mn_looping --- cmake/gtest.cmake | 3 +- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 17 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 914 +++++++++--------- profiler/src/CMakeLists.txt | 304 +++--- 4 files changed, 649 insertions(+), 589 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 0915f53411..dbb9642e20 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -8,7 +8,8 @@ endif() FetchContent_Declare( GTest - GIT_REPOSITORY https://github.com/google/googletest.git + GIT_REPOSITORY git@github.com:google/googletest.git + #GIT_REPOSITORY https://github.com/google/googletest.git GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 ) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 57a25526ce..07aa68bfd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -144,6 +144,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 num_CUs ? num_CUs : gdx; + + std::cout << "gdx = " << gdx << " new_gdx = " << new_gdx << std::endl; + float ave_time = 0; index_t k_grain = arg.KBatch * KPerBlock; @@ -179,10 +184,13 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1201,27 +1215,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -1232,7 +1225,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy - auto a_blockwise_copy = + using ABlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - + BlockwiseGemmPipe::GlobalBufferNum>; // B matrix blockwise copy - auto b_blockwise_copy = + using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + BlockwiseGemmPipe::GlobalBufferNum>; // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1318,186 +1298,230 @@ struct GridwiseGemm_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - // shuffle C and write out + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + using CThreadwiseCopy = + ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + ; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + index_t block_id = get_block_1d_id(); + for(; block_id < total_num_blocks; block_id += global_num_blocks) { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id)); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + auto a_blockwise_copy = + ABlockwiseCopy( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + BBlockwiseCopy( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + auto c_shuffle_block_copy_lds_to_global = CThreadwiseCopy {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op}; - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS block_sync_lds(); @@ -1538,7 +1562,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 const BDataType* p_b_grid, CDataType* p_c_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + const index_t global_num_blocks, + const index_t total_num_blocks) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1562,7 +1588,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + global_num_blocks, + total_num_blocks + ); } template ( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1596,27 +1627,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -1627,7 +1637,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy - auto a_blockwise_copy = + using ABlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + BlockwiseGemmPipe::GlobalBufferNum>; // B matrix blockwise copy - auto b_blockwise_copy = + using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + BlockwiseGemmPipe::GlobalBufferNum>; // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1723,185 +1721,229 @@ struct GridwiseGemm_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - // shuffle C and write out + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + using CThreadwiseCopy = + ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false>; // bool ThreadTransferDstResetCoordinateAfterRun> + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + index_t block_id = get_block_1d_id(); + for(; block_id < total_num_blocks; block_id += global_num_blocks) { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id)); - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + auto a_blockwise_copy = + ABlockwiseCopy( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + BBlockwiseCopy( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + auto c_shuffle_block_copy_lds_to_global = CThreadwiseCopy + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS @@ -1944,7 +1986,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 CDataType* p_c_grid, void* p_shared_0, void* p_shared_1, - const Problem& problem) + const Problem& problem, + const index_t global_num_blocks, + const index_t total_num_blocks + ) { const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1970,7 +2015,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + global_num_blocks, + total_num_blocks + ); } }; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 44f65674be..16f36dabda 100755 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -1,82 +1,82 @@ # ckProfiler set(PROFILER_SOURCES profiler.cpp - profile_gemm.cpp - profile_reduce.cpp - profile_groupnorm_bwd_data.cpp - profile_groupnorm_fwd.cpp - profile_layernorm_bwd_data.cpp - profile_layernorm_bwd_gamma_beta.cpp - profile_groupnorm_bwd_gamma_beta.cpp - profile_layernorm_fwd.cpp - profile_max_pool3d_fwd.cpp - profile_avg_pool3d_bwd.cpp - profile_max_pool3d_bwd.cpp - profile_softmax.cpp - profile_batchnorm_fwd.cpp - profile_batchnorm_bwd.cpp - profile_batchnorm_infer.cpp - profile_conv_tensor_rearrange.cpp - profile_transpose.cpp - profile_permute_scale.cpp + #profile_gemm.cpp + #profile_reduce.cpp + #profile_groupnorm_bwd_data.cpp + #profile_groupnorm_fwd.cpp + #profile_layernorm_bwd_data.cpp + #profile_layernorm_bwd_gamma_beta.cpp + #profile_groupnorm_bwd_gamma_beta.cpp + #profile_layernorm_fwd.cpp + #profile_max_pool3d_fwd.cpp + #profile_avg_pool3d_bwd.cpp + #profile_max_pool3d_bwd.cpp + #profile_softmax.cpp + #profile_batchnorm_fwd.cpp + #profile_batchnorm_bwd.cpp + #profile_batchnorm_infer.cpp + #profile_conv_tensor_rearrange.cpp + #profile_transpose.cpp + #profile_permute_scale.cpp ) -if(GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) - list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) - list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) - -endif() - -if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - endif() - list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) -endif() - -if(DL_KERNELS) - list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) -endif() +#if(GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) +# list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) +# endif() +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) +# endif() +# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_universal_reduce.cpp) +# list(APPEND PROFILER_SOURCES profile_gemm_universal_streamk.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) +# list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd_outelementop.cpp) +# +#endif() +# +#if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) +# endif() +# list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +#endif() +# +#if(DL_KERNELS) +# list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) +# list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) +#endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -84,83 +84,83 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) - -if(GPU_TARGETS MATCHES "gfx9") - if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) -endif() - -if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - endif() - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -endif() - -if(DL_KERNELS) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -endif() +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) +#target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) +# +#if(GPU_TARGETS MATCHES "gfx9") +# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +# endif() +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_tile_loop_instance) +# endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_streamk_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convscale_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance) +#endif() +# +#if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") +# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) +# endif() +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +#endif() +# +#if(DL_KERNELS) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) +# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +#endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)