From 42acf899963895637f048c60c52cf0e8e28f5969 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 10 Jul 2024 08:39:47 +0000 Subject: [PATCH 01/30] driver ok --- driver/CMakeLists.txt | 1 + driver/dm_multimarginloss.cpp | 40 ++ driver/driver.hpp | 6 +- driver/multimarginloss_driver.hpp | 510 ++++++++++++++++++ include/miopen/miopen.h | 96 ++++ src/CMakeLists.txt | 9 + src/include/miopen/multimarginloss.hpp | 74 +++ .../miopen/multimarginloss/invoke_params.hpp | 61 +++ .../multimarginloss/problem_description.hpp | 130 +++++ .../miopen/multimarginloss/solvers.hpp | 77 +++ src/include/miopen/solver_id.hpp | 3 +- src/include/miopen/tensor_view_utils.hpp | 80 +++ src/kernels/MIOpenLossReduce.cpp | 51 ++ src/kernels/MIOpenMultiMarginLoss.cpp | 149 +++++ src/kernels/tensor_view.hpp | 83 +++ src/kernels/warp_shuffle.hpp | 68 +++ src/multimarginloss.cpp | 144 +++++ src/multimarginloss/problem_description.cpp | 51 ++ src/multimarginloss_api.cpp | 141 +++++ src/solver.cpp | 10 + .../forward_reduced_multimarginloss.cpp | 223 ++++++++ .../forward_unreduced_multimarginloss.cpp | 124 +++++ 22 files changed, 2128 insertions(+), 3 deletions(-) create mode 100644 driver/dm_multimarginloss.cpp create mode 100644 driver/multimarginloss_driver.hpp create mode 100644 src/include/miopen/multimarginloss.hpp create mode 100644 src/include/miopen/multimarginloss/invoke_params.hpp create mode 100644 src/include/miopen/multimarginloss/problem_description.hpp create mode 100644 src/include/miopen/multimarginloss/solvers.hpp create mode 100644 src/include/miopen/tensor_view_utils.hpp create mode 100644 src/kernels/MIOpenLossReduce.cpp create mode 100644 src/kernels/MIOpenMultiMarginLoss.cpp create mode 100644 src/kernels/tensor_view.hpp create mode 100644 src/kernels/warp_shuffle.hpp create mode 100644 src/multimarginloss.cpp create mode 100644 src/multimarginloss/problem_description.cpp create mode 100644 src/multimarginloss_api.cpp create mode 100644 src/solver/multimarginloss/forward_reduced_multimarginloss.cpp create mode 100644 src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 8f19a90eb6..d1160be35c 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -44,6 +44,7 @@ add_executable(MIOpenDriver dm_fusion.cpp dm_gemm.cpp dm_groupnorm.cpp + dm_multimarginloss.cpp dm_layernorm.cpp dm_lrn.cpp dm_pool.cpp diff --git a/driver/dm_multimarginloss.cpp b/driver/dm_multimarginloss.cpp new file mode 100644 index 0000000000..0a74712448 --- /dev/null +++ b/driver/dm_multimarginloss.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "multimarginloss_driver.hpp" +#include "registry_driver_maker.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "multimarginloss") + return new MultiMarginLossDriver(); + if(base_arg == "multimarginlossfp16") + return new MultiMarginLossDriver(); + if(base_arg == "multimarginlossbfp16") + return new MultiMarginLossDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index 5bb0a29042..a3f6ec9735 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -174,7 +174,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], " "groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], " - "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16]\n"); + "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], " + "multimarginloss[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -202,7 +203,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "addlayernorm" && arg != "addlayernormfp16" && arg != "addlayernormbfp16" && arg != "t5layernorm" && arg != "t5layernormfp16" && arg != "t5layernormbfp16" && arg != "adam" && arg != "adamfp16" && arg != "ampadam" && arg != "reduceextreme" && - arg != "reduceextremefp16" && arg != "reduceextremebfp16" && arg != "--version") + arg != "reduceextremefp16" && arg != "reduceextremebfp16" && arg != "multimarginloss" && + arg != "multimarginlossfp16" && arg != "multimarginlossbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp new file mode 100644 index 0000000000..58993a4bc1 --- /dev/null +++ b/driver/multimarginloss_driver.hpp @@ -0,0 +1,510 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_MULTIMARGINLOSS_DRIVER_HPP +#define GUARD_MIOPEN_MULTIMARGINLOSS_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "random.hpp" +#include +#include +#include +#include +#include +#include +#include +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> +#include + +template +int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc, + miopenTensorDescriptor_t tDesc, + miopenTensorDescriptor_t wDesc, + miopenTensorDescriptor_t oDesc, + long p, + float margin, + Tgpu* I, + const uint64_t* T, + Tgpu* W, + Tcheck* Ohost) +{ + auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); + auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); + auto W_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(wDesc)); + auto O_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(oDesc)); + auto N = I_tv.size[0], C = I_tv.size[1]; + + int32_t ret = 0; + + for(size_t n = 0; n < N; n++) + { + Tcheck loss = 0; + uint64_t y = T[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + continue; + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + Tcheck t = margin - static_cast(I[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(I[I_tv.get_tensor_view_idx({n, c})]); + + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = W[W_tv.get_tensor_view_idx({y})] * t; + loss += t / static_cast(C); + } + Ohost[O_tv.get_tensor_view_idx({n})] = loss; + } + return ret; +} + +template +int32_t mloMultiMarginLossReducedForwardRunHost(miopenTensorDescriptor_t iDesc, + miopenTensorDescriptor_t tDesc, + miopenTensorDescriptor_t wDesc, + long p, + float margin, + const float divisor, + Tgpu* I, + const uint64_t* T, + Tgpu* W, + Tcheck* Ohost) +{ + auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); + auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); + auto W_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(wDesc)); + auto N = I_tv.size[0], C = I_tv.size[1]; + + int32_t ret = 0; + + for(size_t n = 0; n < N; n++) + { + Tcheck loss = 0; + uint64_t y = T[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + continue; + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + Tcheck t = margin - static_cast(I[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(I[I_tv.get_tensor_view_idx({n, c})]); + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = W[W_tv.get_tensor_view_idx({y})] * t; + loss += t / static_cast(C); + } + Ohost[0] += loss; + } + Ohost[0] /= divisor; + + return ret; +}; + +template +class MultiMarginLossDriver : public Driver +{ +public: + MultiMarginLossDriver() : Driver() + { + miopenCreateTensorDescriptor(&iDesc); + miopenCreateTensorDescriptor(&tDesc); + miopenCreateTensorDescriptor(&wDesc); + miopenCreateTensorDescriptor(&oDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~MultiMarginLossDriver() override + { + miopenDestroyTensorDescriptor(iDesc); + miopenDestroyTensorDescriptor(tDesc); + miopenDestroyTensorDescriptor(wDesc); + miopenDestroyTensorDescriptor(oDesc); + } + +private: + InputFlags inflags; + + // forw = 0 -> run both fw, bw, = 1 -> run only fw, = 2 -> run only bw + int forw; + + miopenTensorDescriptor_t iDesc; + miopenTensorDescriptor_t tDesc; + miopenTensorDescriptor_t wDesc; + miopenTensorDescriptor_t oDesc; + + std::unique_ptr i_dev; + std::unique_ptr t_dev; + std::unique_ptr w_dev; + std::unique_ptr o_dev; + std::unique_ptr workspace_dev; + + std::vector I; + std::vector T; + std::vector W; + std::vector O; + std::vector Ohost; + std::vector workspace; + + long p; + float margin; + miopenLossReductionMode_t reduction_mode; + size_t ws_sizeInBytes; +}; + +template +int MultiMarginLossDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Take (Default=1)", "int"); + inflags.AddInputFlag("dim", 'D', "41x4", "Dim of input tensor (Default=41x4)", "tensor"); + inflags.AddInputFlag("contiguous", 'C', "1", "Tensor is contiguous or not", "int"); + inflags.AddInputFlag("iter", 'i', "1", "Number of Iterations (Default=1)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + inflags.AddInputFlag("reduce", + 'R', + "none", + "Specifies the reduction to apply to the output ('none'|'mean'|'sum') " + "(Default=none to indicate no reduction)", + "str"); + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + auto reduction = inflags.GetValueStr("reduce"); + if(reduction != "none" && reduction != "mean" && reduction != "sum") + return miopenStatusInvalidValue; + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + forw = inflags.GetValueInt("forw"); + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::GetandSetData() +{ + // Set tensor description + std::vector in_len = inflags.GetValueTensor("dim").lengths; + int N = in_len[0], C = in_len[1]; + if(inflags.GetValueInt("contiguous") == 1) + { + SetTensorNd(iDesc, in_len, data_type); + + std::vector t_len = {N}; + SetTensorNd(tDesc, t_len, miopenInt64); + + std::vector w_len = {C}; + SetTensorNd(wDesc, w_len, data_type); + } + else + { + std::vector in_strides(in_len.size()); + in_strides.back() = 1; + for(int i = in_len.size() - 2; i >= 0; --i) + in_strides[i] = in_strides[i + 1] * in_len[i + 1]; + in_strides[0] *= 2; + SetTensorNd(iDesc, in_len, in_strides, data_type); + + std::vector t_len = {N}; + std::vector t_strides = {2}; + SetTensorNd(tDesc, t_len, t_strides, miopenInt64); + + std::vector w_lens = {C}; + std::vector w_strides = {2}; + SetTensorNd(wDesc, w_lens, w_strides, data_type); + } + + // Set p and margin + // p = 1 or 2 + p = prng::gen_A_to_B(static_cast(1), static_cast(3)); + margin = prng::gen_A_to_B(static_cast(0.5), static_cast(1.5)); + + // Set reduction_mode + auto reduction = inflags.GetValueStr("reduce"); + if(reduction == "none") + reduction_mode = MIOPEN_LOSS_REDUCTION_NONE; + else if(reduction == "mean") + reduction_mode = MIOPEN_LOSS_REDUCTION_MEAN; + else if(reduction == "sum") + reduction_mode = MIOPEN_LOSS_REDUCTION_SUM; + + // Set output tensor description (forw = 1 or = 0) + if(forw == 0 || forw == 1) + { + if(reduction == "none") + { + std::vector o_lens = {N}; + SetTensorNd(oDesc, o_lens, data_type); + } + else + { + std::vector o_lens = {1}; + SetTensorNd(oDesc, o_lens, data_type); + } + } + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::AllocateBuffersAndCopy() +{ + uint32_t ctx = 0; + + // for unpacked tensor, we need to use GetTensorSpace instead of GetTensorSize + size_t i_sz = GetTensorSpace(iDesc); + size_t t_sz = GetTensorSpace(tDesc); + size_t w_sz = GetTensorSpace(wDesc); + i_dev = std::unique_ptr(new GPUMem(ctx, i_sz, sizeof(Tgpu))); + t_dev = std::unique_ptr(new GPUMem(ctx, t_sz, sizeof(uint64_t))); + w_dev = std::unique_ptr(new GPUMem(ctx, w_sz, sizeof(Tgpu))); + I = std::vector(i_sz); + T = std::vector(t_sz); + W = std::vector(w_sz); + for(int i = 0; i < i_sz; i++) + { + I[i] = prng::gen_A_to_B(static_cast(-1), static_cast(1)); + } + int C = miopen::deref(iDesc).GetLengths()[1]; + // 0 to C - 1 + for(int i = 0; i < t_sz; i++) + { + T[i] = prng::gen_A_to_B(static_cast(0), static_cast(C)); + } + for(int i = 0; i < w_sz; i++) + { + W[i] = prng::gen_A_to_B(static_cast(-1), static_cast(1)); + } + + if(i_dev->ToGPU(GetStream(), I.data()) != 0) + std::cerr << "Error copying (I) to GPU, size: " << i_dev->GetSize() << std::endl; + + if(t_dev->ToGPU(GetStream(), T.data()) != 0) + std::cerr << "Error copying (T) to GPU, size: " << t_dev->GetSize() << std::endl; + + if(w_dev->ToGPU(GetStream(), W.data()) != 0) + std::cerr << "Error copying (W) to GPU, size: " << w_dev->GetSize() << std::endl; + + if(forw == 0 || forw == 1) + { + size_t o_sz = GetTensorSpace(oDesc); + if(reduction_mode != MIOPEN_LOSS_REDUCTION_NONE) + { + miopenGetMultiMarginLossForwardWorkspaceSize(GetHandle(), + iDesc, + tDesc, + wDesc, + oDesc, + p, + margin, + reduction_mode, + &ws_sizeInBytes); + if(ws_sizeInBytes == static_cast(-1)) + { + return miopenStatusAllocFailed; + } + } + else + ws_sizeInBytes = 0; + + o_dev = std::unique_ptr(new GPUMem(ctx, o_sz, sizeof(Tgpu))); + O = std::vector(o_sz); + Ohost = std::vector(o_sz); + std::fill(O.begin(), O.end(), 0); + std::fill(Ohost.begin(), Ohost.end(), 0); + if(o_dev->ToGPU(GetStream(), O.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << o_dev->GetSize() << std::endl; + + size_t ws_sz = ws_sizeInBytes / sizeof(Tgpu); + workspace_dev = std::unique_ptr(new GPUMem(ctx, ws_sz, sizeof(Tgpu))); + workspace = std::vector(ws_sz); + std::fill(workspace.begin(), workspace.end(), 0); + + if(workspace_dev->ToGPU(GetStream(), workspace.data()) != 0) + std::cerr << "Error copying (workspace) to GPU, size: " << workspace_dev->GetSize() + << std::endl; + } + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenMultiMarginLossForward( + GetHandle(), + iDesc, + i_dev->GetMem(), + tDesc, + t_dev->GetMem(), + wDesc, + w_dev->GetMem(), + oDesc, + o_dev->GetMem(), + p, + margin, + reduction_mode, + (reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) ? nullptr : workspace_dev->GetMem(), + ws_sizeInBytes); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Forward MultiMarginLoss Elapsed: " + << t.gettime_ms() / iter << " ms" << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Forward MultiMarginLoss Elapsed: " << kernel_average_time + << " ms" << std::endl; + } + + if(o_dev->FromGPU(GetStream(), O.data()) != 0) + std::cerr << "Error copying (o_dev) from GPU, size: " << o_dev->GetSize() << std::endl; + if(workspace_dev->FromGPU(GetStream(), workspace.data()) != 0) + std::cerr << "Error copying (workspace_dev) from GPU, size: " << workspace_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::RunForwardCPU() +{ + if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + { + mloMultiMarginLossUnreducedForwardRunHost( + iDesc, tDesc, wDesc, oDesc, p, margin, I.data(), T.data(), W.data(), Ohost.data()); + } + else + { + float divisor = (reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) + ? miopen::deref(iDesc).GetLengths()[0] + : 1; + mloMultiMarginLossReducedForwardRunHost( + iDesc, tDesc, wDesc, p, margin, divisor, I.data(), T.data(), W.data(), Ohost.data()); + } + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref MultiMarginLossDriver::GetTolerance() +{ + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; +} + +template +int MultiMarginLossDriver::VerifyForward() +{ + RunForwardCPU(); + + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(Ohost, O); + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward MultiMarginLoss FAILED: " << error << " > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward MultiMarginLoss Verifies OK on CPU reference (" << error << " < " + << tolerance << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_MULTIMARGINLOSS_DRIVER_HPP diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 9821b94912..3cfe97c2ee 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7221,6 +7221,102 @@ miopenFusedAdamWithOutput(miopenHandle_t handle, // CLOSEOUT SGD DOXYGEN GROUP #endif // MIOPEN_BETA_API +#ifdef MIOPEN_BETA_API + +/*! @ingroup LossFunction + * @enum miopenLossReductionMode_t + * Reduction mode for loss function + */ +typedef enum +{ + MIOPEN_LOSS_REDUCTION_NONE = 0, /*!< output tensor elements are not reduced */ + MIOPEN_LOSS_REDUCTION_SUM = 1, /*!< output tensor elements are summed up */ + MIOPEN_LOSS_REDUCTION_MEAN = 2, /*!< output tensor elements are summed up and divided with total + number of elements to get mean value */ +} miopenLossReductionMode_t; + +// MultiMarginLoss APIs +/** @addtogroup LossFunction + * + * @{ + */ + +/*! @brief Helper function to query the minimum workspace size required by the +MultiMarginLossForward call + * + * @param [in] handle MIOpen Handle + * @param [in] inputDesc Tensor descriptor for input tensor (N, C) where N is the batch +size and C is the number of classes + * @param [in] targetDesc Tensor descriptor for target tensor, must have shape (N). Each +value is between 0 and C - 1 + * @param [in] weightDesc Tensor descriptor for weight tensor. It is a manual rescaling +weight given to each class. It has to be a Tensor of size C + * @param [in] outputDesc Tensor descriptor for output tensor. If reduction is 'none, +then it must have shape (N). Otherwise, it is a scalar + * @param [in] p Has a default value of 1. The only supported values are 1 and 2 + * @param [in] margin Has a default value of 1 + * @param [in] reduction Reduction mode (sum, mean). For none reduction we don't need to +use this function + * @param [out] sizeInBytes Pointer to data to return the minimum workspace size + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t +miopenGetMultiMarginLossForwardWorkspaceSize(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + miopenTensorDescriptor_t targetDesc, + miopenTensorDescriptor_t weightDesc, + miopenTensorDescriptor_t outputDesc, + long p, + float margin, + miopenLossReductionMode_t reduction, + size_t* sizeInBytes); + +/*! @brief Execute a MultiMarginLoss forward layer + * + * @param [in] handle MIOpen handle + * @param [in] inputDesc Tensor descriptor for input tensor (N, C) where N is the +batch size and C is the number of classes. + * @param [in] input Data tensor input + * @param [in] targetDesc Tensor descriptor for target tensor, must have shape (N). +Each value is between 0 and C - 1 + * @param [in] target Data tensor target + * @param [in] weightDesc Tensor descriptor for weight tensor. It is a manual +rescaling weight given to each class. It has to be a Tensor of size C + * @param [in] weight Data tensor weight + * @param [in] outputDesc Tensor descriptor for output tensor. If reduction is 'none, +then it must have shape (N). Otherwise, it is a scalar. + * @param [out] output Data tensor output + * @param [in] p Has a default value of 1. The only supported values are 1 +and 2 + * @param [in] margin Has a default value of 1 + * @param [in] reduction Reduction mode. If reduction mode is mean or sum, you must + * provide param workspace and workspaceSizeInBytes. Call + * miopenGetMultiMarginLossForwardWorkspaceSize to get workspaceSizeInBytes + * @param [in] workspace Address of the allocated workspace data. Set = nullptr if +reduction = 'none' + * @param [in] workspaceSizeInBytes Size in bytes of the allocated workspace data. Set = 0 if +reduction = 'none + * @return miopenStatus_t + */ +miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t weightDesc, + const void* weight, + miopenTensorDescriptor_t outputDesc, + void* output, + long p, + float margin, + miopenLossReductionMode_t reduction, + void* workspace, + size_t workspaceSizeInBytes); + +/** @} */ +// CLOSEOUT LossFunction DOXYGEN GROUP +#endif + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1cb8a1fb0c..9fe5940d27 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -154,6 +154,8 @@ set( MIOpen_Source lrn_api.cpp mha/mha_descriptor.cpp mha/problem_description.cpp + multimarginloss/problem_description.cpp + multimarginloss_api.cpp op_args.cpp operator.cpp performance_config.cpp @@ -284,6 +286,8 @@ set( MIOpen_Source solver/layernorm/forward_t5layernorm.cpp solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp + solver/multimarginloss/forward_reduced_multimarginloss.cpp + solver/multimarginloss/forward_unreduced_multimarginloss.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -446,7 +450,9 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/neuron.inc kernels/rocm_version.inc kernels/stride_array.hpp + kernels/tensor_view.hpp kernels/utilities.inc + kernels/warp_shuffle.hpp kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c16_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c32_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1024vgprs_fp16_fp16acc_f2x3_c16_stride1.inc @@ -485,8 +491,10 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConvDirGenFwd.cl kernels/MIOpenGroupNorm.cpp kernels/MIOpenLayerNorm.cpp + kernels/MIOpenLossReduce.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl + kernels/MIOpenMultiMarginLoss.cpp kernels/MIOpenNeuron.cl kernels/MIOpenPooling.cl kernels/MIOpenPoolingBwd.cl @@ -616,6 +624,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN lrn.cpp mlo_dir_conv.cpp exec_utils.cpp + multimarginloss.cpp ocl/activ_ocl.cpp ocl/batchnormocl.cpp ocl/convolutionocl.cpp diff --git a/src/include/miopen/multimarginloss.hpp b/src/include/miopen/multimarginloss.hpp new file mode 100644 index 0000000000..fc24ccecce --- /dev/null +++ b/src/include/miopen/multimarginloss.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_MULTIMARGINLOSS_HPP_ +#define MIOPEN_MULTIMARGINLOSS_HPP_ + +#include "miopen/miopen.h" +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& iDesc, + const TensorDescriptor& tDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& oDesc, + long p, + float margin, + miopenLossReductionMode_t reduction); + +miopenStatus_t MultiMarginLossUnreducedForward(Handle& handle, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + long p, + float margin); + +miopenStatus_t MultiMarginLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + long p, + float margin, + miopenLossReductionMode_t reduction); + +} // namespace miopen +#endif // _MIOPEN_MULTIMARGINLOSS_HPP_ diff --git a/src/include/miopen/multimarginloss/invoke_params.hpp b/src/include/miopen/multimarginloss/invoke_params.hpp new file mode 100644 index 0000000000..24e323946f --- /dev/null +++ b/src/include/miopen/multimarginloss/invoke_params.hpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +namespace miopen { +namespace multimarginloss { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* iDesc = nullptr; + const TensorDescriptor* tDesc = nullptr; + const TensorDescriptor* wDesc = nullptr; + const TensorDescriptor* oDesc = nullptr; + + ConstData_t i = nullptr; + ConstData_t t = nullptr; + ConstData_t w = nullptr; + Data_t o = nullptr; + + long p; + float margin; + float divisor = 0; + + Data_t workspace = nullptr; + std::size_t workspace_size = 0; + std::size_t GetWorkspaceSize() const { return workspace_size; } + Data_t GetWorkspace() const { return workspace; } +}; + +} // namespace multimarginloss + +} // namespace miopen diff --git a/src/include/miopen/multimarginloss/problem_description.hpp b/src/include/miopen/multimarginloss/problem_description.hpp new file mode 100644 index 0000000000..10d11c0a07 --- /dev/null +++ b/src/include/miopen/multimarginloss/problem_description.hpp @@ -0,0 +1,130 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "miopen/miopen.h" +#include +#include +#include +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace multimarginloss { + +struct ForwardProblemDescription : ProblemDescriptionBase +{ + ForwardProblemDescription(const TensorDescriptor& iDesc_, + const TensorDescriptor& tDesc_, + const TensorDescriptor& wDesc_, + const TensorDescriptor& oDesc_, + const long p_, + const float margin_, + const float divisor_) + : iDesc(iDesc_), + tDesc(tDesc_), + wDesc(wDesc_), + oDesc(oDesc_), + p(p_), + margin(margin_), + divisor(divisor_) + { + if(iDesc.GetType() != oDesc.GetType() || iDesc.GetType() != wDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Input, output, weight tensor types do not match."); + } + if(tDesc.GetType() != miopenInt64) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Target tensor type should be miopenInt64."); + } + if(iDesc.GetNumDims() != 2) + { + MIOPEN_THROW(miopenStatusBadParm, "MultiMarginLoss: Input tensor need to be 2D tensor"); + } + if(tDesc.GetNumDims() != 1 || tDesc.GetLengths()[0] != iDesc.GetLengths()[0]) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Target tensor need to be 1D tensor. If input " + "tensor has shape (N, C) then target tensor must have shape (N)"); + } + if(wDesc.GetNumDims() != 1 || wDesc.GetLengths()[0] != iDesc.GetLengths()[1]) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Weight tensor need to be 1D tensor. If input " + "tensor has shape (N, C) then weight tensor must have shape (C)"); + } + // Check output tensor dimension + if(divisor == 0) + { + // non-reduction case + if(oDesc.GetNumDims() != 1 || oDesc.GetLengths()[0] != iDesc.GetLengths()[0]) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Output tensor need to be " + "1D tensor. If input " + "tensor has shape (N, C) then output tensor must have shape (N)"); + } + } + else + { + // reduction case + if(oDesc.GetNumDims() != 1 || oDesc.GetLengths()[0] != 1) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Output tensor need to be a scalar."); + } + } + } + + const TensorDescriptor& GetiDesc() const { return iDesc; } + const TensorDescriptor& GettDesc() const { return tDesc; } + const TensorDescriptor& GetwDesc() const { return wDesc; } + const TensorDescriptor& GetoDesc() const { return oDesc; } + long Getp() const { return p; } + float Getmargin() const { return margin; } + float Getdivisor() const { return divisor; } + + NetworkConfig MakeNetworkConfig() const override; + +private: + TensorDescriptor iDesc; + TensorDescriptor tDesc; + TensorDescriptor wDesc; + TensorDescriptor oDesc; + long p; + float margin; + float divisor; +}; + +} // namespace multimarginloss + +} // namespace miopen diff --git a/src/include/miopen/multimarginloss/solvers.hpp b/src/include/miopen/multimarginloss/solvers.hpp new file mode 100644 index 0000000000..1a018b1d6a --- /dev/null +++ b/src/include/miopen/multimarginloss/solvers.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace multimarginloss { + +using ForwardMultiMarginLossSolver = + NonTunableSolverBase; + +struct MultiMarginLossUnreducedForward final : ForwardMultiMarginLossSolver +{ + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } + bool + IsApplicable(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; +}; + +struct MultiMarginLossForward final : ForwardMultiMarginLossSolver +{ + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } + bool + IsApplicable(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + std::size_t GetWorkspaceSize( + const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + bool MayNeedWorkspace() const override { return true; } +}; + +} // namespace multimarginloss + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index 6ddd83bcef..0d2ae9cd34 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -58,7 +58,8 @@ enum class Primitive Cat, Mha, Softmax, - Adam + Adam, + MultiMarginLoss, }; struct MIOPEN_INTERNALS_EXPORT Id diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp new file mode 100644 index 0000000000..afeaaeea78 --- /dev/null +++ b/src/include/miopen/tensor_view_utils.hpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef MIOPEN_TENSOR_VIEW_UTIL_HPP_ +#define MIOPEN_TENSOR_VIEW_UTIL_HPP_ + +#include "../../kernels/tensor_view.hpp" +#include "miopen/tensor.hpp" + +namespace miopen { + +template +inline tensor_view_t get_inner_expanded_tv(const TensorDescriptor Desc) +{ + auto dims = Desc.GetLengths(); + auto strides = Desc.GetStrides(); + + tensor_view_t tensor_view; + for(size_t i = 0; i < N; ++i) + { + if(i < dims.size()) + { + tensor_view.stride[i] = strides[i]; + tensor_view.size[i] = dims[i]; + } + else + { + tensor_view.stride[i] = (i == 0 ? 1 : strides[i - 1]); + tensor_view.size[i] = 1; + } + } + return tensor_view; +} + +template +inline void slice_tv(tensor_view_t& tensor_view, int32_t sliceCount, const int32_t* slices) +{ + for(int32_t i = 0; i < sliceCount; i++) + { + int32_t dim = slices[4 * i + 0]; + int32_t start = slices[4 * i + 1]; + int32_t end = slices[4 * i + 2]; + int32_t step = slices[4 * i + 3]; + + if(end > static_cast(tensor_view.size[dim])) + end = tensor_view.size[dim]; + + auto len = end - start; + + tensor_view.size[dim] = (len + step - 1) / step; + tensor_view.stride[dim] *= step; + } +} + +} // namespace miopen + +#endif // MIOPEN_TENSOR_VIEW_UTIL_HPP_ diff --git a/src/kernels/MIOpenLossReduce.cpp b/src/kernels/MIOpenLossReduce.cpp new file mode 100644 index 0000000000..023af37626 --- /dev/null +++ b/src/kernels/MIOpenLossReduce.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" +#include "warp_shuffle.hpp" + +template +__device__ void LossSum(const DTYPE* __restrict__ input, DTYPE* __restrict__ output, size_t N) +{ + auto gid = blockIdx.x * blockDim.x + threadIdx.x; + + FLOAT_ACCUM val = gid < N ? CVT_FLOAT2ACCUM(input[gid]) : CVT_FP32_2ACCUM(0.0f); + val = block_reduce_sum(val); + + if(threadIdx.x == 0) + output[blockIdx.x] = CVT_ACCUM2FLOAT(val); +} + +extern "C" __global__ void +ReduceSumLoss(const FLOAT* __restrict__ input, FLOAT* __restrict__ output, size_t N) +{ + // instantiate the kernel + LossSum(input, output, N); +} diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp new file mode 100644 index 0000000000..6855270329 --- /dev/null +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -0,0 +1,149 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" +#include "tensor_view.hpp" + +template +__device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, + const uint64_t* __restrict__ T, + const DTYPE* __restrict__ W, + DTYPE* __restrict__ O, + long p, + float margin, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv, + tensor_view_t<1> O_tv) +{ + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + size_t N = I_tv.size[0], C = I_tv.size[1]; + size_t n = gid; + if(n >= N) + return; + + FLOAT_ACCUM loss = 0; + uint64_t y = T[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + { + // TODO: need to handle invalid target index value + return; + } + + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + FLOAT_ACCUM t = margin - CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, y})]) + + CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, c})]); + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; + loss += t / (float)C; + } + O[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); +} + +extern "C" __global__ void MultiMarginLossUnreducedForward2d(const FLOAT* __restrict__ I, + const uint64_t* __restrict__ T, + const FLOAT* __restrict__ W, + FLOAT* __restrict__ O, + long p, + float margin, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv, + tensor_view_t<1> O_tv) +{ + // instantiate the kernel + multimarginlossunreducedforward2d(I, T, W, O, p, margin, I_tv, T_tv, W_tv, O_tv); +} + +template +__device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, + const uint64_t* __restrict__ T, + const DTYPE* __restrict__ W, + DTYPE* __restrict__ lsum, + long p, + float margin, + const float divisor, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv) +{ + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + size_t N = I_tv.size[0], C = I_tv.size[1]; + size_t n = gid; + if(n >= N) + return; + + FLOAT_ACCUM loss = 0; + long y = T[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + { + // TODO: need to handle invalid target index value + return; + } + + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + FLOAT_ACCUM t = margin - CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, y})]) + + CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, c})]); + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; + loss += t / (float)C; + } + + lsum[n] = CVT_ACCUM2FLOAT(loss / divisor); +} + +extern "C" __global__ void MultiMarginLossForward2d(const FLOAT* __restrict__ I, + const uint64_t* __restrict__ T, + const FLOAT* __restrict__ W, + FLOAT* __restrict__ lsum, + long p, + float margin, + const float divisor, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv) +{ + // instantiate the kernel + multimarginlossforward2d(I, T, W, lsum, p, margin, divisor, I_tv, T_tv, W_tv); +} \ No newline at end of file diff --git a/src/kernels/tensor_view.hpp b/src/kernels/tensor_view.hpp new file mode 100644 index 0000000000..d7294d8992 --- /dev/null +++ b/src/kernels/tensor_view.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef GUARD_TENSOR_VIEW_HPP +#define GUARD_TENSOR_VIEW_HPP + +template +struct tensor_layout_t; + +template +struct tensor_view_t +{ + constexpr uint64_t get_tensor_view_idx(const uint64_t (&layout)[N]) + { + static_assert(N > 0); + uint64_t idx = 0; + for(auto i = 0; i < N; ++i) + { + idx += stride[i] * layout[i]; + } + return idx; + } + + constexpr uint64_t get_tensor_view_idx(const tensor_layout_t& tensor_layout) + { + return get_tensor_view_idx(tensor_layout.layout); + } + + uint64_t stride[N]; + uint64_t size[N]; +}; + +template +struct tensor_layout_t +{ + // Make tensor layout at index using tensor view + constexpr tensor_layout_t(const tensor_view_t& tensor_view, uint64_t idx) + { + static_assert(N > 0); + uint64_t temp = idx; + if constexpr(N == 1) + { + layout[0] = idx; + } + else + { + for(auto i = N - 1; i > 1; --i) + { + layout[i] = temp % tensor_view.size[i]; + temp = temp / tensor_view.size[i]; + } + layout[1] = temp % tensor_view.size[1]; + layout[0] = temp / tensor_view.size[1]; + } + } + + uint64_t layout[N]; +}; + +#endif // GUARD_TENSOR_VIEW_HPP diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/warp_shuffle.hpp new file mode 100644 index 0000000000..3aa75660f6 --- /dev/null +++ b/src/kernels/warp_shuffle.hpp @@ -0,0 +1,68 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +__device__ FLOAT_ACCUM warp_reduce_sum(FLOAT_ACCUM val) +{ + if(warpSize >= 64) + val += __shfl_down(val, 32); + if(warpSize >= 32) + val += __shfl_down(val, 16); + if(warpSize >= 16) + val += __shfl_down(val, 8); + if(warpSize >= 8) + val += __shfl_down(val, 4); + if(warpSize >= 4) + val += __shfl_down(val, 2); + if(warpSize >= 2) + val += __shfl_down(val, 1); + return val; +} + +__device__ FLOAT_ACCUM block_reduce_sum(FLOAT_ACCUM val) +{ + static __shared__ FLOAT_ACCUM shared[REDUCE_SIZE / warpSize]; + auto lane = threadIdx.x % warpSize; + auto wid = threadIdx.x / warpSize; + + val = warp_reduce_sum(val); + + if(lane == 0) + shared[wid] = val; + __syncthreads(); + + val = threadIdx.x < REDUCE_SIZE / warpSize ? shared[lane] : 0; + if(wid == 0) + val = warp_reduce_sum(val); + + return val; +} \ No newline at end of file diff --git a/src/multimarginloss.cpp b/src/multimarginloss.cpp new file mode 100644 index 0000000000..fe9c041d1a --- /dev/null +++ b/src/multimarginloss.cpp @@ -0,0 +1,144 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t MultiMarginLossUnreducedForward(Handle& handle, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + const long p, + const float margin) +{ + const auto problem = + multimarginloss::ForwardProblemDescription{iDesc, tDesc, wDesc, oDesc, p, margin, 0}; + + const auto invoke_params = [&]() { + auto tmp = multimarginloss::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.iDesc = &iDesc; + tmp.i = i; + tmp.tDesc = &tDesc; + tmp.t = t; + tmp.wDesc = &wDesc; + tmp.w = w; + tmp.oDesc = &oDesc; + tmp.o = o; + tmp.p = p; + tmp.margin = margin; + return tmp; + }(); + + const auto algo = AlgorithmName{"MultiMarginLossUnreducedForward"}; + const auto solvers = + solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& iDesc, + const TensorDescriptor& tDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& oDesc, + const long p, + const float margin, + miopenLossReductionMode_t reduction) +{ + auto ctx = ExecutionContext{&handle}; + const float divisor = (reduction == MIOPEN_LOSS_REDUCTION_MEAN) ? iDesc.GetLengths()[0] : 1; + const auto problem = + multimarginloss::ForwardProblemDescription{iDesc, tDesc, wDesc, oDesc, p, margin, divisor}; + + const auto solvers = solver::SolverContainer{}; + + auto pair_size_vector = solvers.GetWorkspaceSizes(ctx, problem); + return pair_size_vector.empty() ? static_cast(-1) : pair_size_vector.front().second; +} + +miopenStatus_t MultiMarginLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + const long p, + const float margin, + miopenLossReductionMode_t reduction) +{ + const float divisor = (reduction == MIOPEN_LOSS_REDUCTION_MEAN) ? iDesc.GetLengths()[0] : 1; + const auto problem = + multimarginloss::ForwardProblemDescription{iDesc, tDesc, wDesc, oDesc, p, margin, divisor}; + + const auto invoke_params = [&]() { + auto tmp = multimarginloss::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.iDesc = &iDesc; + tmp.i = i; + tmp.tDesc = &tDesc; + tmp.t = t; + tmp.wDesc = &wDesc; + tmp.w = w; + tmp.oDesc = &oDesc; + tmp.o = o; + tmp.p = p; + tmp.margin = margin; + tmp.workspace = workspace; + tmp.workspace_size = workspaceSizeInBytes; + tmp.divisor = divisor; + return tmp; + }(); + + const auto algo = AlgorithmName{"MultiMarginLossForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/multimarginloss/problem_description.cpp b/src/multimarginloss/problem_description.cpp new file mode 100644 index 0000000000..295169ff36 --- /dev/null +++ b/src/multimarginloss/problem_description.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include + +#include + +namespace miopen { + +namespace multimarginloss { + +NetworkConfig ForwardProblemDescription::MakeNetworkConfig() const +{ + // TODO: edit after finish solver + std::ostringstream ss; + ss << "multilmarginloss_fwd"; + ss << "itype" << iDesc.GetType(); + ss << "ilen"; + auto ilen = iDesc.GetLengths(); + for(int32_t i = 0; i < ilen.size(); i++) + ss << ilen[i] << "_"; + ss << "divisor" << divisor; + return NetworkConfig{ss.str()}; +} + +} // namespace multimarginloss + +} // namespace miopen diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp new file mode 100644 index 0000000000..6ef96d400b --- /dev/null +++ b/src/multimarginloss_api.cpp @@ -0,0 +1,141 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "miopen/miopen.h" +#include +#include +#include +#include +#include + +extern "C" miopenStatus_t +miopenGetMultiMarginLossForwardWorkspaceSize(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + miopenTensorDescriptor_t targetDesc, + miopenTensorDescriptor_t weightDesc, + miopenTensorDescriptor_t outputDesc, + const long p, + const float margin, + miopenLossReductionMode_t reduction, + size_t* sizeInBytes) +{ + MIOPEN_LOG_FUNCTION( + handle, inputDesc, targetDesc, weightDesc, outputDesc, p, margin, reduction); + + if(reduction != MIOPEN_LOSS_REDUCTION_SUM && reduction != MIOPEN_LOSS_REDUCTION_MEAN) + { + MIOPEN_THROW(miopenStatusBadParm, + "miopenGetMultiMarginLossForwardWorkspaceSize: reduction should be " + "MIOPEN_LOSS_REDUCTION_SUM or MIOPEN_LOSS_REDUCTION_MEAN."); + } + else + { + return miopen::try_([&] { + miopen::deref(sizeInBytes) = + miopen::GetMultiMarginLossForwardWorkspaceSize(miopen::deref(handle), + miopen::deref(inputDesc), + miopen::deref(targetDesc), + miopen::deref(weightDesc), + miopen::deref(outputDesc), + p, + margin, + reduction); + }); + } +} + +miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t weightDesc, + const void* weight, + miopenTensorDescriptor_t outputDesc, + void* output, + const long p, + const float margin, + miopenLossReductionMode_t reduction, + void* workspace, + size_t workspaceSizeInBytes) +{ + MIOPEN_LOG_FUNCTION(handle, + inputDesc, + input, + targetDesc, + target, + weightDesc, + weight, + outputDesc, + output, + p, + margin, + reduction, + workspace, + workspaceSizeInBytes); + + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + return miopen::try_([&] { + miopen::MultiMarginLossUnreducedForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + p, + margin); + }); + } + else if(reduction == MIOPEN_LOSS_REDUCTION_SUM || reduction == MIOPEN_LOSS_REDUCTION_MEAN) + { + return miopen::try_([&] { + miopen::MultiMarginLossForward(miopen::deref(handle), + DataCast(workspace), + workspaceSizeInBytes, + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + p, + margin, + reduction); + }); + } + else + { + MIOPEN_THROW(miopenStatusBadParm, + "miopenMultiMarginLossForward: reduction should be " + "MIOPEN_LOSS_REDUCTION_NONE, " + "MIOPEN_LOSS_REDUCTION_SUM or MIOPEN_LOSS_REDUCTION_MEAN."); + } +} diff --git a/src/solver.cpp b/src/solver.cpp index 2cec8a1e3c..1501371e5e 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include @@ -669,6 +670,15 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) fusion::ConvWinoFuryRxSFused<2, 3>{}.SolverDbId(), miopenConvolutionAlgoWinograd); + Register(registry, + ++id, + Primitive::MultiMarginLoss, + multimarginloss::MultiMarginLossUnreducedForward{}.SolverDbId()); + Register(registry, + ++id, + Primitive::MultiMarginLoss, + multimarginloss::MultiMarginLossForward{}.SolverDbId()); + // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp new file mode 100644 index 0000000000..75857de821 --- /dev/null +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -0,0 +1,223 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/miopen.h" +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE_MULTIMARGINLOSS 256 +#define LOCAL_SIZE_REDUCE 256 + +namespace miopen { + +namespace solver { + +namespace multimarginloss { + +bool MultiMarginLossForward::IsApplicable( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + // TODO: edit later + // if(problem.GetiDesc().GetLengths()[1] > 24) + // return false; + return true; +} + +ConvSolution MultiMarginLossForward::GetSolution( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + auto result = ConvSolution{miopenStatusSuccess}; + + auto xgrid = problem.GetiDesc().GetLengths()[0]; + auto dtype = problem.GetiDesc().GetType(); + + { + /* Phase 1: Calc loss for each element. */ + size_t xlocalsize = LOCAL_SIZE_MULTIMARGINLOSS; + size_t xgridsize = AlignUp(xgrid, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenMultiMarginLoss.cpp"; + kernel.kernel_name = "MultiMarginLossForward2d"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + { + /* Phase 2: Reduce */ + auto _size = xgrid; + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"REDUCE_SIZE", LOCAL_SIZE_REDUCE}, + }; + do + { + size_t xlocalsize = LOCAL_SIZE_REDUCE; + size_t xgridsize = AlignUp(_size, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenLossReduce.cpp"; + kernel.kernel_name = "ReduceSumLoss"; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + _size = AlignUp(_size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; + } while(_size > 1); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) params = raw_params.CastTo(); + auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); + auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); + auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + + float elapsed = 0.0f; + HipEventPtr start; + HipEventPtr stop; + + bool reset_profiling_state = false; + if(handle_.IsProfilingEnabled()) + { + reset_profiling_state = true; + handle_.EnableProfiling(false); + start = miopen::make_hip_event(); + stop = miopen::make_hip_event(); + hipEventRecord(start.get(), handle_.GetStream()); + } + /* Phase 1: Calc loss for each element. */ + { + decltype(auto) kernel = handle_.Run(kernels.front()); + kernel(params.i, + params.t, + params.w, + params.workspace, + params.p, + params.margin, + params.divisor, + i_tv, + t_tv, + w_tv); + } + + /* Phase 2: Reduce */ + auto size = deref(params.iDesc).GetLengths()[0]; + auto reduce_in = params.workspace; + auto reduce_out = + static_cast(static_cast(params.workspace) + + size * get_data_size(deref(params.oDesc).GetType())); + + for(int i = 1; i < kernels.size(); ++i) + { + decltype(auto) kernel = handle_.Run(kernels[i]); + if(i + 1 != kernels.size()) + { + kernel(reduce_in, reduce_out, size); + std::swap(reduce_in, reduce_out); + } + else + { + kernel(reduce_in, params.o, size); + } + size = AlignUp(size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; + } + + if(reset_profiling_state) + handle_.EnableProfiling(true); + if(handle_.IsProfilingEnabled()) + { + hipEventRecord(stop.get(), handle_.GetStream()); + hipEventSynchronize(stop.get()); + hipEventElapsedTime(&elapsed, start.get(), stop.get()); + + // Clean up + hipEventDestroy(start.get()); + hipEventDestroy(stop.get()); + handle_.ResetKernelTime(); + handle_.AccumKernelTime(elapsed); + }; + }; + }; + + return result; +} + +std::size_t MultiMarginLossForward::GetWorkspaceSize( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + auto elem = problem.GetiDesc().GetLengths()[0]; + return (elem + (elem + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE) * + get_data_size(problem.GetoDesc().GetType()); +} + +} // namespace multimarginloss + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp new file mode 100644 index 0000000000..a585f66cb2 --- /dev/null +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -0,0 +1,124 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/miopen.h" +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace multimarginloss { + +bool MultiMarginLossUnreducedForward::IsApplicable( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + // TODO: edit later + // if(problem.GetiDesc().GetLengths()[1] > 24) + // return false; + return true; +} + +ConvSolution MultiMarginLossUnreducedForward::GetSolution( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + auto result = ConvSolution{miopenStatusSuccess}; + + auto xgrid = problem.GetiDesc().GetLengths()[0]; + + { + auto dtype = problem.GetiDesc().GetType(); + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(xgrid, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenMultiMarginLoss.cpp"; + kernel.kernel_name = "MultiMarginLossUnreducedForward2d"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); + auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); + auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); + + kernel(params.i, + params.t, + params.w, + params.o, + params.p, + params.margin, + i_tv, + t_tv, + w_tv, + o_tv); + }; + }; + + return result; +} + +} // namespace multimarginloss + +} // namespace solver + +} // namespace miopen From b26abf97f9a4856c972782f00fe2b05735eb66cc Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 10 Jul 2024 10:37:18 +0000 Subject: [PATCH 02/30] add gtest --- driver/multimarginloss_driver.hpp | 40 ++--- test/cpu_multimarginloss.hpp | 112 +++++++++++++ test/gtest/multimarginloss.cpp | 110 +++++++++++++ test/gtest/multimarginloss.hpp | 259 ++++++++++++++++++++++++++++++ test/tensor_holder.hpp | 5 + 5 files changed, 506 insertions(+), 20 deletions(-) create mode 100644 test/cpu_multimarginloss.hpp create mode 100644 test/gtest/multimarginloss.cpp create mode 100644 test/gtest/multimarginloss.hpp diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 58993a4bc1..20af57a487 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -49,10 +49,10 @@ int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc miopenTensorDescriptor_t oDesc, long p, float margin, - Tgpu* I, - const uint64_t* T, - Tgpu* W, - Tcheck* Ohost) + Tgpu* input, + const uint64_t* target, + Tgpu* weight, + Tcheck* ref_output) { auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); @@ -65,24 +65,24 @@ int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc for(size_t n = 0; n < N; n++) { Tcheck loss = 0; - uint64_t y = T[T_tv.get_tensor_view_idx({n})]; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; if(y >= C) continue; for(size_t c = 0; c < C; c++) { if(y == c) continue; - Tcheck t = margin - static_cast(I[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(I[I_tv.get_tensor_view_idx({n, c})]); + Tcheck t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; - t = W[W_tv.get_tensor_view_idx({y})] * t; + t = weight[W_tv.get_tensor_view_idx({y})] * t; loss += t / static_cast(C); } - Ohost[O_tv.get_tensor_view_idx({n})] = loss; + ref_output[O_tv.get_tensor_view_idx({n})] = loss; } return ret; } @@ -94,10 +94,10 @@ int32_t mloMultiMarginLossReducedForwardRunHost(miopenTensorDescriptor_t iDesc, long p, float margin, const float divisor, - Tgpu* I, - const uint64_t* T, - Tgpu* W, - Tcheck* Ohost) + Tgpu* input, + const uint64_t* target, + Tgpu* weight, + Tcheck* ref_output) { auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); @@ -109,25 +109,25 @@ int32_t mloMultiMarginLossReducedForwardRunHost(miopenTensorDescriptor_t iDesc, for(size_t n = 0; n < N; n++) { Tcheck loss = 0; - uint64_t y = T[T_tv.get_tensor_view_idx({n})]; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; if(y >= C) continue; for(size_t c = 0; c < C; c++) { if(y == c) continue; - Tcheck t = margin - static_cast(I[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(I[I_tv.get_tensor_view_idx({n, c})]); + Tcheck t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; - t = W[W_tv.get_tensor_view_idx({y})] * t; + t = weight[W_tv.get_tensor_view_idx({y})] * t; loss += t / static_cast(C); } - Ohost[0] += loss; + ref_output[0] += loss; } - Ohost[0] /= divisor; + ref_output[0] /= divisor; return ret; }; @@ -242,7 +242,7 @@ int MultiMarginLossDriver::GetandSetData() { // Set tensor description std::vector in_len = inflags.GetValueTensor("dim").lengths; - int N = in_len[0], C = in_len[1]; + size_t N = in_len[0], C = in_len[1]; if(inflags.GetValueInt("contiguous") == 1) { SetTensorNd(iDesc, in_len, data_type); diff --git a/test/cpu_multimarginloss.hpp b/test/cpu_multimarginloss.hpp new file mode 100644 index 0000000000..8e790feb5a --- /dev/null +++ b/test/cpu_multimarginloss.hpp @@ -0,0 +1,112 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef GUARD_CPU_MULTIMARGINLOSS_HPP +#define GUARD_CPU_MULTIMARGINLOSS_HPP + +#include "tensor_holder.hpp" +#include + +template +void cpu_multimarginloss_unreduced_forward(tensor input, + tensor target, + tensor weight, + tensor& ref_output, + long p, + float margin) +{ + auto I_tv = miopen::get_inner_expanded_tv<2>(input.desc); + auto T_tv = miopen::get_inner_expanded_tv<1>(target.desc); + auto W_tv = miopen::get_inner_expanded_tv<1>(weight.desc); + auto O_tv = miopen::get_inner_expanded_tv<1>(ref_output.desc); + auto N = I_tv.size[0], C = I_tv.size[1]; + + for(size_t n = 0; n < N; n++) + { + float loss = 0; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + continue; + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + float t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = weight[W_tv.get_tensor_view_idx({y})] * t; + loss += t / static_cast(C); + } + ref_output[O_tv.get_tensor_view_idx({n})] = loss; + } +} + +template +void cpu_multimarginloss_reduced_forward(tensor input, + tensor target, + tensor weight, + tensor& ref_output, + long p, + float margin, + const float divisor) +{ + auto I_tv = miopen::get_inner_expanded_tv<2>(input.desc); + auto T_tv = miopen::get_inner_expanded_tv<1>(target.desc); + auto W_tv = miopen::get_inner_expanded_tv<1>(weight.desc); + auto N = I_tv.size[0], C = I_tv.size[1]; + + float sum = 0; + for(size_t n = 0; n < N; n++) + { + float loss = 0; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + continue; + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + float t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = weight[W_tv.get_tensor_view_idx({y})] * t; + loss += t / static_cast(C); + } + sum += loss; + } + sum /= divisor; + ref_output[0] = static_cast(sum); +} + +#endif diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp new file mode 100644 index 0000000000..ce7e501a7c --- /dev/null +++ b/test/gtest/multimarginloss.cpp @@ -0,0 +1,110 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "multimarginloss.hpp" +#include + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace multimarginloss { + +std::string GetFloatArg() +{ + const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct MultiMarginLossForwardTestFloat : MultiMarginLossForwardTest +{ +}; + +struct MultiMarginLossForwardTestHalf : MultiMarginLossForwardTest +{ +}; + +struct MultiMarginLossForwardTestBFloat16 : MultiMarginLossForwardTest +{ +}; + +} // namespace multimarginloss + +using namespace multimarginloss; + +TEST_P(MultiMarginLossForwardTestFloat, ) +{ + if(!MIOPEN_TEST_ALL || + (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(MultiMarginLossForwardTestHalf, ) +{ + if(!MIOPEN_TEST_ALL || + (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(MultiMarginLossForwardTestBFloat16, ) +{ + if(!MIOPEN_TEST_ALL || + (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; +INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, + MultiMarginLossForwardTestFloat, + testing::ValuesIn(MultiMarginLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, + MultiMarginLossForwardTestHalf, + testing::ValuesIn(MultiMarginLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, + MultiMarginLossForwardTestBFloat16, + testing::ValuesIn(MultiMarginLossTestConfigs())); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp new file mode 100644 index 0000000000..53a452770e --- /dev/null +++ b/test/gtest/multimarginloss.hpp @@ -0,0 +1,259 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "../driver/tensor_driver.hpp" +#include "cpu_multimarginloss.hpp" +#include "get_handle.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include +#include + +struct MultiMarginLossTestCase +{ + std::vector dims; + bool cont; + miopenLossReductionMode_t reduction_mode; + long p; +}; + +std::vector MultiMarginLossTestConfigs() +{ + // clang-format off + return { + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + }; + // clang-format on +} + +template +struct MultiMarginLossForwardTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + auto in_dims = config.dims; + reduction_mode = config.reduction_mode; + p = config.p; + margin = prng::gen_A_to_B(0.5, 1.5); + size_t N = in_dims[0], C = in_dims[1]; + + if(std::is_same::value && + reduction_mode != MIOPEN_LOSS_REDUCTION_NONE && N >= 100000) + { + std::cerr << "For fp16 forward reduction test, too many elements in input tensor can " + "lead to fp16 " + "overflow or underflow. If reduction mean, divisor is very big lead to " + "underflow. If reduction sum, result is too big lead to overflow." + << std::endl; + GTEST_SKIP(); + } + + if(config.cont) + { + input = tensor{in_dims}; + // input is (N, C) -> target is (N) + target = tensor{std::vector{N}}; + // input is (N, C) -> weight is (C) + weight = tensor{std::vector{C}}; + } + else + { + std::vector in_strides(in_dims.size()); + in_strides.back() = 1; + for(int i = in_dims.size() - 2; i >= 0; --i) + in_strides[i] = in_strides[i + 1] * in_dims[i + 1]; + in_strides[0] *= 2; + input = tensor{in_dims, in_strides}; + + std::vector t_len = {N}; + std::vector t_strides = {2}; + target = tensor{t_len, t_strides}; + + std::vector w_lens = {C}; + std::vector w_strides = {2}; + weight = tensor{w_lens, w_strides}; + } + + auto gen_in_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-1), static_cast(1)); + }; + std::generate(input.begin(), input.end(), gen_in_value); + input_dev = handle.Write(input.data); + + for(auto& ptr : target) + { + ptr = prng::gen_A_to_B(0, C); + } + target_dev = handle.Write(target.data); + + auto gen_weight_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-1), static_cast(1)); + }; + std::generate(weight.begin(), weight.end(), gen_weight_value); + weight_dev = handle.Write(weight.data); + + if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + { + // input is (N, C) -> output is (N) + output = tensor{std::vector{N}}; + ref_output = tensor{std::vector{N}}; + } + else + { + // Tensor with 1 element to store result after reduce + output = tensor{std::vector{1}}; + ref_output = tensor{std::vector{1}}; + } + std::fill(output.begin(), output.end(), 0); + std::fill(ref_output.begin(), ref_output.end(), 0); + output_dev = handle.Write(output.data); + + if(reduction_mode != MIOPEN_LOSS_REDUCTION_NONE) + { + ws_sizeInBytes = miopen::GetMultiMarginLossForwardWorkspaceSize(handle, + input.desc, + target.desc, + weight.desc, + output.desc, + p, + margin, + reduction_mode); + if(ws_sizeInBytes == static_cast(-1)) + GTEST_SKIP(); + workspace = tensor{std::vector{ws_sizeInBytes / sizeof(T)}}; + std::fill(workspace.begin(), workspace.end(), 0); + workspace_dev = handle.Write(workspace.data); + } + } + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + { + cpu_multimarginloss_unreduced_forward(input, target, weight, ref_output, p, margin); + + status = miopen::MultiMarginLossUnreducedForward(handle, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + weight.desc, + weight_dev.get(), + output.desc, + output_dev.get(), + p, + margin); + } + else + { + cpu_multimarginloss_reduced_forward( + input, + target, + weight, + ref_output, + p, + margin, + (reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) ? input.desc.GetLengths()[0] : 1); + + status = miopen::MultiMarginLossForward(handle, + workspace_dev.get(), + ws_sizeInBytes, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + weight.desc, + weight_dev.get(), + output.desc, + output_dev.get(), + p, + margin, + reduction_mode); + } + EXPECT_EQ(status, miopenStatusSuccess); + + // Write from GPU to CPU + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto threshold = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + threshold *= 8.0; + + auto error = miopen::rms_range(ref_output, output); + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + // When doing reduction with big test, floating point precision error is high. I raise + // threshold from *10 to *30 to pass big test + EXPECT_TRUE(error < threshold * 30) << "Error output beyond tolerance " + "Error:" + << error << ", Threshold x 30: " << threshold * 10; + } + MultiMarginLossTestCase config; + + tensor input; + tensor target; + tensor weight; + tensor output; + tensor workspace; + + tensor ref_output; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr weight_dev; + miopen::Allocator::ManageDataPtr output_dev; + miopen::Allocator::ManageDataPtr workspace_dev; + + long p; + float margin; + miopenLossReductionMode_t reduction_mode; + size_t ws_sizeInBytes; +}; diff --git a/test/tensor_holder.hpp b/test/tensor_holder.hpp index 8ff00e6532..5428ce54cd 100644 --- a/test/tensor_holder.hpp +++ b/test/tensor_holder.hpp @@ -137,6 +137,11 @@ struct miopen_type : std::integral_constant +struct miopen_type : std::integral_constant +{ +}; + template struct tensor { From ac4c9fe8974ec2ddf854f4aaab69a8d75acd9bb1 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Sun, 14 Jul 2024 16:48:10 +0000 Subject: [PATCH 03/30] update condition --- src/include/miopen/tensor.hpp | 1 + src/kernels/MIOpenMultiMarginLoss.cpp | 1 + .../forward_reduced_multimarginloss.cpp | 9 ++++++--- .../forward_unreduced_multimarginloss.cpp | 9 ++++++--- src/tensor.cpp | 16 ++++++++++++++++ 5 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/include/miopen/tensor.hpp b/src/include/miopen/tensor.hpp index 5794fdb21a..70d7751bfa 100644 --- a/src/include/miopen/tensor.hpp +++ b/src/include/miopen/tensor.hpp @@ -216,6 +216,7 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor } bool IsPacked() const; + bool IsContiguous() const; /// Checks all lengths and strides. bool AllDimsFitIntoInt() const; /// Checks only lengths. diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 6855270329..23bddbce36 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -31,6 +31,7 @@ #include "float_types.h" #include "tensor_view.hpp" +// TODO: try to pass margin as DTYPE to optimize memory access template __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, const uint64_t* __restrict__ T, diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index 75857de821..cc903e3b81 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -46,9 +46,12 @@ bool MultiMarginLossForward::IsApplicable( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { - // TODO: edit later - // if(problem.GetiDesc().GetLengths()[1] > 24) - // return false; + if((problem.GetiDesc().GetType() == miopenHalf || + problem.GetiDesc().GetType() == miopenBFloat16) && + problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] > 40) + return false; + if(problem.GetiDesc().GetLengths()[1] > 30) + return false; return true; } diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp index a585f66cb2..fe7e4c4be4 100644 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -45,9 +45,12 @@ bool MultiMarginLossUnreducedForward::IsApplicable( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { - // TODO: edit later - // if(problem.GetiDesc().GetLengths()[1] > 24) - // return false; + if((problem.GetiDesc().GetType() == miopenHalf || + problem.GetiDesc().GetType() == miopenBFloat16) && + problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] > 40) + return false; + if(problem.GetiDesc().GetLengths()[1] > 30) + return false; return true; } diff --git a/src/tensor.cpp b/src/tensor.cpp index 669a5842ea..b4b89d31ca 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -458,6 +458,22 @@ std::size_t TensorDescriptor::GetNumBytes() const bool TensorDescriptor::IsPacked() const { return this->packed; } +bool TensorDescriptor::IsContiguous() const +{ + size_t plane_size = 1; + size_t dims_of_shape = lens.size(); + + for(int index = dims_of_shape - 1; index >= 0; --index) + { + if((lens[index] != 1) && (strides[index] != plane_size)) + { + return false; + } + plane_size *= static_cast(lens[index]); + } + return true; +} + bool TensorDescriptor::AllLengthsFitIntoInt() const { if(std::any_of(lens.cbegin(), lens.cend(), [](std::size_t x) { From cad5f79c0b6de3e5588b98fed30da33aaa1e5647 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Mon, 15 Jul 2024 08:10:44 +0000 Subject: [PATCH 04/30] remove redundant header + some fix by clangd recommend --- driver/multimarginloss_driver.hpp | 10 ++++---- .../multimarginloss/problem_description.hpp | 2 -- .../miopen/multimarginloss/solvers.hpp | 1 - src/include/miopen/solver_id.hpp | 1 - src/include/miopen/tensor_view_utils.hpp | 2 +- src/kernels/MIOpenMultiMarginLoss.cpp | 2 +- src/kernels/warp_shuffle.hpp | 2 +- src/multimarginloss/problem_description.cpp | 1 - src/multimarginloss_api.cpp | 2 +- test/gtest/multimarginloss.cpp | 25 ++++++++++++++++++- test/gtest/multimarginloss.hpp | 23 ----------------- test/tensor_holder.hpp | 4 --- 12 files changed, 33 insertions(+), 42 deletions(-) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 20af57a487..e944a5d325 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -312,9 +312,9 @@ int MultiMarginLossDriver::AllocateBuffersAndCopy() size_t i_sz = GetTensorSpace(iDesc); size_t t_sz = GetTensorSpace(tDesc); size_t w_sz = GetTensorSpace(wDesc); - i_dev = std::unique_ptr(new GPUMem(ctx, i_sz, sizeof(Tgpu))); - t_dev = std::unique_ptr(new GPUMem(ctx, t_sz, sizeof(uint64_t))); - w_dev = std::unique_ptr(new GPUMem(ctx, w_sz, sizeof(Tgpu))); + i_dev = std::make_unique(ctx, i_sz, sizeof(Tgpu)); + t_dev = std::make_unique(ctx, t_sz, sizeof(uint64_t)); + w_dev = std::make_unique(ctx, w_sz, sizeof(Tgpu)); I = std::vector(i_sz); T = std::vector(t_sz); W = std::vector(w_sz); @@ -364,7 +364,7 @@ int MultiMarginLossDriver::AllocateBuffersAndCopy() else ws_sizeInBytes = 0; - o_dev = std::unique_ptr(new GPUMem(ctx, o_sz, sizeof(Tgpu))); + o_dev = std::make_unique(ctx, o_sz, sizeof(Tgpu)); O = std::vector(o_sz); Ohost = std::vector(o_sz); std::fill(O.begin(), O.end(), 0); @@ -373,7 +373,7 @@ int MultiMarginLossDriver::AllocateBuffersAndCopy() std::cerr << "Error copying (out) to GPU, size: " << o_dev->GetSize() << std::endl; size_t ws_sz = ws_sizeInBytes / sizeof(Tgpu); - workspace_dev = std::unique_ptr(new GPUMem(ctx, ws_sz, sizeof(Tgpu))); + workspace_dev = std::make_unique(ctx, ws_sz, sizeof(Tgpu)); workspace = std::vector(ws_sz); std::fill(workspace.begin(), workspace.end(), 0); diff --git a/src/include/miopen/multimarginloss/problem_description.hpp b/src/include/miopen/multimarginloss/problem_description.hpp index 10d11c0a07..23cf200714 100644 --- a/src/include/miopen/multimarginloss/problem_description.hpp +++ b/src/include/miopen/multimarginloss/problem_description.hpp @@ -30,8 +30,6 @@ #include #include #include -#include -#include namespace miopen { diff --git a/src/include/miopen/multimarginloss/solvers.hpp b/src/include/miopen/multimarginloss/solvers.hpp index 1a018b1d6a..db76d65d10 100644 --- a/src/include/miopen/multimarginloss/solvers.hpp +++ b/src/include/miopen/multimarginloss/solvers.hpp @@ -27,7 +27,6 @@ #include #include -#include namespace miopen { diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index 0d2ae9cd34..e93cdf0e4b 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -32,7 +32,6 @@ #include #include -#include namespace miopen { diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index afeaaeea78..226e33749d 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -28,7 +28,7 @@ #define MIOPEN_TENSOR_VIEW_UTIL_HPP_ #include "../../kernels/tensor_view.hpp" -#include "miopen/tensor.hpp" +#include namespace miopen { diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 23bddbce36..9bb3faac0f 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -147,4 +147,4 @@ extern "C" __global__ void MultiMarginLossForward2d(const FLOAT* __restrict__ I, { // instantiate the kernel multimarginlossforward2d(I, T, W, lsum, p, margin, divisor, I_tv, T_tv, W_tv); -} \ No newline at end of file +} diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/warp_shuffle.hpp index 3aa75660f6..771f7f7954 100644 --- a/src/kernels/warp_shuffle.hpp +++ b/src/kernels/warp_shuffle.hpp @@ -65,4 +65,4 @@ __device__ FLOAT_ACCUM block_reduce_sum(FLOAT_ACCUM val) val = warp_reduce_sum(val); return val; -} \ No newline at end of file +} diff --git a/src/multimarginloss/problem_description.cpp b/src/multimarginloss/problem_description.cpp index 295169ff36..0656696ca7 100644 --- a/src/multimarginloss/problem_description.cpp +++ b/src/multimarginloss/problem_description.cpp @@ -34,7 +34,6 @@ namespace multimarginloss { NetworkConfig ForwardProblemDescription::MakeNetworkConfig() const { - // TODO: edit after finish solver std::ostringstream ss; ss << "multilmarginloss_fwd"; ss << "itype" << iDesc.GetType(); diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp index 6ef96d400b..ae68972bf2 100644 --- a/src/multimarginloss_api.cpp +++ b/src/multimarginloss_api.cpp @@ -23,7 +23,7 @@ * SOFTWARE. * *******************************************************************************/ -#include "miopen/miopen.h" +#include #include #include #include diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp index ce7e501a7c..37170205e1 100644 --- a/test/gtest/multimarginloss.cpp +++ b/test/gtest/multimarginloss.cpp @@ -56,7 +56,9 @@ struct MultiMarginLossForwardTestBFloat16 : MultiMarginLossForwardTest } // namespace multimarginloss -using namespace multimarginloss; +using multimarginloss::MultiMarginLossForwardTestBFloat16; +using multimarginloss::MultiMarginLossForwardTestFloat; +using multimarginloss::MultiMarginLossForwardTestHalf; TEST_P(MultiMarginLossForwardTestFloat, ) { @@ -99,6 +101,27 @@ TEST_P(MultiMarginLossForwardTestBFloat16, ) GTEST_SKIP(); } }; + +std::vector MultiMarginLossTestConfigs() +{ + // clang-format off + return { + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + }; + // clang-format on +} + INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, MultiMarginLossForwardTestFloat, testing::ValuesIn(MultiMarginLossTestConfigs())); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index 53a452770e..62c93d3dcd 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -24,16 +24,13 @@ * *******************************************************************************/ -#include "../driver/tensor_driver.hpp" #include "cpu_multimarginloss.hpp" #include "get_handle.hpp" -#include "random.hpp" #include "tensor_holder.hpp" #include "verify.hpp" #include #include #include -#include struct MultiMarginLossTestCase { @@ -43,26 +40,6 @@ struct MultiMarginLossTestCase long p; }; -std::vector MultiMarginLossTestConfigs() -{ - // clang-format off - return { - {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, - {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, - {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, - {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, - {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, - {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, - {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, - {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, - {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, - }; - // clang-format on -} - template struct MultiMarginLossForwardTest : public ::testing::TestWithParam { diff --git a/test/tensor_holder.hpp b/test/tensor_holder.hpp index 5428ce54cd..4bf198e9a8 100644 --- a/test/tensor_holder.hpp +++ b/test/tensor_holder.hpp @@ -27,7 +27,6 @@ #define GUARD_TENSOR_HOLDER_HPP #include "ford.hpp" -#include "network_data.hpp" #include #include #include @@ -44,9 +43,6 @@ using hip_bfloat16 = bfloat16; using float8 = miopen_f8::hip_f8; using bfloat8 = miopen_f8::hip_f8; -#include -#include - template void visit_tensor_size(std::size_t n, F f) { From 679fc435fcd46495002c57cf677b9e53325fd15f Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Mon, 15 Jul 2024 09:54:21 +0000 Subject: [PATCH 05/30] fix warnings by make analyze --- src/kernels/MIOpenMultiMarginLoss.cpp | 1 - src/multimarginloss/problem_description.cpp | 4 ++-- test/gtest/multimarginloss.cpp | 6 +++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 9bb3faac0f..6ade61245b 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -31,7 +31,6 @@ #include "float_types.h" #include "tensor_view.hpp" -// TODO: try to pass margin as DTYPE to optimize memory access template __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, const uint64_t* __restrict__ T, diff --git a/src/multimarginloss/problem_description.cpp b/src/multimarginloss/problem_description.cpp index 0656696ca7..497f0d119c 100644 --- a/src/multimarginloss/problem_description.cpp +++ b/src/multimarginloss/problem_description.cpp @@ -39,8 +39,8 @@ NetworkConfig ForwardProblemDescription::MakeNetworkConfig() const ss << "itype" << iDesc.GetType(); ss << "ilen"; auto ilen = iDesc.GetLengths(); - for(int32_t i = 0; i < ilen.size(); i++) - ss << ilen[i] << "_"; + for(unsigned long i : ilen) + ss << i << "_"; ss << "divisor" << divisor; return NetworkConfig{ss.str()}; } diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp index 37170205e1..2bfed37696 100644 --- a/test/gtest/multimarginloss.cpp +++ b/test/gtest/multimarginloss.cpp @@ -60,7 +60,7 @@ using multimarginloss::MultiMarginLossForwardTestBFloat16; using multimarginloss::MultiMarginLossForwardTestFloat; using multimarginloss::MultiMarginLossForwardTestHalf; -TEST_P(MultiMarginLossForwardTestFloat, ) +TEST_P(MultiMarginLossForwardTestFloat, MMLFwdTest) { if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--float")) @@ -74,7 +74,7 @@ TEST_P(MultiMarginLossForwardTestFloat, ) } }; -TEST_P(MultiMarginLossForwardTestHalf, ) +TEST_P(MultiMarginLossForwardTestHalf, MMLFwdTest) { if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--half")) @@ -88,7 +88,7 @@ TEST_P(MultiMarginLossForwardTestHalf, ) } }; -TEST_P(MultiMarginLossForwardTestBFloat16, ) +TEST_P(MultiMarginLossForwardTestBFloat16, MMLFwdTest) { if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--bfloat16")) From 7e2b420274bbe99b0c61d361e872286091d06934 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Tue, 16 Jul 2024 03:25:18 +0000 Subject: [PATCH 06/30] check p value --- src/include/miopen/multimarginloss/problem_description.hpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/include/miopen/multimarginloss/problem_description.hpp b/src/include/miopen/multimarginloss/problem_description.hpp index 23cf200714..c3eeecfa08 100644 --- a/src/include/miopen/multimarginloss/problem_description.hpp +++ b/src/include/miopen/multimarginloss/problem_description.hpp @@ -101,6 +101,11 @@ struct ForwardProblemDescription : ProblemDescriptionBase "MultiMarginLoss: Output tensor need to be a scalar."); } } + // Check p value + if(p != 1 && p != 2) + { + MIOPEN_THROW(miopenStatusBadParm, "MultiMarginLoss: p need to be equal 1 or 2."); + } } const TensorDescriptor& GetiDesc() const { return iDesc; } From 2d6bc8523b3830420af710a9e0dc7be4f8d7d070 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 17 Jul 2024 08:34:19 +0000 Subject: [PATCH 07/30] partial fix reviewer comment --- driver/CMakeLists.txt | 2 +- driver/multimarginloss_driver.hpp | 2 +- src/CMakeLists.txt | 2 +- .../miopen/multimarginloss/solvers.hpp | 6 ++++ src/kernels/MIOpenLossReduce.cpp | 4 +-- src/kernels/MIOpenMultiMarginLoss.cpp | 18 +++++------ src/kernels/warp_shuffle.hpp | 5 ++++ .../forward_reduced_multimarginloss.cpp | 15 ++++++++++ .../forward_unreduced_multimarginloss.cpp | 14 +++++++-- test/gtest/multimarginloss.cpp | 20 ------------- test/gtest/multimarginloss.hpp | 30 +++++++++++++++++++ test/tensor_holder.hpp | 4 +++ 12 files changed, 86 insertions(+), 36 deletions(-) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index d1160be35c..b151c05432 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -44,9 +44,9 @@ add_executable(MIOpenDriver dm_fusion.cpp dm_gemm.cpp dm_groupnorm.cpp - dm_multimarginloss.cpp dm_layernorm.cpp dm_lrn.cpp + dm_multimarginloss.cpp dm_pool.cpp dm_reduce.cpp dm_reduceextreme.cpp diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index e944a5d325..2580e18b3a 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -205,7 +205,7 @@ int MultiMarginLossDriver::AddCmdLineArgs() { inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Take (Default=1)", "int"); inflags.AddInputFlag("dim", 'D', "41x4", "Dim of input tensor (Default=41x4)", "tensor"); - inflags.AddInputFlag("contiguous", 'C', "1", "Tensor is contiguous or not", "int"); + inflags.AddInputFlag("contiguous", 'C', "1", "Tensor is contiguous or not (Default=1)", "int"); inflags.AddInputFlag("iter", 'i', "1", "Number of Iterations (Default=1)", "int"); inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9fe5940d27..7d2cd1fa85 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -618,12 +618,12 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN adam.cpp addlayernorm.cpp cat.cpp + exec_utils.cpp groupnorm.cpp kernel_cache.cpp layernorm.cpp lrn.cpp mlo_dir_conv.cpp - exec_utils.cpp multimarginloss.cpp ocl/activ_ocl.cpp ocl/batchnormocl.cpp diff --git a/src/include/miopen/multimarginloss/solvers.hpp b/src/include/miopen/multimarginloss/solvers.hpp index db76d65d10..078eaa0ce0 100644 --- a/src/include/miopen/multimarginloss/solvers.hpp +++ b/src/include/miopen/multimarginloss/solvers.hpp @@ -44,6 +44,9 @@ struct MultiMarginLossUnreducedForward final : ForwardMultiMarginLossSolver return GetSolverDbId(); } bool + IsImprovementOverROCm(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const; + bool IsApplicable(const ExecutionContext& context, const miopen::multimarginloss::ForwardProblemDescription& problem) const override; ConvSolution @@ -58,6 +61,9 @@ struct MultiMarginLossForward final : ForwardMultiMarginLossSolver return GetSolverDbId(); } bool + IsImprovementOverROCm(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const; + bool IsApplicable(const ExecutionContext& context, const miopen::multimarginloss::ForwardProblemDescription& problem) const override; ConvSolution diff --git a/src/kernels/MIOpenLossReduce.cpp b/src/kernels/MIOpenLossReduce.cpp index 023af37626..a86839d3c7 100644 --- a/src/kernels/MIOpenLossReduce.cpp +++ b/src/kernels/MIOpenLossReduce.cpp @@ -32,7 +32,7 @@ #include "warp_shuffle.hpp" template -__device__ void LossSum(const DTYPE* __restrict__ input, DTYPE* __restrict__ output, size_t N) +__device__ void LossSum(const DTYPE* __restrict__ input, DTYPE* __restrict__ output, uint64_t N) { auto gid = blockIdx.x * blockDim.x + threadIdx.x; @@ -44,7 +44,7 @@ __device__ void LossSum(const DTYPE* __restrict__ input, DTYPE* __restrict__ out } extern "C" __global__ void -ReduceSumLoss(const FLOAT* __restrict__ input, FLOAT* __restrict__ output, size_t N) +ReduceSumLoss(const FLOAT* __restrict__ input, FLOAT* __restrict__ output, uint64_t N) { // instantiate the kernel LossSum(input, output, N); diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 6ade61245b..224f600d49 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -36,8 +36,8 @@ __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, const uint64_t* __restrict__ T, const DTYPE* __restrict__ W, DTYPE* __restrict__ O, - long p, - float margin, + const long p, + const float margin, tensor_view_t<2> I_tv, tensor_view_t<1> T_tv, tensor_view_t<1> W_tv, @@ -78,8 +78,8 @@ extern "C" __global__ void MultiMarginLossUnreducedForward2d(const FLOAT* __rest const uint64_t* __restrict__ T, const FLOAT* __restrict__ W, FLOAT* __restrict__ O, - long p, - float margin, + const long p, + const float margin, tensor_view_t<2> I_tv, tensor_view_t<1> T_tv, tensor_view_t<1> W_tv, @@ -94,8 +94,8 @@ __device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, const uint64_t* __restrict__ T, const DTYPE* __restrict__ W, DTYPE* __restrict__ lsum, - long p, - float margin, + const long p, + const float margin, const float divisor, tensor_view_t<2> I_tv, tensor_view_t<1> T_tv, @@ -109,7 +109,7 @@ __device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, return; FLOAT_ACCUM loss = 0; - long y = T[T_tv.get_tensor_view_idx({n})]; + uint64_t y = T[T_tv.get_tensor_view_idx({n})]; if(y >= C) { // TODO: need to handle invalid target index value @@ -137,8 +137,8 @@ extern "C" __global__ void MultiMarginLossForward2d(const FLOAT* __restrict__ I, const uint64_t* __restrict__ T, const FLOAT* __restrict__ W, FLOAT* __restrict__ lsum, - long p, - float margin, + const long p, + const float margin, const float divisor, tensor_view_t<2> I_tv, tensor_view_t<1> T_tv, diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/warp_shuffle.hpp index 771f7f7954..db536bd2ae 100644 --- a/src/kernels/warp_shuffle.hpp +++ b/src/kernels/warp_shuffle.hpp @@ -24,6 +24,9 @@ * *******************************************************************************/ +#ifndef GUARD_WARP_SHUFFLE_HPP +#define GUARD_WARP_SHUFFLE_HPP + #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include @@ -66,3 +69,5 @@ __device__ FLOAT_ACCUM block_reduce_sum(FLOAT_ACCUM val) return val; } + +#endif // GUARD_WARP_SHUFFLE_HPP diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index cc903e3b81..ba311455f6 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -42,6 +42,19 @@ namespace solver { namespace multimarginloss { +bool MultiMarginLossForward::IsImprovementOverROCm( + const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + if((problem.GetiDesc().GetType() == miopenHalf || + problem.GetiDesc().GetType() == miopenBFloat16) && + problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] > 40) + return false; + if(problem.GetiDesc().GetLengths()[1] > 30) + return false; + return true; +} + bool MultiMarginLossForward::IsApplicable( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const @@ -80,6 +93,7 @@ ConvSolution MultiMarginLossForward::GetSolution( const auto build_params = KernelBuildParameters{ {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, }; @@ -102,6 +116,7 @@ ConvSolution MultiMarginLossForward::GetSolution( const auto build_params = KernelBuildParameters{ {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, {"REDUCE_SIZE", LOCAL_SIZE_REDUCE}, }; diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp index fe7e4c4be4..54a59ba5ca 100644 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -41,8 +41,8 @@ namespace solver { namespace multimarginloss { -bool MultiMarginLossUnreducedForward::IsApplicable( - const ExecutionContext& /*context*/, +bool MultiMarginLossUnreducedForward::IsImprovementOverROCm( + const ExecutionContext& context, const miopen::multimarginloss::ForwardProblemDescription& problem) const { if((problem.GetiDesc().GetType() == miopenHalf || @@ -54,6 +54,15 @@ bool MultiMarginLossUnreducedForward::IsApplicable( return true; } +bool MultiMarginLossUnreducedForward::IsApplicable( + const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + if(!IsImprovementOverROCm(context, problem)) + return false; + return true; +} + ConvSolution MultiMarginLossUnreducedForward::GetSolution( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const @@ -78,6 +87,7 @@ ConvSolution MultiMarginLossUnreducedForward::GetSolution( const auto build_params = KernelBuildParameters{ {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, }; diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp index 2bfed37696..db150522f6 100644 --- a/test/gtest/multimarginloss.cpp +++ b/test/gtest/multimarginloss.cpp @@ -102,26 +102,6 @@ TEST_P(MultiMarginLossForwardTestBFloat16, MMLFwdTest) } }; -std::vector MultiMarginLossTestConfigs() -{ - // clang-format off - return { - {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, - {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, - {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, - {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, - {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, - {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, - {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, - {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, - {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, - }; - // clang-format on -} - INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, MultiMarginLossForwardTestFloat, testing::ValuesIn(MultiMarginLossTestConfigs())); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index 62c93d3dcd..792bad6aa4 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -38,8 +38,38 @@ struct MultiMarginLossTestCase bool cont; miopenLossReductionMode_t reduction_mode; long p; + + friend std::ostream& operator<<(std::ostream& os, const MultiMarginLossTestCase& tc) + { + os << " dims:"; + os << tc.dims[0]; + for(int i = 1; i < tc.dims.size(); i++) + os << "x" << tc.dims[i]; + os << " cont:" << tc.cont << " reduction_mode:" << tc.reduction_mode << " p:" << tc.p; + return os; + } }; +inline std::vector MultiMarginLossTestConfigs() +{ + // clang-format off + return { + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + }; + // clang-format on +} + template struct MultiMarginLossForwardTest : public ::testing::TestWithParam { diff --git a/test/tensor_holder.hpp b/test/tensor_holder.hpp index 4bf198e9a8..5428ce54cd 100644 --- a/test/tensor_holder.hpp +++ b/test/tensor_holder.hpp @@ -27,6 +27,7 @@ #define GUARD_TENSOR_HOLDER_HPP #include "ford.hpp" +#include "network_data.hpp" #include #include #include @@ -43,6 +44,9 @@ using hip_bfloat16 = bfloat16; using float8 = miopen_f8::hip_f8; using bfloat8 = miopen_f8::hip_f8; +#include +#include + template void visit_tensor_size(std::size_t n, F f) { From 4f5092c0066564bb77587b2bb2b9e33af233f203 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 17 Jul 2024 08:54:00 +0000 Subject: [PATCH 08/30] handle getworkspace internal --- driver/multimarginloss_driver.hpp | 50 ++++++++++++------------------- include/miopen/miopen.h | 3 +- src/multimarginloss.cpp | 5 ++++ src/multimarginloss_api.cpp | 31 +++++++------------ 4 files changed, 36 insertions(+), 53 deletions(-) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 2580e18b3a..e817efccd3 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -345,24 +345,13 @@ int MultiMarginLossDriver::AllocateBuffersAndCopy() if(forw == 0 || forw == 1) { size_t o_sz = GetTensorSpace(oDesc); - if(reduction_mode != MIOPEN_LOSS_REDUCTION_NONE) + + miopenGetMultiMarginLossForwardWorkspaceSize( + GetHandle(), iDesc, tDesc, wDesc, oDesc, p, margin, reduction_mode, &ws_sizeInBytes); + if(ws_sizeInBytes == static_cast(-1)) { - miopenGetMultiMarginLossForwardWorkspaceSize(GetHandle(), - iDesc, - tDesc, - wDesc, - oDesc, - p, - margin, - reduction_mode, - &ws_sizeInBytes); - if(ws_sizeInBytes == static_cast(-1)) - { - return miopenStatusAllocFailed; - } + return miopenStatusAllocFailed; } - else - ws_sizeInBytes = 0; o_dev = std::make_unique(ctx, o_sz, sizeof(Tgpu)); O = std::vector(o_sz); @@ -396,21 +385,20 @@ int MultiMarginLossDriver::RunForwardGPU() for(int i = 0; i < inflags.GetValueInt("iter"); i++) { - miopenMultiMarginLossForward( - GetHandle(), - iDesc, - i_dev->GetMem(), - tDesc, - t_dev->GetMem(), - wDesc, - w_dev->GetMem(), - oDesc, - o_dev->GetMem(), - p, - margin, - reduction_mode, - (reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) ? nullptr : workspace_dev->GetMem(), - ws_sizeInBytes); + miopenMultiMarginLossForward(GetHandle(), + iDesc, + i_dev->GetMem(), + tDesc, + t_dev->GetMem(), + wDesc, + w_dev->GetMem(), + oDesc, + o_dev->GetMem(), + p, + margin, + reduction_mode, + workspace_dev->GetMem(), + ws_sizeInBytes); float time = 0.0; miopenGetKernelTime(GetHandle(), &time); diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 3cfe97c2ee..f6e4b0e915 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7255,8 +7255,7 @@ weight given to each class. It has to be a Tensor of size C then it must have shape (N). Otherwise, it is a scalar * @param [in] p Has a default value of 1. The only supported values are 1 and 2 * @param [in] margin Has a default value of 1 - * @param [in] reduction Reduction mode (sum, mean). For none reduction we don't need to -use this function + * @param [in] reduction Reduction mode (sum, mean) * @param [out] sizeInBytes Pointer to data to return the minimum workspace size * @return miopenStatus_t */ diff --git a/src/multimarginloss.cpp b/src/multimarginloss.cpp index fe9c041d1a..fb5c4b4df5 100644 --- a/src/multimarginloss.cpp +++ b/src/multimarginloss.cpp @@ -24,6 +24,7 @@ * *******************************************************************************/ +#include "miopen/miopen.h" #include #include #include @@ -84,6 +85,10 @@ std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, const float margin, miopenLossReductionMode_t reduction) { + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + return static_cast(0); + } auto ctx = ExecutionContext{&handle}; const float divisor = (reduction == MIOPEN_LOSS_REDUCTION_MEAN) ? iDesc.GetLengths()[0] : 1; const auto problem = diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp index ae68972bf2..cb0bdeb7e8 100644 --- a/src/multimarginloss_api.cpp +++ b/src/multimarginloss_api.cpp @@ -44,26 +44,17 @@ miopenGetMultiMarginLossForwardWorkspaceSize(miopenHandle_t handle, MIOPEN_LOG_FUNCTION( handle, inputDesc, targetDesc, weightDesc, outputDesc, p, margin, reduction); - if(reduction != MIOPEN_LOSS_REDUCTION_SUM && reduction != MIOPEN_LOSS_REDUCTION_MEAN) - { - MIOPEN_THROW(miopenStatusBadParm, - "miopenGetMultiMarginLossForwardWorkspaceSize: reduction should be " - "MIOPEN_LOSS_REDUCTION_SUM or MIOPEN_LOSS_REDUCTION_MEAN."); - } - else - { - return miopen::try_([&] { - miopen::deref(sizeInBytes) = - miopen::GetMultiMarginLossForwardWorkspaceSize(miopen::deref(handle), - miopen::deref(inputDesc), - miopen::deref(targetDesc), - miopen::deref(weightDesc), - miopen::deref(outputDesc), - p, - margin, - reduction); - }); - } + return miopen::try_([&] { + miopen::deref(sizeInBytes) = + miopen::GetMultiMarginLossForwardWorkspaceSize(miopen::deref(handle), + miopen::deref(inputDesc), + miopen::deref(targetDesc), + miopen::deref(weightDesc), + miopen::deref(outputDesc), + p, + margin, + reduction); + }); } miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, From 4618ab9912a9374b54ea8d253ed16b144f4ea214 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 17 Jul 2024 09:44:52 +0000 Subject: [PATCH 09/30] edit network config + gtest --- src/multimarginloss/problem_description.cpp | 1 + .../forward_reduced_multimarginloss.cpp | 12 ++-- .../forward_unreduced_multimarginloss.cpp | 12 ++-- test/cpu_multimarginloss.hpp | 26 ++++---- test/gtest/multimarginloss.cpp | 2 +- test/gtest/multimarginloss.hpp | 66 ++++++++++--------- 6 files changed, 61 insertions(+), 58 deletions(-) diff --git a/src/multimarginloss/problem_description.cpp b/src/multimarginloss/problem_description.cpp index 497f0d119c..ca094f819d 100644 --- a/src/multimarginloss/problem_description.cpp +++ b/src/multimarginloss/problem_description.cpp @@ -41,6 +41,7 @@ NetworkConfig ForwardProblemDescription::MakeNetworkConfig() const auto ilen = iDesc.GetLengths(); for(unsigned long i : ilen) ss << i << "_"; + ss << "cont" << iDesc.IsContiguous(); ss << "divisor" << divisor; return NetworkConfig{ss.str()}; } diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index ba311455f6..fa2e803f81 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -43,16 +43,16 @@ namespace solver { namespace multimarginloss { bool MultiMarginLossForward::IsImprovementOverROCm( - const ExecutionContext& context, + const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { + if(problem.GetiDesc().GetLengths()[1] <= 30) + return true; if((problem.GetiDesc().GetType() == miopenHalf || problem.GetiDesc().GetType() == miopenBFloat16) && - problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] > 40) - return false; - if(problem.GetiDesc().GetLengths()[1] > 30) - return false; - return true; + problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] <= 40) + return true; + return false; } bool MultiMarginLossForward::IsApplicable( diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp index 54a59ba5ca..729df7d431 100644 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -42,16 +42,16 @@ namespace solver { namespace multimarginloss { bool MultiMarginLossUnreducedForward::IsImprovementOverROCm( - const ExecutionContext& context, + const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { + if(problem.GetiDesc().GetLengths()[1] <= 30) + return true; if((problem.GetiDesc().GetType() == miopenHalf || problem.GetiDesc().GetType() == miopenBFloat16) && - problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] > 40) - return false; - if(problem.GetiDesc().GetLengths()[1] > 30) - return false; - return true; + problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] <= 40) + return true; + return false; } bool MultiMarginLossUnreducedForward::IsApplicable( diff --git a/test/cpu_multimarginloss.hpp b/test/cpu_multimarginloss.hpp index 8e790feb5a..0bde558d62 100644 --- a/test/cpu_multimarginloss.hpp +++ b/test/cpu_multimarginloss.hpp @@ -46,23 +46,23 @@ void cpu_multimarginloss_unreduced_forward(tensor input, for(size_t n = 0; n < N; n++) { - float loss = 0; - uint64_t y = target[T_tv.get_tensor_view_idx({n})]; + double loss = 0; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; if(y >= C) continue; for(size_t c = 0; c < C; c++) { if(y == c) continue; - float t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; - t = weight[W_tv.get_tensor_view_idx({y})] * t; - loss += t / static_cast(C); + t = static_cast(weight[W_tv.get_tensor_view_idx({y})]) * t; + loss += t / C; } ref_output[O_tv.get_tensor_view_idx({n})] = loss; } @@ -82,26 +82,26 @@ void cpu_multimarginloss_reduced_forward(tensor input, auto W_tv = miopen::get_inner_expanded_tv<1>(weight.desc); auto N = I_tv.size[0], C = I_tv.size[1]; - float sum = 0; + double sum = 0; for(size_t n = 0; n < N; n++) { - float loss = 0; - uint64_t y = target[T_tv.get_tensor_view_idx({n})]; + double loss = 0; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; if(y >= C) continue; for(size_t c = 0; c < C; c++) { if(y == c) continue; - float t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; - t = weight[W_tv.get_tensor_view_idx({y})] * t; - loss += t / static_cast(C); + t = static_cast(weight[W_tv.get_tensor_view_idx({y})]) * t; + loss += t / C; } sum += loss; } diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp index db150522f6..dfc226258a 100644 --- a/test/gtest/multimarginloss.cpp +++ b/test/gtest/multimarginloss.cpp @@ -107,7 +107,7 @@ INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, testing::ValuesIn(MultiMarginLossTestConfigs())); INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, MultiMarginLossForwardTestHalf, - testing::ValuesIn(MultiMarginLossTestConfigs())); + testing::ValuesIn(MultiMarginLossFp16TestConfigs())); INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, MultiMarginLossForwardTestBFloat16, testing::ValuesIn(MultiMarginLossTestConfigs())); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index 792bad6aa4..f8d614e91f 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -54,18 +54,35 @@ inline std::vector MultiMarginLossTestConfigs() { // clang-format off return { - {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, - {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, - {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, - {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, - {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, - {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, - {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, - {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, - {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, - {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + }; + // clang-format on +} + +// Remove big tests with reduction from FP16 test because the result will be overflow/ underflow +inline std::vector MultiMarginLossFp16TestConfigs() +{ + // clang-format off + return { + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, }; // clang-format on } @@ -85,17 +102,6 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam(0.5, 1.5); size_t N = in_dims[0], C = in_dims[1]; - if(std::is_same::value && - reduction_mode != MIOPEN_LOSS_REDUCTION_NONE && N >= 100000) - { - std::cerr << "For fp16 forward reduction test, too many elements in input tensor can " - "lead to fp16 " - "overflow or underflow. If reduction mean, divisor is very big lead to " - "underflow. If reduction sum, result is too big lead to overflow." - << std::endl; - GTEST_SKIP(); - } - if(config.cont) { input = tensor{in_dims}; @@ -167,7 +173,7 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam(-1)) - GTEST_SKIP(); + GTEST_SKIP() << "Call GetMultiMarginLossForwardWorkspaceSize failed!"; workspace = tensor{std::vector{ws_sizeInBytes / sizeof(T)}}; std::fill(workspace.begin(), workspace.end(), 0); workspace_dev = handle.Write(workspace.data); @@ -229,19 +235,15 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam::value ? 1.5e-6 : 8.2e-3; - + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. if(std::is_same::value) - threshold *= 8.0; + tolerance *= 8.0; auto error = miopen::rms_range(ref_output, output); EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); - // When doing reduction with big test, floating point precision error is high. I raise - // threshold from *10 to *30 to pass big test - EXPECT_TRUE(error < threshold * 30) << "Error output beyond tolerance " - "Error:" - << error << ", Threshold x 30: " << threshold * 10; + EXPECT_TRUE(error < tolerance) << "Error output beyond tolerance. Error:" << error + << ", Tolerance: " << tolerance * 10; } MultiMarginLossTestCase config; From a2a8580545aff67e00f25b59e045d61a4411f090 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Thu, 18 Jul 2024 10:39:53 +0000 Subject: [PATCH 10/30] change driver tref to float --- driver/dm_multimarginloss.cpp | 6 +++--- driver/multimarginloss_driver.hpp | 23 ++++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/driver/dm_multimarginloss.cpp b/driver/dm_multimarginloss.cpp index 0a74712448..1c924bb7dc 100644 --- a/driver/dm_multimarginloss.cpp +++ b/driver/dm_multimarginloss.cpp @@ -29,11 +29,11 @@ static Driver* makeDriver(const std::string& base_arg) { if(base_arg == "multimarginloss") - return new MultiMarginLossDriver(); + return new MultiMarginLossDriver(); if(base_arg == "multimarginlossfp16") - return new MultiMarginLossDriver(); + return new MultiMarginLossDriver(); if(base_arg == "multimarginlossbfp16") - return new MultiMarginLossDriver(); + return new MultiMarginLossDriver(); return nullptr; } diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index e817efccd3..7867e39c0e 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -64,7 +64,7 @@ int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc for(size_t n = 0; n < N; n++) { - Tcheck loss = 0; + double loss = 0; uint64_t y = target[T_tv.get_tensor_view_idx({n})]; if(y >= C) continue; @@ -72,17 +72,17 @@ int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc { if(y == c) continue; - Tcheck t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; t = weight[W_tv.get_tensor_view_idx({y})] * t; - loss += t / static_cast(C); + loss += t / C; } - ref_output[O_tv.get_tensor_view_idx({n})] = loss; + ref_output[O_tv.get_tensor_view_idx({n})] = static_cast(loss); } return ret; } @@ -106,9 +106,10 @@ int32_t mloMultiMarginLossReducedForwardRunHost(miopenTensorDescriptor_t iDesc, int32_t ret = 0; + double sum_loss = 0; for(size_t n = 0; n < N; n++) { - Tcheck loss = 0; + double loss = 0; uint64_t y = target[T_tv.get_tensor_view_idx({n})]; if(y >= C) continue; @@ -116,18 +117,18 @@ int32_t mloMultiMarginLossReducedForwardRunHost(miopenTensorDescriptor_t iDesc, { if(y == c) continue; - Tcheck t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; t = weight[W_tv.get_tensor_view_idx({y})] * t; - loss += t / static_cast(C); + loss += t / C; } - ref_output[0] += loss; + sum_loss += loss; } - ref_output[0] /= divisor; + ref_output[0] = static_cast(sum_loss / divisor); return ret; }; From 355ecf4e491cdc502a9b925f12fc1662d6dd048f Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Fri, 19 Jul 2024 03:09:27 +0000 Subject: [PATCH 11/30] small fix --- .../multimarginloss/forward_reduced_multimarginloss.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index fa2e803f81..639e70a15a 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -56,14 +56,10 @@ bool MultiMarginLossForward::IsImprovementOverROCm( } bool MultiMarginLossForward::IsApplicable( - const ExecutionContext& /*context*/, + const ExecutionContext& context, const miopen::multimarginloss::ForwardProblemDescription& problem) const { - if((problem.GetiDesc().GetType() == miopenHalf || - problem.GetiDesc().GetType() == miopenBFloat16) && - problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] > 40) - return false; - if(problem.GetiDesc().GetLengths()[1] > 30) + if(!IsImprovementOverROCm(context, problem)) return false; return true; } From 77d01009950a33f5e56ca3b17834a599518386a1 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 24 Jul 2024 10:53:51 +0000 Subject: [PATCH 12/30] add MIOPEN_EXPORT --- include/miopen/miopen.h | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 89e0d821fc..5d2ca7dffd 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7613,20 +7613,20 @@ reduction = 'none' reduction = 'none * @return miopenStatus_t */ -miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, - miopenTensorDescriptor_t inputDesc, - const void* input, - miopenTensorDescriptor_t targetDesc, - const void* target, - miopenTensorDescriptor_t weightDesc, - const void* weight, - miopenTensorDescriptor_t outputDesc, - void* output, - long p, - float margin, - miopenLossReductionMode_t reduction, - void* workspace, - size_t workspaceSizeInBytes); +MIOPEN_EXPORT miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t weightDesc, + const void* weight, + miopenTensorDescriptor_t outputDesc, + void* output, + long p, + float margin, + miopenLossReductionMode_t reduction, + void* workspace, + size_t workspaceSizeInBytes); /** @} */ // CLOSEOUT LossFunction DOXYGEN GROUP From 8006142736f3bcc3e7fd82cf5956985950a3da01 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Mon, 29 Jul 2024 03:39:49 +0000 Subject: [PATCH 13/30] partial fix comment --- driver/multimarginloss_driver.hpp | 11 +-- src/include/miopen/multimarginloss.hpp | 68 ++++++++++--------- src/kernels/MIOpenMultiMarginLoss.cpp | 4 +- .../forward_reduced_multimarginloss.cpp | 11 ++- test/gtest/multimarginloss.hpp | 2 +- 5 files changed, 45 insertions(+), 51 deletions(-) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 7867e39c0e..623fbca6c9 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -355,21 +355,14 @@ int MultiMarginLossDriver::AllocateBuffersAndCopy() } o_dev = std::make_unique(ctx, o_sz, sizeof(Tgpu)); - O = std::vector(o_sz); - Ohost = std::vector(o_sz); - std::fill(O.begin(), O.end(), 0); - std::fill(Ohost.begin(), Ohost.end(), 0); + O = std::vector(o_sz, static_cast(0)); + Ohost = std::vector(o_sz, static_cast(0)); if(o_dev->ToGPU(GetStream(), O.data()) != 0) std::cerr << "Error copying (out) to GPU, size: " << o_dev->GetSize() << std::endl; size_t ws_sz = ws_sizeInBytes / sizeof(Tgpu); workspace_dev = std::make_unique(ctx, ws_sz, sizeof(Tgpu)); workspace = std::vector(ws_sz); - std::fill(workspace.begin(), workspace.end(), 0); - - if(workspace_dev->ToGPU(GetStream(), workspace.data()) != 0) - std::cerr << "Error copying (workspace) to GPU, size: " << workspace_dev->GetSize() - << std::endl; } return miopenStatusSuccess; diff --git a/src/include/miopen/multimarginloss.hpp b/src/include/miopen/multimarginloss.hpp index fc24ccecce..4d8d0f6566 100644 --- a/src/include/miopen/multimarginloss.hpp +++ b/src/include/miopen/multimarginloss.hpp @@ -34,41 +34,43 @@ namespace miopen { struct Handle; struct TensorDescriptor; -std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, - const TensorDescriptor& iDesc, - const TensorDescriptor& tDesc, - const TensorDescriptor& wDesc, - const TensorDescriptor& oDesc, - long p, - float margin, - miopenLossReductionMode_t reduction); +MIOPEN_INTERNALS_EXPORT std::size_t +GetMultiMarginLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& iDesc, + const TensorDescriptor& tDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& oDesc, + long p, + float margin, + miopenLossReductionMode_t reduction); -miopenStatus_t MultiMarginLossUnreducedForward(Handle& handle, - const TensorDescriptor& iDesc, - ConstData_t i, - const TensorDescriptor& tDesc, - ConstData_t t, - const TensorDescriptor& wDesc, - ConstData_t w, - const TensorDescriptor& oDesc, - Data_t o, - long p, - float margin); +MIOPEN_INTERNALS_EXPORT miopenStatus_t +MultiMarginLossUnreducedForward(Handle& handle, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + long p, + float margin); -miopenStatus_t MultiMarginLossForward(Handle& handle, - Data_t workspace, - size_t workspaceSizeInBytes, - const TensorDescriptor& iDesc, - ConstData_t i, - const TensorDescriptor& tDesc, - ConstData_t t, - const TensorDescriptor& wDesc, - ConstData_t w, - const TensorDescriptor& oDesc, - Data_t o, - long p, - float margin, - miopenLossReductionMode_t reduction); +MIOPEN_INTERNALS_EXPORT miopenStatus_t MultiMarginLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + long p, + float margin, + miopenLossReductionMode_t reduction); } // namespace miopen #endif // _MIOPEN_MULTIMARGINLOSS_HPP_ diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 224f600d49..52a032cdbe 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -69,7 +69,7 @@ __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, if(p == 2) t = t * t; t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; - loss += t / (float)C; + loss += t / C; } O[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); } @@ -127,7 +127,7 @@ __device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, if(p == 2) t = t * t; t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; - loss += t / (float)C; + loss += t / C; } lsum[n] = CVT_ACCUM2FLOAT(loss / divisor); diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index 639e70a15a..d00d41cd40 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -155,10 +155,9 @@ ConvSolution MultiMarginLossForward::GetSolution( HipEventPtr start; HipEventPtr stop; - bool reset_profiling_state = false; - if(handle_.IsProfilingEnabled()) + const bool profiling = handle_.IsProfilingEnabled(); + if(profiling) { - reset_profiling_state = true; handle_.EnableProfiling(false); start = miopen::make_hip_event(); stop = miopen::make_hip_event(); @@ -201,9 +200,7 @@ ConvSolution MultiMarginLossForward::GetSolution( size = AlignUp(size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; } - if(reset_profiling_state) - handle_.EnableProfiling(true); - if(handle_.IsProfilingEnabled()) + if(profiling) { hipEventRecord(stop.get(), handle_.GetStream()); hipEventSynchronize(stop.get()); @@ -214,6 +211,8 @@ ConvSolution MultiMarginLossForward::GetSolution( hipEventDestroy(stop.get()); handle_.ResetKernelTime(); handle_.AccumKernelTime(elapsed); + + handle_.EnableProfiling(true); }; }; }; diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index f8d614e91f..47b1a7b51b 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -173,7 +173,7 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam(-1)) - GTEST_SKIP() << "Call GetMultiMarginLossForwardWorkspaceSize failed!"; + GTEST_FAIL() << "Call GetMultiMarginLossForwardWorkspaceSize failed!"; workspace = tensor{std::vector{ws_sizeInBytes / sizeof(T)}}; std::fill(workspace.begin(), workspace.end(), 0); workspace_dev = handle.Write(workspace.data); From e044841cee75cb5ace102b820b6efa283bc0fa95 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Mon, 29 Jul 2024 08:15:28 +0000 Subject: [PATCH 14/30] change gtest format --- test/gtest/multimarginloss.cpp | 30 +++++++++++++++--------------- test/gtest/multimarginloss.hpp | 9 ++++----- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp index dfc226258a..df2653af41 100644 --- a/test/gtest/multimarginloss.cpp +++ b/test/gtest/multimarginloss.cpp @@ -42,25 +42,25 @@ std::string GetFloatArg() return tmp; } -struct MultiMarginLossForwardTestFloat : MultiMarginLossForwardTest +struct GPU_MultiMarginLoss_FP32 : MultiMarginLossForwardTest { }; -struct MultiMarginLossForwardTestHalf : MultiMarginLossForwardTest +struct GPU_MultiMarginLoss_FP16 : MultiMarginLossForwardTest { }; -struct MultiMarginLossForwardTestBFloat16 : MultiMarginLossForwardTest +struct GPU_MultiMarginLoss_BFP16 : MultiMarginLossForwardTest { }; } // namespace multimarginloss -using multimarginloss::MultiMarginLossForwardTestBFloat16; -using multimarginloss::MultiMarginLossForwardTestFloat; -using multimarginloss::MultiMarginLossForwardTestHalf; +using multimarginloss::GPU_MultiMarginLoss_BFP16; +using multimarginloss::GPU_MultiMarginLoss_FP16; +using multimarginloss::GPU_MultiMarginLoss_FP32; -TEST_P(MultiMarginLossForwardTestFloat, MMLFwdTest) +TEST_P(GPU_MultiMarginLoss_FP32, Test) { if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--float")) @@ -74,7 +74,7 @@ TEST_P(MultiMarginLossForwardTestFloat, MMLFwdTest) } }; -TEST_P(MultiMarginLossForwardTestHalf, MMLFwdTest) +TEST_P(GPU_MultiMarginLoss_FP16, Test) { if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--half")) @@ -88,7 +88,7 @@ TEST_P(MultiMarginLossForwardTestHalf, MMLFwdTest) } }; -TEST_P(MultiMarginLossForwardTestBFloat16, MMLFwdTest) +TEST_P(GPU_MultiMarginLoss_BFP16, Test) { if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--bfloat16")) @@ -102,12 +102,12 @@ TEST_P(MultiMarginLossForwardTestBFloat16, MMLFwdTest) } }; -INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, - MultiMarginLossForwardTestFloat, +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_MultiMarginLoss_FP32, testing::ValuesIn(MultiMarginLossTestConfigs())); -INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, - MultiMarginLossForwardTestHalf, +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_MultiMarginLoss_FP16, testing::ValuesIn(MultiMarginLossFp16TestConfigs())); -INSTANTIATE_TEST_SUITE_P(MultiMarginLossTestSet, - MultiMarginLossForwardTestBFloat16, +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_MultiMarginLoss_BFP16, testing::ValuesIn(MultiMarginLossTestConfigs())); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index 47b1a7b51b..d69f98e0f3 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -41,7 +41,7 @@ struct MultiMarginLossTestCase friend std::ostream& operator<<(std::ostream& os, const MultiMarginLossTestCase& tc) { - os << " dims:"; + os << "dims:"; os << tc.dims[0]; for(int i = 1; i < tc.dims.size(); i++) os << "x" << tc.dims[i]; @@ -225,7 +225,7 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam(output_dev, output.data.size()); @@ -241,9 +241,8 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam Date: Tue, 30 Jul 2024 03:36:43 +0000 Subject: [PATCH 15/30] try to trigger CI/CD --- src/kernels/MIOpenMultiMarginLoss.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 52a032cdbe..5b0e69c1eb 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -107,7 +107,6 @@ __device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, size_t n = gid; if(n >= N) return; - FLOAT_ACCUM loss = 0; uint64_t y = T[T_tv.get_tensor_view_idx({n})]; if(y >= C) From 2af818a872aa659ef96357d7b106d891047ec53c Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Tue, 30 Jul 2024 10:36:16 +0000 Subject: [PATCH 16/30] use MultiBufferWorkspaceTraits --- .../forward_reduced_multimarginloss.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index d00d41cd40..29a0041d84 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -180,10 +180,12 @@ ConvSolution MultiMarginLossForward::GetSolution( /* Phase 2: Reduce */ auto size = deref(params.iDesc).GetLengths()[0]; + auto data_size = get_data_size(deref(params.iDesc).GetType()); + auto wt = MultiBufferWorkspaceTraits{ + size * data_size, (size + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE * data_size}; auto reduce_in = params.workspace; auto reduce_out = - static_cast(static_cast(params.workspace) + - size * get_data_size(deref(params.oDesc).GetType())); + static_cast(static_cast(params.workspace) + wt.GetOffset(1)); for(int i = 1; i < kernels.size(); ++i) { @@ -224,9 +226,11 @@ std::size_t MultiMarginLossForward::GetWorkspaceSize( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { - auto elem = problem.GetiDesc().GetLengths()[0]; - return (elem + (elem + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE) * - get_data_size(problem.GetoDesc().GetType()); + auto size = problem.GetiDesc().GetLengths()[0]; + auto data_size = get_data_size(problem.GetiDesc().GetType()); + return MultiBufferWorkspaceTraits{size * data_size, + (size + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE * data_size} + .GetSize(); } } // namespace multimarginloss From 9b4469c5b532eb1956a8596c4f256e0df6dc3604 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Thu, 1 Aug 2024 06:36:50 +0000 Subject: [PATCH 17/30] add extern C --- src/multimarginloss_api.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp index cb0bdeb7e8..f39b477873 100644 --- a/src/multimarginloss_api.cpp +++ b/src/multimarginloss_api.cpp @@ -57,20 +57,20 @@ miopenGetMultiMarginLossForwardWorkspaceSize(miopenHandle_t handle, }); } -miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, - miopenTensorDescriptor_t inputDesc, - const void* input, - miopenTensorDescriptor_t targetDesc, - const void* target, - miopenTensorDescriptor_t weightDesc, - const void* weight, - miopenTensorDescriptor_t outputDesc, - void* output, - const long p, - const float margin, - miopenLossReductionMode_t reduction, - void* workspace, - size_t workspaceSizeInBytes) +extern "C" miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t weightDesc, + const void* weight, + miopenTensorDescriptor_t outputDesc, + void* output, + const long p, + const float margin, + miopenLossReductionMode_t reduction, + void* workspace, + size_t workspaceSizeInBytes) { MIOPEN_LOG_FUNCTION(handle, inputDesc, From b62c4d8365199920281aa96c1b89cf013a2fbef8 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Fri, 2 Aug 2024 04:18:19 +0000 Subject: [PATCH 18/30] use std::byte instead of char --- src/solver/multimarginloss/forward_reduced_multimarginloss.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index 29a0041d84..bec47becea 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -185,7 +185,7 @@ ConvSolution MultiMarginLossForward::GetSolution( size * data_size, (size + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE * data_size}; auto reduce_in = params.workspace; auto reduce_out = - static_cast(static_cast(params.workspace) + wt.GetOffset(1)); + static_cast(static_cast(params.workspace) + wt.GetOffset(1)); for(int i = 1; i < kernels.size(); ++i) { From eb2994c916df4db5945b3f4a88ab4aa79d8191fd Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Thu, 15 Aug 2024 08:50:12 +0000 Subject: [PATCH 19/30] combine duplicated code in driver, gtest, src --- driver/multimarginloss_driver.hpp | 97 ++++++------------- src/include/miopen/multimarginloss.hpp | 13 --- .../multimarginloss/problem_description.hpp | 10 +- src/multimarginloss.cpp | 76 +++++---------- src/multimarginloss/problem_description.cpp | 2 +- src/multimarginloss_api.cpp | 19 +--- test/cpu_multimarginloss.hpp | 60 +++--------- test/gtest/multimarginloss.hpp | 80 ++++++--------- 8 files changed, 103 insertions(+), 254 deletions(-) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 623fbca6c9..244143b3b6 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -31,7 +31,6 @@ #include "tensor_driver.hpp" #include "timer.hpp" #include "random.hpp" -#include #include #include #include @@ -43,16 +42,17 @@ #include template -int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc, - miopenTensorDescriptor_t tDesc, - miopenTensorDescriptor_t wDesc, - miopenTensorDescriptor_t oDesc, - long p, - float margin, - Tgpu* input, - const uint64_t* target, - Tgpu* weight, - Tcheck* ref_output) +int32_t mloMultiMarginLossForwardRunHost(const miopenTensorDescriptor_t iDesc, + const miopenTensorDescriptor_t tDesc, + const miopenTensorDescriptor_t wDesc, + const miopenTensorDescriptor_t oDesc, + const long p, + const float margin, + const float divisor, + const Tgpu* input, + const uint64_t* target, + const Tgpu* weight, + Tcheck* ref_output) { auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); @@ -60,7 +60,8 @@ int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc auto O_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(oDesc)); auto N = I_tv.size[0], C = I_tv.size[1]; - int32_t ret = 0; + int32_t ret = 0; + double sum_loss = 0; for(size_t n = 0; n < N; n++) { @@ -82,57 +83,16 @@ int32_t mloMultiMarginLossUnreducedForwardRunHost(miopenTensorDescriptor_t iDesc t = weight[W_tv.get_tensor_view_idx({y})] * t; loss += t / C; } - ref_output[O_tv.get_tensor_view_idx({n})] = static_cast(loss); + if(divisor != 0) + sum_loss += loss; + else + ref_output[O_tv.get_tensor_view_idx({n})] = static_cast(loss); } + if(divisor != 0) + ref_output[0] = static_cast(sum_loss / divisor); return ret; } -template -int32_t mloMultiMarginLossReducedForwardRunHost(miopenTensorDescriptor_t iDesc, - miopenTensorDescriptor_t tDesc, - miopenTensorDescriptor_t wDesc, - long p, - float margin, - const float divisor, - Tgpu* input, - const uint64_t* target, - Tgpu* weight, - Tcheck* ref_output) -{ - auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); - auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); - auto W_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(wDesc)); - auto N = I_tv.size[0], C = I_tv.size[1]; - - int32_t ret = 0; - - double sum_loss = 0; - for(size_t n = 0; n < N; n++) - { - double loss = 0; - uint64_t y = target[T_tv.get_tensor_view_idx({n})]; - if(y >= C) - continue; - for(size_t c = 0; c < C; c++) - { - if(y == c) - continue; - double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(input[I_tv.get_tensor_view_idx({n, c})]); - if(t < 0) - continue; - if(p == 2) - t = t * t; - t = weight[W_tv.get_tensor_view_idx({y})] * t; - loss += t / C; - } - sum_loss += loss; - } - ref_output[0] = static_cast(sum_loss / divisor); - - return ret; -}; - template class MultiMarginLossDriver : public Driver { @@ -427,19 +387,16 @@ int MultiMarginLossDriver::RunForwardGPU() template int MultiMarginLossDriver::RunForwardCPU() { - if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + float divisor; + switch(reduction_mode) { - mloMultiMarginLossUnreducedForwardRunHost( - iDesc, tDesc, wDesc, oDesc, p, margin, I.data(), T.data(), W.data(), Ohost.data()); - } - else - { - float divisor = (reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) - ? miopen::deref(iDesc).GetLengths()[0] - : 1; - mloMultiMarginLossReducedForwardRunHost( - iDesc, tDesc, wDesc, p, margin, divisor, I.data(), T.data(), W.data(), Ohost.data()); + case MIOPEN_LOSS_REDUCTION_NONE: divisor = 0; break; + case MIOPEN_LOSS_REDUCTION_MEAN: divisor = miopen::deref(iDesc).GetLengths()[0]; break; + case MIOPEN_LOSS_REDUCTION_SUM: divisor = 1; break; } + + mloMultiMarginLossForwardRunHost( + iDesc, tDesc, wDesc, oDesc, p, margin, divisor, I.data(), T.data(), W.data(), Ohost.data()); return miopenStatusSuccess; } diff --git a/src/include/miopen/multimarginloss.hpp b/src/include/miopen/multimarginloss.hpp index 4d8d0f6566..fa718923a2 100644 --- a/src/include/miopen/multimarginloss.hpp +++ b/src/include/miopen/multimarginloss.hpp @@ -44,19 +44,6 @@ GetMultiMarginLossForwardWorkspaceSize(Handle& handle, float margin, miopenLossReductionMode_t reduction); -MIOPEN_INTERNALS_EXPORT miopenStatus_t -MultiMarginLossUnreducedForward(Handle& handle, - const TensorDescriptor& iDesc, - ConstData_t i, - const TensorDescriptor& tDesc, - ConstData_t t, - const TensorDescriptor& wDesc, - ConstData_t w, - const TensorDescriptor& oDesc, - Data_t o, - long p, - float margin); - MIOPEN_INTERNALS_EXPORT miopenStatus_t MultiMarginLossForward(Handle& handle, Data_t workspace, size_t workspaceSizeInBytes, diff --git a/src/include/miopen/multimarginloss/problem_description.hpp b/src/include/miopen/multimarginloss/problem_description.hpp index c3eeecfa08..15e6d0b1a6 100644 --- a/src/include/miopen/multimarginloss/problem_description.hpp +++ b/src/include/miopen/multimarginloss/problem_description.hpp @@ -45,14 +45,14 @@ struct ForwardProblemDescription : ProblemDescriptionBase const TensorDescriptor& oDesc_, const long p_, const float margin_, - const float divisor_) + const miopenLossReductionMode_t reduction_) : iDesc(iDesc_), tDesc(tDesc_), wDesc(wDesc_), oDesc(oDesc_), p(p_), margin(margin_), - divisor(divisor_) + reduction(reduction_) { if(iDesc.GetType() != oDesc.GetType() || iDesc.GetType() != wDesc.GetType()) { @@ -81,7 +81,7 @@ struct ForwardProblemDescription : ProblemDescriptionBase "tensor has shape (N, C) then weight tensor must have shape (C)"); } // Check output tensor dimension - if(divisor == 0) + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) { // non-reduction case if(oDesc.GetNumDims() != 1 || oDesc.GetLengths()[0] != iDesc.GetLengths()[0]) @@ -114,7 +114,7 @@ struct ForwardProblemDescription : ProblemDescriptionBase const TensorDescriptor& GetoDesc() const { return oDesc; } long Getp() const { return p; } float Getmargin() const { return margin; } - float Getdivisor() const { return divisor; } + miopenLossReductionMode_t Getreduction() const { return reduction; } NetworkConfig MakeNetworkConfig() const override; @@ -125,7 +125,7 @@ struct ForwardProblemDescription : ProblemDescriptionBase TensorDescriptor oDesc; long p; float margin; - float divisor; + miopenLossReductionMode_t reduction; }; } // namespace multimarginloss diff --git a/src/multimarginloss.cpp b/src/multimarginloss.cpp index fb5c4b4df5..e3933c7ae4 100644 --- a/src/multimarginloss.cpp +++ b/src/multimarginloss.cpp @@ -36,46 +36,6 @@ namespace miopen { -miopenStatus_t MultiMarginLossUnreducedForward(Handle& handle, - const TensorDescriptor& iDesc, - ConstData_t i, - const TensorDescriptor& tDesc, - ConstData_t t, - const TensorDescriptor& wDesc, - ConstData_t w, - const TensorDescriptor& oDesc, - Data_t o, - const long p, - const float margin) -{ - const auto problem = - multimarginloss::ForwardProblemDescription{iDesc, tDesc, wDesc, oDesc, p, margin, 0}; - - const auto invoke_params = [&]() { - auto tmp = multimarginloss::InvokeParams{}; - tmp.type = InvokeType::Run; - tmp.iDesc = &iDesc; - tmp.i = i; - tmp.tDesc = &tDesc; - tmp.t = t; - tmp.wDesc = &wDesc; - tmp.w = w; - tmp.oDesc = &oDesc; - tmp.o = o; - tmp.p = p; - tmp.margin = margin; - return tmp; - }(); - - const auto algo = AlgorithmName{"MultiMarginLossUnreducedForward"}; - const auto solvers = - solver::SolverContainer{}; - - solvers.ExecutePrimitive(handle, problem, algo, invoke_params); - - return miopenStatusSuccess; -} - std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, const TensorDescriptor& iDesc, const TensorDescriptor& tDesc, @@ -89,10 +49,9 @@ std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, { return static_cast(0); } - auto ctx = ExecutionContext{&handle}; - const float divisor = (reduction == MIOPEN_LOSS_REDUCTION_MEAN) ? iDesc.GetLengths()[0] : 1; - const auto problem = - multimarginloss::ForwardProblemDescription{iDesc, tDesc, wDesc, oDesc, p, margin, divisor}; + auto ctx = ExecutionContext{&handle}; + const auto problem = multimarginloss::ForwardProblemDescription{ + iDesc, tDesc, wDesc, oDesc, p, margin, reduction}; const auto solvers = solver::SolverContainer{}; @@ -115,9 +74,8 @@ miopenStatus_t MultiMarginLossForward(Handle& handle, const float margin, miopenLossReductionMode_t reduction) { - const float divisor = (reduction == MIOPEN_LOSS_REDUCTION_MEAN) ? iDesc.GetLengths()[0] : 1; - const auto problem = - multimarginloss::ForwardProblemDescription{iDesc, tDesc, wDesc, oDesc, p, margin, divisor}; + const auto problem = multimarginloss::ForwardProblemDescription{ + iDesc, tDesc, wDesc, oDesc, p, margin, reduction}; const auto invoke_params = [&]() { auto tmp = multimarginloss::InvokeParams{}; @@ -134,14 +92,30 @@ miopenStatus_t MultiMarginLossForward(Handle& handle, tmp.margin = margin; tmp.workspace = workspace; tmp.workspace_size = workspaceSizeInBytes; - tmp.divisor = divisor; + switch(reduction) + { + case MIOPEN_LOSS_REDUCTION_NONE: tmp.divisor = 0; break; + case MIOPEN_LOSS_REDUCTION_MEAN: tmp.divisor = iDesc.GetLengths()[0]; break; + case MIOPEN_LOSS_REDUCTION_SUM: tmp.divisor = 1; break; + } return tmp; }(); - const auto algo = AlgorithmName{"MultiMarginLossForward"}; - const auto solvers = solver::SolverContainer{}; + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + const auto algo = AlgorithmName{"MultiMarginLossUnreducedForward"}; + const auto solvers = + solver::SolverContainer{}; - solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + } + else + { + const auto algo = AlgorithmName{"MultiMarginLossForward"}; + const auto solvers = + solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + } return miopenStatusSuccess; } diff --git a/src/multimarginloss/problem_description.cpp b/src/multimarginloss/problem_description.cpp index ca094f819d..b275d5a4fe 100644 --- a/src/multimarginloss/problem_description.cpp +++ b/src/multimarginloss/problem_description.cpp @@ -42,7 +42,7 @@ NetworkConfig ForwardProblemDescription::MakeNetworkConfig() const for(unsigned long i : ilen) ss << i << "_"; ss << "cont" << iDesc.IsContiguous(); - ss << "divisor" << divisor; + ss << "reduction" << reduction; return NetworkConfig{ss.str()}; } diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp index f39b477873..479cfd91b4 100644 --- a/src/multimarginloss_api.cpp +++ b/src/multimarginloss_api.cpp @@ -87,23 +87,8 @@ extern "C" miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, workspace, workspaceSizeInBytes); - if(reduction == MIOPEN_LOSS_REDUCTION_NONE) - { - return miopen::try_([&] { - miopen::MultiMarginLossUnreducedForward(miopen::deref(handle), - miopen::deref(inputDesc), - DataCast(input), - miopen::deref(targetDesc), - DataCast(target), - miopen::deref(weightDesc), - DataCast(weight), - miopen::deref(outputDesc), - DataCast(output), - p, - margin); - }); - } - else if(reduction == MIOPEN_LOSS_REDUCTION_SUM || reduction == MIOPEN_LOSS_REDUCTION_MEAN) + if(reduction == MIOPEN_LOSS_REDUCTION_NONE || reduction == MIOPEN_LOSS_REDUCTION_SUM || + reduction == MIOPEN_LOSS_REDUCTION_MEAN) { return miopen::try_([&] { miopen::MultiMarginLossForward(miopen::deref(handle), diff --git a/test/cpu_multimarginloss.hpp b/test/cpu_multimarginloss.hpp index 0bde558d62..c2b7fbb5ca 100644 --- a/test/cpu_multimarginloss.hpp +++ b/test/cpu_multimarginloss.hpp @@ -31,12 +31,13 @@ #include template -void cpu_multimarginloss_unreduced_forward(tensor input, - tensor target, - tensor weight, - tensor& ref_output, - long p, - float margin) +void cpu_multimarginloss_forward(tensor input, + tensor target, + tensor weight, + tensor& ref_output, + const long p, + const float margin, + const float divisor) { auto I_tv = miopen::get_inner_expanded_tv<2>(input.desc); auto T_tv = miopen::get_inner_expanded_tv<1>(target.desc); @@ -44,44 +45,6 @@ void cpu_multimarginloss_unreduced_forward(tensor input, auto O_tv = miopen::get_inner_expanded_tv<1>(ref_output.desc); auto N = I_tv.size[0], C = I_tv.size[1]; - for(size_t n = 0; n < N; n++) - { - double loss = 0; - uint64_t y = target[T_tv.get_tensor_view_idx({n})]; - if(y >= C) - continue; - for(size_t c = 0; c < C; c++) - { - if(y == c) - continue; - double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + - static_cast(input[I_tv.get_tensor_view_idx({n, c})]); - - if(t < 0) - continue; - if(p == 2) - t = t * t; - t = static_cast(weight[W_tv.get_tensor_view_idx({y})]) * t; - loss += t / C; - } - ref_output[O_tv.get_tensor_view_idx({n})] = loss; - } -} - -template -void cpu_multimarginloss_reduced_forward(tensor input, - tensor target, - tensor weight, - tensor& ref_output, - long p, - float margin, - const float divisor) -{ - auto I_tv = miopen::get_inner_expanded_tv<2>(input.desc); - auto T_tv = miopen::get_inner_expanded_tv<1>(target.desc); - auto W_tv = miopen::get_inner_expanded_tv<1>(weight.desc); - auto N = I_tv.size[0], C = I_tv.size[1]; - double sum = 0; for(size_t n = 0; n < N; n++) { @@ -103,10 +66,13 @@ void cpu_multimarginloss_reduced_forward(tensor input, t = static_cast(weight[W_tv.get_tensor_view_idx({y})]) * t; loss += t / C; } - sum += loss; + if(divisor == 0) + ref_output[O_tv.get_tensor_view_idx({n})] = loss; + else + sum += loss; } - sum /= divisor; - ref_output[0] = static_cast(sum); + if(divisor != 0) + ref_output[0] = static_cast(sum / divisor); } #endif diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index d69f98e0f3..604d370825 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -162,69 +162,49 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam(-1)) + GTEST_FAIL() << "Call GetMultiMarginLossForwardWorkspaceSize failed!"; + if(ws_sizeInBytes > 0) { - ws_sizeInBytes = miopen::GetMultiMarginLossForwardWorkspaceSize(handle, - input.desc, - target.desc, - weight.desc, - output.desc, - p, - margin, - reduction_mode); - if(ws_sizeInBytes == static_cast(-1)) - GTEST_FAIL() << "Call GetMultiMarginLossForwardWorkspaceSize failed!"; workspace = tensor{std::vector{ws_sizeInBytes / sizeof(T)}}; std::fill(workspace.begin(), workspace.end(), 0); workspace_dev = handle.Write(workspace.data); } + else + { + workspace_dev = nullptr; + } } void RunTest() { auto&& handle = get_handle(); miopenStatus_t status; - if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + float divisor; + switch(reduction_mode) { - cpu_multimarginloss_unreduced_forward(input, target, weight, ref_output, p, margin); - - status = miopen::MultiMarginLossUnreducedForward(handle, - input.desc, - input_dev.get(), - target.desc, - target_dev.get(), - weight.desc, - weight_dev.get(), - output.desc, - output_dev.get(), - p, - margin); + case MIOPEN_LOSS_REDUCTION_NONE: divisor = 0; break; + case MIOPEN_LOSS_REDUCTION_MEAN: divisor = input.desc.GetLengths()[0]; break; + case MIOPEN_LOSS_REDUCTION_SUM: divisor = 1; break; } - else - { - cpu_multimarginloss_reduced_forward( - input, - target, - weight, - ref_output, - p, - margin, - (reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) ? input.desc.GetLengths()[0] : 1); + cpu_multimarginloss_forward(input, target, weight, ref_output, p, margin, divisor); + + status = miopen::MultiMarginLossForward(handle, + workspace_dev.get(), + ws_sizeInBytes, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + weight.desc, + weight_dev.get(), + output.desc, + output_dev.get(), + p, + margin, + reduction_mode); - status = miopen::MultiMarginLossForward(handle, - workspace_dev.get(), - ws_sizeInBytes, - input.desc, - input_dev.get(), - target.desc, - target_dev.get(), - weight.desc, - weight_dev.get(), - output.desc, - output_dev.get(), - p, - margin, - reduction_mode); - } ASSERT_EQ(status, miopenStatusSuccess); // Write from GPU to CPU From 8b4b050b19dc8787818f488126f945c046bd8cbd Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Fri, 16 Aug 2024 09:13:05 +0000 Subject: [PATCH 20/30] fma --- src/kernels/MIOpenMultiMarginLoss.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 5b0e69c1eb..1928ab1ff6 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -68,8 +68,9 @@ __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, continue; if(p == 2) t = t * t; - t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; - loss += t / C; + t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; + FLOAT_ACCUM rC = 1 / static_cast(C); + loss = fma(t, rC, loss); } O[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); } @@ -125,8 +126,9 @@ __device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, continue; if(p == 2) t = t * t; - t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; - loss += t / C; + t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; + FLOAT_ACCUM rC = 1 / static_cast(C); + loss = fma(t, rC, loss); } lsum[n] = CVT_ACCUM2FLOAT(loss / divisor); From 5384ab8fd3b1964ff600a76f3eeef5f8c433c398 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 21 Aug 2024 03:30:26 +0000 Subject: [PATCH 21/30] optimize kernel + update condition in solver --- driver/multimarginloss_driver.hpp | 12 +++++++- .../multimarginloss/problem_description.hpp | 5 ++++ src/kernels/MIOpenMultiMarginLoss.cpp | 14 +++++---- .../forward_unreduced_multimarginloss.cpp | 30 ++++++++++++++----- 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 244143b3b6..21ae142786 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -31,6 +31,7 @@ #include "tensor_driver.hpp" #include "timer.hpp" #include "random.hpp" +#include #include #include #include @@ -167,7 +168,7 @@ int MultiMarginLossDriver::AddCmdLineArgs() inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Take (Default=1)", "int"); inflags.AddInputFlag("dim", 'D', "41x4", "Dim of input tensor (Default=41x4)", "tensor"); inflags.AddInputFlag("contiguous", 'C', "1", "Tensor is contiguous or not (Default=1)", "int"); - inflags.AddInputFlag("iter", 'i', "1", "Number of Iterations (Default=1)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); inflags.AddInputFlag( @@ -334,6 +335,8 @@ int MultiMarginLossDriver::RunForwardGPU() float kernel_total_time = 0; float kernel_first_time = 0; + std::vector time_vector; + Timer t; START_TIME @@ -359,8 +362,15 @@ int MultiMarginLossDriver::RunForwardGPU() kernel_total_time += time; if(i == 0) kernel_first_time = time; + else + time_vector.push_back(time); } + std::cerr << "Min between iterations: " + << *std::min_element(time_vector.begin(), time_vector.end()) << std::endl; + std::cerr << "Max between iterations: " + << *std::max_element(time_vector.begin(), time_vector.end()) << std::endl; + if(inflags.GetValueInt("time") == 1) { STOP_TIME diff --git a/src/include/miopen/multimarginloss/problem_description.hpp b/src/include/miopen/multimarginloss/problem_description.hpp index 15e6d0b1a6..db1c8fa875 100644 --- a/src/include/miopen/multimarginloss/problem_description.hpp +++ b/src/include/miopen/multimarginloss/problem_description.hpp @@ -115,6 +115,11 @@ struct ForwardProblemDescription : ProblemDescriptionBase long Getp() const { return p; } float Getmargin() const { return margin; } miopenLossReductionMode_t Getreduction() const { return reduction; } + bool allContiguousTensor() const + { + return iDesc.IsContiguous() && tDesc.IsContiguous() && wDesc.IsContiguous() && + oDesc.IsContiguous(); + } NetworkConfig MakeNetworkConfig() const override; diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 1928ab1ff6..8b2c04a7a5 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -51,27 +51,29 @@ __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, return; FLOAT_ACCUM loss = 0; - uint64_t y = T[T_tv.get_tensor_view_idx({n})]; + size_t y = T[T_tv.get_tensor_view_idx({n})]; if(y >= C) { // TODO: need to handle invalid target index value return; } + FLOAT_ACCUM Iny = CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, y})]); + FLOAT_ACCUM Wy = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]); + for(size_t c = 0; c < C; c++) { if(y == c) continue; - FLOAT_ACCUM t = margin - CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, y})]) + - CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, c})]); + FLOAT_ACCUM t = margin - Iny + CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, c})]); if(t < 0) continue; if(p == 2) t = t * t; - t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; - FLOAT_ACCUM rC = 1 / static_cast(C); - loss = fma(t, rC, loss); + t = Wy * t; + loss += t; } + loss /= C; O[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); } diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp index 729df7d431..1c0e27a74f 100644 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -45,13 +45,29 @@ bool MultiMarginLossUnreducedForward::IsImprovementOverROCm( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { - if(problem.GetiDesc().GetLengths()[1] <= 30) - return true; - if((problem.GetiDesc().GetType() == miopenHalf || - problem.GetiDesc().GetType() == miopenBFloat16) && - problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] <= 40) - return true; - return false; + int C = problem.GetiDesc().GetLengths()[1]; + if(problem.allContiguousTensor()) + { + switch(problem.GetiDesc().GetType()) + { + case miopenFloat: return C <= 32; + case miopenHalf: return C <= 43; + case miopenBFloat16: return C <= 44; + // Have not tested with other types yet + default: return true; + } + } + else + { + switch(problem.GetiDesc().GetType()) + { + case miopenFloat: return C <= 31; + case miopenHalf: return C <= 38; + case miopenBFloat16: return C <= 40; + // Have not tested with other types yet + default: return true; + } + } } bool MultiMarginLossUnreducedForward::IsApplicable( From 86149f6ac37b2a9ccca4e051979bc1b229b9b8ed Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 21 Aug 2024 07:39:40 +0000 Subject: [PATCH 22/30] merge kernel + avoid convert to float after each kernel --- driver/multimarginloss_driver.hpp | 17 +-- src/CMakeLists.txt | 5 +- src/kernels/MIOpenMultiMarginLoss.cpp | 94 ++++------------ src/kernels/MIOpenReduceSum.cpp | 103 ++++++++++++++++++ .../{warp_shuffle.hpp => block_reduce.hpp} | 53 +++++---- .../{MIOpenLossReduce.cpp => warp_reduce.hpp} | 33 +++--- .../forward_reduced_multimarginloss.cpp | 69 ++++++++---- .../forward_unreduced_multimarginloss.cpp | 3 +- test/gtest/multimarginloss.hpp | 4 +- 9 files changed, 223 insertions(+), 158 deletions(-) create mode 100644 src/kernels/MIOpenReduceSum.cpp rename src/kernels/{warp_shuffle.hpp => block_reduce.hpp} (59%) rename src/kernels/{MIOpenLossReduce.cpp => warp_reduce.hpp} (70%) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 21ae142786..9bf354f805 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -154,7 +154,6 @@ class MultiMarginLossDriver : public Driver std::vector W; std::vector O; std::vector Ohost; - std::vector workspace; long p; float margin; @@ -321,9 +320,7 @@ int MultiMarginLossDriver::AllocateBuffersAndCopy() if(o_dev->ToGPU(GetStream(), O.data()) != 0) std::cerr << "Error copying (out) to GPU, size: " << o_dev->GetSize() << std::endl; - size_t ws_sz = ws_sizeInBytes / sizeof(Tgpu); - workspace_dev = std::make_unique(ctx, ws_sz, sizeof(Tgpu)); - workspace = std::vector(ws_sz); + workspace_dev = std::make_unique(ctx, ws_sizeInBytes, sizeof(std::byte)); } return miopenStatusSuccess; @@ -335,8 +332,6 @@ int MultiMarginLossDriver::RunForwardGPU() float kernel_total_time = 0; float kernel_first_time = 0; - std::vector time_vector; - Timer t; START_TIME @@ -362,15 +357,8 @@ int MultiMarginLossDriver::RunForwardGPU() kernel_total_time += time; if(i == 0) kernel_first_time = time; - else - time_vector.push_back(time); } - std::cerr << "Min between iterations: " - << *std::min_element(time_vector.begin(), time_vector.end()) << std::endl; - std::cerr << "Max between iterations: " - << *std::max_element(time_vector.begin(), time_vector.end()) << std::endl; - if(inflags.GetValueInt("time") == 1) { STOP_TIME @@ -387,9 +375,6 @@ int MultiMarginLossDriver::RunForwardGPU() if(o_dev->FromGPU(GetStream(), O.data()) != 0) std::cerr << "Error copying (o_dev) from GPU, size: " << o_dev->GetSize() << std::endl; - if(workspace_dev->FromGPU(GetStream(), workspace.data()) != 0) - std::cerr << "Error copying (workspace_dev) from GPU, size: " << workspace_dev->GetSize() - << std::endl; return miopenStatusSuccess; } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0fa4a81746..9dd357540d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -452,6 +452,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenReduceCalculation.hpp kernels/MIOpenReduceExtreme.hpp kernels/bfloat16_dev.hpp + kernels/block_reduce.hpp kernels/conv_common.inc kernels/conv_sizes.inc kernels/float_types.h @@ -470,7 +471,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/stride_array.hpp kernels/tensor_view.hpp kernels/utilities.inc - kernels/warp_shuffle.hpp + kernels/warp_reduce.hpp kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c16_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c32_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1024vgprs_fp16_fp16acc_f2x3_c16_stride1.inc @@ -510,7 +511,6 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenGroupNorm.cpp kernels/MIOpenGetitem.cpp kernels/MIOpenLayerNorm.cpp - kernels/MIOpenLossReduce.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl kernels/MIOpenMultiMarginLoss.cpp @@ -525,6 +525,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConv1x1J1_stride.cl kernels/MIOpenReduceCalculation.cpp kernels/MIOpenReduceExtreme.cpp + kernels/MIOpenReduceSum.cpp kernels/MIOpenRoPE.cpp kernels/MIOpenSoftmax.cl kernels/MIOpenSoftmaxAttn.cpp diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 8b2c04a7a5..1ab7811eba 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -31,17 +31,17 @@ #include "float_types.h" #include "tensor_view.hpp" -template -__device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, - const uint64_t* __restrict__ T, - const DTYPE* __restrict__ W, - DTYPE* __restrict__ O, - const long p, - const float margin, - tensor_view_t<2> I_tv, - tensor_view_t<1> T_tv, - tensor_view_t<1> W_tv, - tensor_view_t<1> O_tv) +template +__device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, + const uint64_t* __restrict__ T, + const DTYPE* __restrict__ W, + void* __restrict__ O, + const long p, + const float margin, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv, + tensor_view_t<1> O_tv) { const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; @@ -74,79 +74,25 @@ __device__ void multimarginlossunreducedforward2d(const DTYPE* __restrict__ I, loss += t; } loss /= C; - O[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); -} - -extern "C" __global__ void MultiMarginLossUnreducedForward2d(const FLOAT* __restrict__ I, - const uint64_t* __restrict__ T, - const FLOAT* __restrict__ W, - FLOAT* __restrict__ O, - const long p, - const float margin, - tensor_view_t<2> I_tv, - tensor_view_t<1> T_tv, - tensor_view_t<1> W_tv, - tensor_view_t<1> O_tv) -{ - // instantiate the kernel - multimarginlossunreducedforward2d(I, T, W, O, p, margin, I_tv, T_tv, W_tv, O_tv); -} - -template -__device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, - const uint64_t* __restrict__ T, - const DTYPE* __restrict__ W, - DTYPE* __restrict__ lsum, - const long p, - const float margin, - const float divisor, - tensor_view_t<2> I_tv, - tensor_view_t<1> T_tv, - tensor_view_t<1> W_tv) -{ - const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; - - size_t N = I_tv.size[0], C = I_tv.size[1]; - size_t n = gid; - if(n >= N) - return; - FLOAT_ACCUM loss = 0; - uint64_t y = T[T_tv.get_tensor_view_idx({n})]; - if(y >= C) + switch(REDUCTION_T) { - // TODO: need to handle invalid target index value - return; + case 0: static_cast(O)[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); break; + case 1: static_cast(O)[n] = loss; break; + case 2: static_cast(O)[n] = loss / N; break; } - - for(size_t c = 0; c < C; c++) - { - if(y == c) - continue; - FLOAT_ACCUM t = margin - CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, y})]) + - CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, c})]); - if(t < 0) - continue; - if(p == 2) - t = t * t; - t = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]) * t; - FLOAT_ACCUM rC = 1 / static_cast(C); - loss = fma(t, rC, loss); - } - - lsum[n] = CVT_ACCUM2FLOAT(loss / divisor); } extern "C" __global__ void MultiMarginLossForward2d(const FLOAT* __restrict__ I, const uint64_t* __restrict__ T, const FLOAT* __restrict__ W, - FLOAT* __restrict__ lsum, + void* __restrict__ O, const long p, const float margin, - const float divisor, tensor_view_t<2> I_tv, tensor_view_t<1> T_tv, - tensor_view_t<1> W_tv) + tensor_view_t<1> W_tv, + tensor_view_t<1> O_tv) { // instantiate the kernel - multimarginlossforward2d(I, T, W, lsum, p, margin, divisor, I_tv, T_tv, W_tv); -} + multimarginlossforward2d(I, T, W, O, p, margin, I_tv, T_tv, W_tv, O_tv); +} \ No newline at end of file diff --git a/src/kernels/MIOpenReduceSum.cpp b/src/kernels/MIOpenReduceSum.cpp new file mode 100644 index 0000000000..ec431bc1e0 --- /dev/null +++ b/src/kernels/MIOpenReduceSum.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" +#include "tensor_view.hpp" +#include "block_reduce.hpp" + +template +__device__ void +ReduceSum(const FLOAT_ACCUM* input, TO* output, uint64_t N, tensor_view_t<1> output_tv) +{ + uint64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + + FLOAT_ACCUM val = gid < N ? input[gid] : CVT_FP32_2ACCUM(0.0f); + val = block_reduce(val); + + if(threadIdx.x == 0) + output[output_tv.get_tensor_view_idx({blockIdx.x})] = CVT_ACCUM2FLOAT(val); +} + +extern "C" __global__ void ReduceSum(const FLOAT_ACCUM* __restrict__ input, + FLOAT* __restrict__ output, + uint64_t N, + tensor_view_t<1> output_tv) +{ + // instantiate the kernel + ReduceSum(input, output, N, output_tv); +} + +extern "C" __global__ void ReduceSumFLOATACCUM(const FLOAT_ACCUM* __restrict__ input, + FLOAT_ACCUM* __restrict__ output, + uint64_t N) +{ + uint64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + + FLOAT_ACCUM val = gid < N ? input[gid] : 0.0f; + val = block_reduce(val); + + if(threadIdx.x == 0) + output[blockIdx.x] = val; +} + +template +__device__ void Reduce1dSum(const FLOAT_ACCUM* __restrict__ input, + TO* __restrict__ output, + uint64_t output_numel, + uint64_t inner_size, + uint64_t outer_size, + tensor_view_t<1> output_tv) +{ + uint64_t tid = threadIdx.x; + uint64_t oidx = blockIdx.x; + + // use double instead of FLOAT_ACCUM for better precision + double sum_double = 0.0; + for(uint64_t i = tid; i < outer_size * inner_size; i += blockDim.x) + sum_double += static_cast( + input[i / inner_size * output_numel * inner_size + oidx * inner_size + i % inner_size]); + + FLOAT_ACCUM sum = static_cast(sum_double); + sum = block_reduce(sum); + + if(tid == 0) + output[output_tv.get_tensor_view_idx({oidx})] = CVT_ACCUM2FLOAT(sum); +} + +extern "C" __global__ void Reduce1dSum(const FLOAT_ACCUM* __restrict__ input, + FLOAT* __restrict__ output, + uint64_t output_numel, + uint64_t inner_size, + uint64_t outer_size, + tensor_view_t<1> output_tv) +{ + // instantiate the kernel + Reduce1dSum(input, output, output_numel, inner_size, outer_size, output_tv); +} \ No newline at end of file diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/block_reduce.hpp similarity index 59% rename from src/kernels/warp_shuffle.hpp rename to src/kernels/block_reduce.hpp index db536bd2ae..7d298a6ead 100644 --- a/src/kernels/warp_shuffle.hpp +++ b/src/kernels/block_reduce.hpp @@ -23,9 +23,8 @@ * SOFTWARE. * *******************************************************************************/ - -#ifndef GUARD_WARP_SHUFFLE_HPP -#define GUARD_WARP_SHUFFLE_HPP +#ifndef GUARD_BLOCK_REDUCE_HPP +#define GUARD_BLOCK_REDUCE_HPP #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include @@ -33,41 +32,41 @@ #endif #include "float_types.h" +#include "warp_reduce.hpp" -__device__ FLOAT_ACCUM warp_reduce_sum(FLOAT_ACCUM val) +enum class ReduceThreadDim : int32_t { - if(warpSize >= 64) - val += __shfl_down(val, 32); - if(warpSize >= 32) - val += __shfl_down(val, 16); - if(warpSize >= 16) - val += __shfl_down(val, 8); - if(warpSize >= 8) - val += __shfl_down(val, 4); - if(warpSize >= 4) - val += __shfl_down(val, 2); - if(warpSize >= 2) - val += __shfl_down(val, 1); - return val; -} + X = 1 << 0, + Y = 1 << 1, + Z = 1 << 2, +}; -__device__ FLOAT_ACCUM block_reduce_sum(FLOAT_ACCUM val) +template +__device__ FLOAT_ACCUM block_reduce(FLOAT_ACCUM val) { - static __shared__ FLOAT_ACCUM shared[REDUCE_SIZE / warpSize]; - auto lane = threadIdx.x % warpSize; - auto wid = threadIdx.x / warpSize; + if(reduce_size == warpSize) + return warp_reduce(val); - val = warp_reduce_sum(val); + static __shared__ FLOAT_ACCUM shared[reduce_size / warpSize]; + uint64_t tid = 0; + if(static_cast(thread_dim) & static_cast(ReduceThreadDim::X)) + tid += threadIdx.x; + if(static_cast(thread_dim) & static_cast(ReduceThreadDim::Y)) + tid = tid * blockDim.y + threadIdx.y; + if(static_cast(thread_dim) & static_cast(ReduceThreadDim::Z)) + tid = tid * blockDim.z + threadIdx.z; + const uint64_t lane = tid % warpSize; + const uint64_t wid = tid / warpSize; + val = warp_reduce(val); if(lane == 0) shared[wid] = val; __syncthreads(); - val = threadIdx.x < REDUCE_SIZE / warpSize ? shared[lane] : 0; + val = tid < reduce_size / warpSize ? shared[lane] : 0; if(wid == 0) - val = warp_reduce_sum(val); - + val = warp_reduce(val); return val; } -#endif // GUARD_WARP_SHUFFLE_HPP +#endif // GUARD_BLOCK_REDUCE_HPP diff --git a/src/kernels/MIOpenLossReduce.cpp b/src/kernels/warp_reduce.hpp similarity index 70% rename from src/kernels/MIOpenLossReduce.cpp rename to src/kernels/warp_reduce.hpp index a86839d3c7..f1490bd214 100644 --- a/src/kernels/MIOpenLossReduce.cpp +++ b/src/kernels/warp_reduce.hpp @@ -23,29 +23,36 @@ * SOFTWARE. * *******************************************************************************/ +#ifndef GUARD_WARP_REDUCE_HPP +#define GUARD_WARP_REDUCE_HPP + #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include #endif #include "float_types.h" -#include "warp_shuffle.hpp" -template -__device__ void LossSum(const DTYPE* __restrict__ input, DTYPE* __restrict__ output, uint64_t N) +enum class BinaryOp_t { - auto gid = blockIdx.x * blockDim.x + threadIdx.x; + Add, +}; - FLOAT_ACCUM val = gid < N ? CVT_FLOAT2ACCUM(input[gid]) : CVT_FP32_2ACCUM(0.0f); - val = block_reduce_sum(val); +template +struct BinaryFunc; - if(threadIdx.x == 0) - output[blockIdx.x] = CVT_ACCUM2FLOAT(val); -} +template +struct BinaryFunc +{ + constexpr void exec(T& a, const T& b) { a += b; } +}; -extern "C" __global__ void -ReduceSumLoss(const FLOAT* __restrict__ input, FLOAT* __restrict__ output, uint64_t N) +template +__device__ FLOAT_ACCUM warp_reduce(FLOAT_ACCUM val) { - // instantiate the kernel - LossSum(input, output, N); + for(auto d = ws / 2; d >= 1; d >>= 1) + BinaryFunc{}.exec(val, __shfl_down(val, d)); + return val; } + +#endif // GUARD_WARP_REDUCE_HPP diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index bec47becea..dfec238da7 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -91,6 +91,7 @@ ConvSolution MultiMarginLossForward::GetSolution( {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"REDUCTION_TYPE", static_cast(problem.Getreduction())}, }; kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); @@ -107,7 +108,7 @@ ConvSolution MultiMarginLossForward::GetSolution( } { - /* Phase 2: Reduce */ + /* Phase 2: Reduce FLOAT_ACCUM -> FLOAT_ACCUM */ auto _size = xgrid; const auto build_params = KernelBuildParameters{ {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, @@ -116,7 +117,7 @@ ConvSolution MultiMarginLossForward::GetSolution( {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, {"REDUCE_SIZE", LOCAL_SIZE_REDUCE}, }; - do + while(_size > LOCAL_SIZE_REDUCE) { size_t xlocalsize = LOCAL_SIZE_REDUCE; size_t xgridsize = AlignUp(_size, xlocalsize); @@ -126,8 +127,8 @@ ConvSolution MultiMarginLossForward::GetSolution( size_t zgridsize = 1; auto kernel = KernelInfo{}; - kernel.kernel_file = "MIOpenLossReduce.cpp"; - kernel.kernel_name = "ReduceSumLoss"; + kernel.kernel_file = "MIOpenReduceSum.cpp"; + kernel.kernel_name = "ReduceSumFLOATACCUM"; kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); @@ -140,8 +141,32 @@ ConvSolution MultiMarginLossForward::GetSolution( kernel.g_wk.push_back(zgridsize); result.construction_params.push_back(kernel); - _size = AlignUp(_size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; - } while(_size > 1); + _size = (_size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; + } + + // Last kernel reduce: FLOAT_ACCUM -> FLOAT + size_t xlocalsize = LOCAL_SIZE_REDUCE; + size_t xgridsize = AlignUp(_size, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenReduceSum.cpp"; + kernel.kernel_name = "ReduceSum"; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); } result.invoker_factory = [](const std::vector& kernels) { @@ -150,6 +175,7 @@ ConvSolution MultiMarginLossForward::GetSolution( auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); float elapsed = 0.0f; HipEventPtr start; @@ -180,28 +206,25 @@ ConvSolution MultiMarginLossForward::GetSolution( /* Phase 2: Reduce */ auto size = deref(params.iDesc).GetLengths()[0]; - auto data_size = get_data_size(deref(params.iDesc).GetType()); + auto data_size = get_data_size(miopenFloat); auto wt = MultiBufferWorkspaceTraits{ - size * data_size, (size + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE * data_size}; + size * data_size, (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE * data_size}; auto reduce_in = params.workspace; auto reduce_out = static_cast(static_cast(params.workspace) + wt.GetOffset(1)); - for(int i = 1; i < kernels.size(); ++i) + int kernelCnt = 1; + for(kernelCnt; kernelCnt < kernels.size() - 1; kernelCnt++) { - decltype(auto) kernel = handle_.Run(kernels[i]); - if(i + 1 != kernels.size()) - { - kernel(reduce_in, reduce_out, size); - std::swap(reduce_in, reduce_out); - } - else - { - kernel(reduce_in, params.o, size); - } - size = AlignUp(size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; + decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); + kernel(reduce_in, reduce_out, size); + std::swap(reduce_in, reduce_out); + size = (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; } + decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); + kernel(reduce_in, params.o, size, o_tv); + if(profiling) { hipEventRecord(stop.get(), handle_.GetStream()); @@ -227,9 +250,9 @@ std::size_t MultiMarginLossForward::GetWorkspaceSize( const miopen::multimarginloss::ForwardProblemDescription& problem) const { auto size = problem.GetiDesc().GetLengths()[0]; - auto data_size = get_data_size(problem.GetiDesc().GetType()); - return MultiBufferWorkspaceTraits{size * data_size, - (size + LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE * data_size} + auto data_size = get_data_size(miopenFloat); + return MultiBufferWorkspaceTraits{ + size * data_size, (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE * data_size} .GetSize(); } diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp index 1c0e27a74f..f6c3ec7b47 100644 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -98,13 +98,14 @@ ConvSolution MultiMarginLossUnreducedForward::GetSolution( auto kernel = KernelInfo{}; kernel.kernel_file = "MIOpenMultiMarginLoss.cpp"; - kernel.kernel_name = "MultiMarginLossUnreducedForward2d"; + kernel.kernel_name = "MultiMarginLossForward2d"; const auto build_params = KernelBuildParameters{ {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"REDUCTION_TYPE", static_cast(problem.Getreduction())}, }; kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index 604d370825..bba72cecb9 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -168,7 +168,7 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam 0) { - workspace = tensor{std::vector{ws_sizeInBytes / sizeof(T)}}; + workspace = tensor{std::vector{ws_sizeInBytes / sizeof(float)}}; std::fill(workspace.begin(), workspace.end(), 0); workspace_dev = handle.Write(workspace.data); } @@ -230,7 +230,7 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam target; tensor weight; tensor output; - tensor workspace; + tensor workspace; tensor ref_output; From fdc42329c9118e1c80bfc465032d9cbdfbe1d8df Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 21 Aug 2024 08:19:30 +0000 Subject: [PATCH 23/30] update isimprovementoverrocm --- .../forward_reduced_multimarginloss.cpp | 30 ++++++++++++++----- .../forward_unreduced_multimarginloss.cpp | 2 +- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index dfec238da7..b58fa3020d 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -46,13 +46,29 @@ bool MultiMarginLossForward::IsImprovementOverROCm( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { - if(problem.GetiDesc().GetLengths()[1] <= 30) - return true; - if((problem.GetiDesc().GetType() == miopenHalf || - problem.GetiDesc().GetType() == miopenBFloat16) && - problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] <= 40) - return true; - return false; + int C = problem.GetiDesc().GetLengths()[1]; + if(problem.allContiguousTensor()) + { + switch(problem.GetiDesc().GetType()) + { + case miopenFloat: return C <= 33; + case miopenHalf: return C <= 43; + case miopenBFloat16: return C <= 44; + // Have not tested with other types yet + default: return true; + } + } + else + { + switch(problem.GetiDesc().GetType()) + { + case miopenFloat: return C <= 31; + case miopenHalf: return C <= 38; + case miopenBFloat16: return C <= 40; + // Have not tested with other types yet + default: return true; + } + } } bool MultiMarginLossForward::IsApplicable( diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp index f6c3ec7b47..2d95aecf20 100644 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp @@ -50,7 +50,7 @@ bool MultiMarginLossUnreducedForward::IsImprovementOverROCm( { switch(problem.GetiDesc().GetType()) { - case miopenFloat: return C <= 32; + case miopenFloat: return C <= 33; case miopenHalf: return C <= 43; case miopenBFloat16: return C <= 44; // Have not tested with other types yet From 047ca0a38d3e88e7e6b5acd906cd354a4eb68595 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 21 Aug 2024 09:57:43 +0000 Subject: [PATCH 24/30] remove OUTPUT_TYPE from solver prelu --- src/solver/prelu/backward_prelu_multi_weights.cpp | 1 - src/solver/prelu/backward_prelu_single_weight.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/src/solver/prelu/backward_prelu_multi_weights.cpp b/src/solver/prelu/backward_prelu_multi_weights.cpp index 5fed375a2b..c8137bee2e 100644 --- a/src/solver/prelu/backward_prelu_multi_weights.cpp +++ b/src/solver/prelu/backward_prelu_multi_weights.cpp @@ -87,7 +87,6 @@ MultiWeightsBackward::GetSolution(const ExecutionContext& context, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, - {"OUTPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, {"REDUCE_SIZE", LOCAL_SIZE_MW_REDUCE_BWD}, }; result.construction_params.push_back( diff --git a/src/solver/prelu/backward_prelu_single_weight.cpp b/src/solver/prelu/backward_prelu_single_weight.cpp index 2f8ba825f4..dba2694761 100644 --- a/src/solver/prelu/backward_prelu_single_weight.cpp +++ b/src/solver/prelu/backward_prelu_single_weight.cpp @@ -97,7 +97,6 @@ SingleWeightBackward::GetSolution(const ExecutionContext& /*context*/, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, - {"OUTPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, {"REDUCE_SIZE", LOCAL_SIZE_SW_REDUCE_BWD}, }; while(size > LOCAL_SIZE_SW_REDUCE_BWD) From 1b222e1433950bba071b8350f49eae3fc9877f2b Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Tue, 27 Aug 2024 10:19:11 +0000 Subject: [PATCH 25/30] avoid using divisor at all --- .../miopen/multimarginloss/invoke_params.hpp | 1 - src/kernels/MIOpenMultiMarginLoss.cpp | 1 + src/multimarginloss.cpp | 7 +------ .../forward_reduced_multimarginloss.cpp | 4 ++-- test/cpu_multimarginloss.hpp | 15 +++++++++++---- test/gtest/multimarginloss.hpp | 11 +++-------- 6 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/include/miopen/multimarginloss/invoke_params.hpp b/src/include/miopen/multimarginloss/invoke_params.hpp index 24e323946f..9d8a8e3498 100644 --- a/src/include/miopen/multimarginloss/invoke_params.hpp +++ b/src/include/miopen/multimarginloss/invoke_params.hpp @@ -48,7 +48,6 @@ struct InvokeParams : public miopen::InvokeParams long p; float margin; - float divisor = 0; Data_t workspace = nullptr; std::size_t workspace_size = 0; diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index 1ab7811eba..c1bf908e3b 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -79,6 +79,7 @@ __device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, case 0: static_cast(O)[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); break; case 1: static_cast(O)[n] = loss; break; case 2: static_cast(O)[n] = loss / N; break; + default: break; } } diff --git a/src/multimarginloss.cpp b/src/multimarginloss.cpp index e3933c7ae4..ec5412b88e 100644 --- a/src/multimarginloss.cpp +++ b/src/multimarginloss.cpp @@ -92,12 +92,7 @@ miopenStatus_t MultiMarginLossForward(Handle& handle, tmp.margin = margin; tmp.workspace = workspace; tmp.workspace_size = workspaceSizeInBytes; - switch(reduction) - { - case MIOPEN_LOSS_REDUCTION_NONE: tmp.divisor = 0; break; - case MIOPEN_LOSS_REDUCTION_MEAN: tmp.divisor = iDesc.GetLengths()[0]; break; - case MIOPEN_LOSS_REDUCTION_SUM: tmp.divisor = 1; break; - } + return tmp; }(); diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index b58fa3020d..b70e3a775c 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -214,10 +214,10 @@ ConvSolution MultiMarginLossForward::GetSolution( params.workspace, params.p, params.margin, - params.divisor, i_tv, t_tv, - w_tv); + w_tv, + o_tv); } /* Phase 2: Reduce */ diff --git a/test/cpu_multimarginloss.hpp b/test/cpu_multimarginloss.hpp index c2b7fbb5ca..1df2ac6e75 100644 --- a/test/cpu_multimarginloss.hpp +++ b/test/cpu_multimarginloss.hpp @@ -27,6 +27,7 @@ #ifndef GUARD_CPU_MULTIMARGINLOSS_HPP #define GUARD_CPU_MULTIMARGINLOSS_HPP +#include "miopen/miopen.h" #include "tensor_holder.hpp" #include @@ -37,7 +38,7 @@ void cpu_multimarginloss_forward(tensor input, tensor& ref_output, const long p, const float margin, - const float divisor) + miopenLossReductionMode_t reduction_mode) { auto I_tv = miopen::get_inner_expanded_tv<2>(input.desc); auto T_tv = miopen::get_inner_expanded_tv<1>(target.desc); @@ -66,13 +67,19 @@ void cpu_multimarginloss_forward(tensor input, t = static_cast(weight[W_tv.get_tensor_view_idx({y})]) * t; loss += t / C; } - if(divisor == 0) + if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) ref_output[O_tv.get_tensor_view_idx({n})] = loss; else sum += loss; } - if(divisor != 0) - ref_output[0] = static_cast(sum / divisor); + if(reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) + { + ref_output[0] = static_cast(sum / N); + } + else if(reduction_mode == MIOPEN_LOSS_REDUCTION_SUM) + { + ref_output[0] = static_cast(sum); + } } #endif diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp index bba72cecb9..493fa725c1 100644 --- a/test/gtest/multimarginloss.hpp +++ b/test/gtest/multimarginloss.hpp @@ -181,14 +181,9 @@ struct MultiMarginLossForwardTest : public ::testing::TestWithParam(input, target, weight, ref_output, p, margin, divisor); + + cpu_multimarginloss_forward( + input, target, weight, ref_output, p, margin, reduction_mode); status = miopen::MultiMarginLossForward(handle, workspace_dev.get(), From d83b8330659fabd8a8ac611c1aab35bc2649ad05 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 28 Aug 2024 04:07:37 +0000 Subject: [PATCH 26/30] remove tensor_view modify --- src/kernels/tensor_view.hpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/kernels/tensor_view.hpp b/src/kernels/tensor_view.hpp index 8cdab4b942..c9357dd729 100644 --- a/src/kernels/tensor_view.hpp +++ b/src/kernels/tensor_view.hpp @@ -35,22 +35,17 @@ struct tensor_layout_t; template struct tensor_view_t { - constexpr uint64_t get_tensor_view_idx(const uint64_t (&layout)[N]) + // Get index in tensor view at tensor layout + constexpr uint64_t get_tensor_view_idx(const tensor_layout_t& tensor_layout) { static_assert(N > 0); uint64_t idx = 0; for(auto i = 0; i < N; ++i) { - idx += stride[i] * layout[i]; + idx += stride[i] * tensor_layout.layout[i]; } return idx; } - - constexpr uint64_t get_tensor_view_idx(const tensor_layout_t& tensor_layout) - { - return get_tensor_view_idx(tensor_layout.layout); - } - uint64_t stride[N]; uint64_t size[N]; }; From f90005a4d0d471f4081d32b533aed86a7c9580ad Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 28 Aug 2024 07:48:02 +0000 Subject: [PATCH 27/30] remove divisor in driver --- driver/multimarginloss_driver.hpp | 32 ++++++++++---------- src/kernels/MIOpenMultiMarginLoss.cpp | 2 +- src/kernels/MIOpenReduceSum.cpp | 2 +- src/multimarginloss_api.cpp | 43 ++++++++++----------------- 4 files changed, 35 insertions(+), 44 deletions(-) diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp index 9bf354f805..ebf13eb883 100644 --- a/driver/multimarginloss_driver.hpp +++ b/driver/multimarginloss_driver.hpp @@ -31,7 +31,6 @@ #include "tensor_driver.hpp" #include "timer.hpp" #include "random.hpp" -#include #include #include #include @@ -49,7 +48,7 @@ int32_t mloMultiMarginLossForwardRunHost(const miopenTensorDescriptor_t iDesc, const miopenTensorDescriptor_t oDesc, const long p, const float margin, - const float divisor, + const miopenLossReductionMode_t reduction_mode, const Tgpu* input, const uint64_t* target, const Tgpu* weight, @@ -84,13 +83,15 @@ int32_t mloMultiMarginLossForwardRunHost(const miopenTensorDescriptor_t iDesc, t = weight[W_tv.get_tensor_view_idx({y})] * t; loss += t / C; } - if(divisor != 0) + if(reduction_mode != MIOPEN_LOSS_REDUCTION_NONE) sum_loss += loss; else ref_output[O_tv.get_tensor_view_idx({n})] = static_cast(loss); } - if(divisor != 0) - ref_output[0] = static_cast(sum_loss / divisor); + if(reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) + ref_output[0] = static_cast(sum_loss / N); + else if(reduction_mode == MIOPEN_LOSS_REDUCTION_SUM) + ref_output[0] = static_cast(sum_loss); return ret; } @@ -382,16 +383,17 @@ int MultiMarginLossDriver::RunForwardGPU() template int MultiMarginLossDriver::RunForwardCPU() { - float divisor; - switch(reduction_mode) - { - case MIOPEN_LOSS_REDUCTION_NONE: divisor = 0; break; - case MIOPEN_LOSS_REDUCTION_MEAN: divisor = miopen::deref(iDesc).GetLengths()[0]; break; - case MIOPEN_LOSS_REDUCTION_SUM: divisor = 1; break; - } - - mloMultiMarginLossForwardRunHost( - iDesc, tDesc, wDesc, oDesc, p, margin, divisor, I.data(), T.data(), W.data(), Ohost.data()); + mloMultiMarginLossForwardRunHost(iDesc, + tDesc, + wDesc, + oDesc, + p, + margin, + reduction_mode, + I.data(), + T.data(), + W.data(), + Ohost.data()); return miopenStatusSuccess; } diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp index c1bf908e3b..2443f7863a 100644 --- a/src/kernels/MIOpenMultiMarginLoss.cpp +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -96,4 +96,4 @@ extern "C" __global__ void MultiMarginLossForward2d(const FLOAT* __restrict__ I, { // instantiate the kernel multimarginlossforward2d(I, T, W, O, p, margin, I_tv, T_tv, W_tv, O_tv); -} \ No newline at end of file +} diff --git a/src/kernels/MIOpenReduceSum.cpp b/src/kernels/MIOpenReduceSum.cpp index ec431bc1e0..367544cdbb 100644 --- a/src/kernels/MIOpenReduceSum.cpp +++ b/src/kernels/MIOpenReduceSum.cpp @@ -100,4 +100,4 @@ extern "C" __global__ void Reduce1dSum(const FLOAT_ACCUM* __restrict__ input, { // instantiate the kernel Reduce1dSum(input, output, output_numel, inner_size, outer_size, output_tv); -} \ No newline at end of file +} diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp index 479cfd91b4..0e0d34b371 100644 --- a/src/multimarginloss_api.cpp +++ b/src/multimarginloss_api.cpp @@ -87,31 +87,20 @@ extern "C" miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, workspace, workspaceSizeInBytes); - if(reduction == MIOPEN_LOSS_REDUCTION_NONE || reduction == MIOPEN_LOSS_REDUCTION_SUM || - reduction == MIOPEN_LOSS_REDUCTION_MEAN) - { - return miopen::try_([&] { - miopen::MultiMarginLossForward(miopen::deref(handle), - DataCast(workspace), - workspaceSizeInBytes, - miopen::deref(inputDesc), - DataCast(input), - miopen::deref(targetDesc), - DataCast(target), - miopen::deref(weightDesc), - DataCast(weight), - miopen::deref(outputDesc), - DataCast(output), - p, - margin, - reduction); - }); - } - else - { - MIOPEN_THROW(miopenStatusBadParm, - "miopenMultiMarginLossForward: reduction should be " - "MIOPEN_LOSS_REDUCTION_NONE, " - "MIOPEN_LOSS_REDUCTION_SUM or MIOPEN_LOSS_REDUCTION_MEAN."); - } + return miopen::try_([&] { + miopen::MultiMarginLossForward(miopen::deref(handle), + DataCast(workspace), + workspaceSizeInBytes, + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + p, + margin, + reduction); + }); } From 898286f3a82bd4662efd55a57763d55dd8cd6655 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Wed, 28 Aug 2024 09:01:04 +0000 Subject: [PATCH 28/30] merge solver to avoid duplicated code --- src/CMakeLists.txt | 1 - .../miopen/multimarginloss/solvers.hpp | 17 -- src/multimarginloss.cpp | 22 +-- src/solver.cpp | 4 - .../forward_reduced_multimarginloss.cpp | 160 +++++++++++------- .../forward_unreduced_multimarginloss.cpp | 154 ----------------- 6 files changed, 103 insertions(+), 255 deletions(-) delete mode 100644 src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 32da2a6229..0a5694f5ec 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -302,7 +302,6 @@ set( MIOpen_Source solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp solver/multimarginloss/forward_reduced_multimarginloss.cpp - solver/multimarginloss/forward_unreduced_multimarginloss.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp diff --git a/src/include/miopen/multimarginloss/solvers.hpp b/src/include/miopen/multimarginloss/solvers.hpp index 078eaa0ce0..82c5c79910 100644 --- a/src/include/miopen/multimarginloss/solvers.hpp +++ b/src/include/miopen/multimarginloss/solvers.hpp @@ -37,23 +37,6 @@ namespace multimarginloss { using ForwardMultiMarginLossSolver = NonTunableSolverBase; -struct MultiMarginLossUnreducedForward final : ForwardMultiMarginLossSolver -{ - const std::string& SolverDbId() const override - { - return GetSolverDbId(); - } - bool - IsImprovementOverROCm(const ExecutionContext& context, - const miopen::multimarginloss::ForwardProblemDescription& problem) const; - bool - IsApplicable(const ExecutionContext& context, - const miopen::multimarginloss::ForwardProblemDescription& problem) const override; - ConvSolution - GetSolution(const ExecutionContext& context, - const miopen::multimarginloss::ForwardProblemDescription& problem) const override; -}; - struct MultiMarginLossForward final : ForwardMultiMarginLossSolver { const std::string& SolverDbId() const override diff --git a/src/multimarginloss.cpp b/src/multimarginloss.cpp index ec5412b88e..53b28ce78b 100644 --- a/src/multimarginloss.cpp +++ b/src/multimarginloss.cpp @@ -45,10 +45,6 @@ std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, const float margin, miopenLossReductionMode_t reduction) { - if(reduction == MIOPEN_LOSS_REDUCTION_NONE) - { - return static_cast(0); - } auto ctx = ExecutionContext{&handle}; const auto problem = multimarginloss::ForwardProblemDescription{ iDesc, tDesc, wDesc, oDesc, p, margin, reduction}; @@ -96,21 +92,9 @@ miopenStatus_t MultiMarginLossForward(Handle& handle, return tmp; }(); - if(reduction == MIOPEN_LOSS_REDUCTION_NONE) - { - const auto algo = AlgorithmName{"MultiMarginLossUnreducedForward"}; - const auto solvers = - solver::SolverContainer{}; - - solvers.ExecutePrimitive(handle, problem, algo, invoke_params); - } - else - { - const auto algo = AlgorithmName{"MultiMarginLossForward"}; - const auto solvers = - solver::SolverContainer{}; - solvers.ExecutePrimitive(handle, problem, algo, invoke_params); - } + const auto algo = AlgorithmName{"MultiMarginLossForward"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); return miopenStatusSuccess; } diff --git a/src/solver.cpp b/src/solver.cpp index 11ada34965..e1c44c7f2d 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -681,10 +681,6 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::ReLU, prelu::MultiWeightsBackward{}.SolverDbId()); Register(registry, ++id, Primitive::ReLU, prelu::SingleWeightBackward{}.SolverDbId()); - Register(registry, - ++id, - Primitive::MultiMarginLoss, - multimarginloss::MultiMarginLossUnreducedForward{}.SolverDbId()); Register(registry, ++id, Primitive::MultiMarginLoss, diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp index b70e3a775c..cb1f9f2b4b 100644 --- a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp @@ -86,9 +86,9 @@ ConvSolution MultiMarginLossForward::GetSolution( { auto result = ConvSolution{miopenStatusSuccess}; + // Start building result.construction_params auto xgrid = problem.GetiDesc().GetLengths()[0]; auto dtype = problem.GetiDesc().GetType(); - { /* Phase 1: Calc loss for each element. */ size_t xlocalsize = LOCAL_SIZE_MULTIMARGINLOSS; @@ -123,7 +123,10 @@ ConvSolution MultiMarginLossForward::GetSolution( result.construction_params.push_back(kernel); } + if(problem.Getreduction() != MIOPEN_LOSS_REDUCTION_NONE) { + // If Reduction = NONE, then we should run second kernel to calculate mean/sum of result + // from first kernel above /* Phase 2: Reduce FLOAT_ACCUM -> FLOAT_ACCUM */ auto _size = xgrid; const auto build_params = KernelBuildParameters{ @@ -184,79 +187,113 @@ ConvSolution MultiMarginLossForward::GetSolution( result.construction_params.push_back(kernel); } + // End building result.construction_params - result.invoker_factory = [](const std::vector& kernels) { - return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { - decltype(auto) params = raw_params.CastTo(); - auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); - auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); - auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); - auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); - - float elapsed = 0.0f; - HipEventPtr start; - HipEventPtr stop; - - const bool profiling = handle_.IsProfilingEnabled(); - if(profiling) - { - handle_.EnableProfiling(false); - start = miopen::make_hip_event(); - stop = miopen::make_hip_event(); - hipEventRecord(start.get(), handle_.GetStream()); - } - /* Phase 1: Calc loss for each element. */ - { + // Start building result.invoker_factory + if(problem.Getreduction() == MIOPEN_LOSS_REDUCTION_NONE) + { + // Reduction = None -> invoke 1 kernel + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); + auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); + auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); + kernel(params.i, params.t, params.w, - params.workspace, + params.o, params.p, params.margin, i_tv, t_tv, w_tv, o_tv); - } - - /* Phase 2: Reduce */ - auto size = deref(params.iDesc).GetLengths()[0]; - auto data_size = get_data_size(miopenFloat); - auto wt = MultiBufferWorkspaceTraits{ - size * data_size, (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE * data_size}; - auto reduce_in = params.workspace; - auto reduce_out = - static_cast(static_cast(params.workspace) + wt.GetOffset(1)); - - int kernelCnt = 1; - for(kernelCnt; kernelCnt < kernels.size() - 1; kernelCnt++) - { + }; + }; + } + else + { + // Reduction != None -> invoke 2 kernels + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) params = raw_params.CastTo(); + auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); + auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); + auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); + + float elapsed = 0.0f; + HipEventPtr start; + HipEventPtr stop; + + const bool profiling = handle_.IsProfilingEnabled(); + if(profiling) + { + handle_.EnableProfiling(false); + start = miopen::make_hip_event(); + stop = miopen::make_hip_event(); + hipEventRecord(start.get(), handle_.GetStream()); + } + /* Phase 1: Calc loss for each element. */ + { + decltype(auto) kernel = handle_.Run(kernels.front()); + kernel(params.i, + params.t, + params.w, + params.workspace, + params.p, + params.margin, + i_tv, + t_tv, + w_tv, + o_tv); + } + + /* Phase 2: Reduce */ + auto size = deref(params.iDesc).GetLengths()[0]; + auto data_size = get_data_size(miopenFloat); + auto wt = MultiBufferWorkspaceTraits{size * data_size, + (size + LOCAL_SIZE_REDUCE - 1) / + LOCAL_SIZE_REDUCE * data_size}; + auto reduce_in = params.workspace; + auto reduce_out = static_cast(static_cast(params.workspace) + + wt.GetOffset(1)); + + int kernelCnt = 1; + for(kernelCnt; kernelCnt < kernels.size() - 1; kernelCnt++) + { + decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); + kernel(reduce_in, reduce_out, size); + std::swap(reduce_in, reduce_out); + size = (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; + } + decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); - kernel(reduce_in, reduce_out, size); - std::swap(reduce_in, reduce_out); - size = (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; - } - - decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); - kernel(reduce_in, params.o, size, o_tv); - - if(profiling) - { - hipEventRecord(stop.get(), handle_.GetStream()); - hipEventSynchronize(stop.get()); - hipEventElapsedTime(&elapsed, start.get(), stop.get()); - - // Clean up - hipEventDestroy(start.get()); - hipEventDestroy(stop.get()); - handle_.ResetKernelTime(); - handle_.AccumKernelTime(elapsed); - - handle_.EnableProfiling(true); + kernel(reduce_in, params.o, size, o_tv); + + if(profiling) + { + hipEventRecord(stop.get(), handle_.GetStream()); + hipEventSynchronize(stop.get()); + hipEventElapsedTime(&elapsed, start.get(), stop.get()); + + // Clean up + hipEventDestroy(start.get()); + hipEventDestroy(stop.get()); + handle_.ResetKernelTime(); + handle_.AccumKernelTime(elapsed); + + handle_.EnableProfiling(true); + }; }; }; - }; + } + // End building result.invoker_factory return result; } @@ -265,6 +302,9 @@ std::size_t MultiMarginLossForward::GetWorkspaceSize( const ExecutionContext& /*context*/, const miopen::multimarginloss::ForwardProblemDescription& problem) const { + if(problem.Getreduction() == MIOPEN_LOSS_REDUCTION_NONE) + return 0; + auto size = problem.GetiDesc().GetLengths()[0]; auto data_size = get_data_size(miopenFloat); return MultiBufferWorkspaceTraits{ diff --git a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp deleted file mode 100644 index 2d95aecf20..0000000000 --- a/src/solver/multimarginloss/forward_unreduced_multimarginloss.cpp +++ /dev/null @@ -1,154 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2024 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#include "miopen/miopen.h" -#include -#include -#include -#include -#include -#include -#include - -#define LOCAL_SIZE 256 - -namespace miopen { - -namespace solver { - -namespace multimarginloss { - -bool MultiMarginLossUnreducedForward::IsImprovementOverROCm( - const ExecutionContext& /*context*/, - const miopen::multimarginloss::ForwardProblemDescription& problem) const -{ - int C = problem.GetiDesc().GetLengths()[1]; - if(problem.allContiguousTensor()) - { - switch(problem.GetiDesc().GetType()) - { - case miopenFloat: return C <= 33; - case miopenHalf: return C <= 43; - case miopenBFloat16: return C <= 44; - // Have not tested with other types yet - default: return true; - } - } - else - { - switch(problem.GetiDesc().GetType()) - { - case miopenFloat: return C <= 31; - case miopenHalf: return C <= 38; - case miopenBFloat16: return C <= 40; - // Have not tested with other types yet - default: return true; - } - } -} - -bool MultiMarginLossUnreducedForward::IsApplicable( - const ExecutionContext& context, - const miopen::multimarginloss::ForwardProblemDescription& problem) const -{ - if(!IsImprovementOverROCm(context, problem)) - return false; - return true; -} - -ConvSolution MultiMarginLossUnreducedForward::GetSolution( - const ExecutionContext& /*context*/, - const miopen::multimarginloss::ForwardProblemDescription& problem) const -{ - auto result = ConvSolution{miopenStatusSuccess}; - - auto xgrid = problem.GetiDesc().GetLengths()[0]; - - { - auto dtype = problem.GetiDesc().GetType(); - size_t xlocalsize = LOCAL_SIZE; - size_t xgridsize = AlignUp(xgrid, xlocalsize); - size_t ylocalsize = 1; - size_t ygridsize = 1; - size_t zlocalsize = 1; - size_t zgridsize = 1; - - auto kernel = KernelInfo{}; - kernel.kernel_file = "MIOpenMultiMarginLoss.cpp"; - kernel.kernel_name = "MultiMarginLossForward2d"; - - const auto build_params = KernelBuildParameters{ - {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, - {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, - {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, - {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, - {"REDUCTION_TYPE", static_cast(problem.Getreduction())}, - }; - - kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); - - kernel.l_wk.push_back(xlocalsize); - kernel.l_wk.push_back(ylocalsize); - kernel.l_wk.push_back(zlocalsize); - - kernel.g_wk.push_back(xgridsize); - kernel.g_wk.push_back(ygridsize); - kernel.g_wk.push_back(zgridsize); - - result.construction_params.push_back(kernel); - } - - result.invoker_factory = [](const std::vector& kernels) { - return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { - decltype(auto) kernel = handle_.Run(kernels.front()); - decltype(auto) params = raw_params.CastTo(); - - auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); - auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); - auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); - auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); - - kernel(params.i, - params.t, - params.w, - params.o, - params.p, - params.margin, - i_tv, - t_tv, - w_tv, - o_tv); - }; - }; - - return result; -} - -} // namespace multimarginloss - -} // namespace solver - -} // namespace miopen From d3b8b42149462e902ac056824a334d980f06edc1 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Thu, 5 Sep 2024 10:55:50 +0000 Subject: [PATCH 29/30] rename file --- src/CMakeLists.txt | 2 +- ..._reduced_multimarginloss.cpp => forward_multimarginloss.cpp} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/solver/multimarginloss/{forward_reduced_multimarginloss.cpp => forward_multimarginloss.cpp} (100%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0a5694f5ec..18276dc758 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -301,7 +301,7 @@ set( MIOpen_Source solver/layernorm/forward_t5layernorm.cpp solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp - solver/multimarginloss/forward_reduced_multimarginloss.cpp + solver/multimarginloss/forward_multimarginloss.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp diff --git a/src/solver/multimarginloss/forward_reduced_multimarginloss.cpp b/src/solver/multimarginloss/forward_multimarginloss.cpp similarity index 100% rename from src/solver/multimarginloss/forward_reduced_multimarginloss.cpp rename to src/solver/multimarginloss/forward_multimarginloss.cpp From 9978ea98ed3280b0a6c6799b573555faf02be830 Mon Sep 17 00:00:00 2001 From: littlecutebird Date: Mon, 30 Sep 2024 03:56:57 +0000 Subject: [PATCH 30/30] add missing headers --- src/solver/multimarginloss/forward_multimarginloss.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/solver/multimarginloss/forward_multimarginloss.cpp b/src/solver/multimarginloss/forward_multimarginloss.cpp index cb1f9f2b4b..939222ca2f 100644 --- a/src/solver/multimarginloss/forward_multimarginloss.cpp +++ b/src/solver/multimarginloss/forward_multimarginloss.cpp @@ -24,7 +24,9 @@ * *******************************************************************************/ +#include "miopen/buffer_info.hpp" #include "miopen/miopen.h" +#include "miopen/mlo_internal.hpp" #include #include #include