diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 256901aa94..8e445c0a66 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -48,6 +48,7 @@ add_executable(MIOpenDriver dm_groupnorm.cpp dm_layernorm.cpp dm_lrn.cpp + dm_multimarginloss.cpp dm_pool.cpp dm_prelu.cpp dm_reduce.cpp diff --git a/driver/dm_multimarginloss.cpp b/driver/dm_multimarginloss.cpp new file mode 100644 index 0000000000..1c924bb7dc --- /dev/null +++ b/driver/dm_multimarginloss.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "multimarginloss_driver.hpp" +#include "registry_driver_maker.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "multimarginloss") + return new MultiMarginLossDriver(); + if(base_arg == "multimarginlossfp16") + return new MultiMarginLossDriver(); + if(base_arg == "multimarginlossbfp16") + return new MultiMarginLossDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index aa0b89f10a..83d27c7c0e 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -176,7 +176,8 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], " "adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, " "getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], " - "prelu[bfp16|fp16], glu[bfp16|fp16]\n"); + "prelu[bfp16|fp16], glu[bfp16|fp16], " + "multimarginloss[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -210,6 +211,7 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" && arg != "ropefp16" && arg != "ropebfp16" && arg != "prelu" && arg != "prelufp16" && arg != "prelubfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" && + arg != "multimarginloss" && arg != "multimarginlossfp16" && arg != "multimarginlossbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); diff --git a/driver/multimarginloss_driver.hpp b/driver/multimarginloss_driver.hpp new file mode 100644 index 0000000000..ebf13eb883 --- /dev/null +++ b/driver/multimarginloss_driver.hpp @@ -0,0 +1,446 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_MULTIMARGINLOSS_DRIVER_HPP +#define GUARD_MIOPEN_MULTIMARGINLOSS_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "random.hpp" +#include +#include +#include +#include +#include +#include +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> +#include + +template +int32_t mloMultiMarginLossForwardRunHost(const miopenTensorDescriptor_t iDesc, + const miopenTensorDescriptor_t tDesc, + const miopenTensorDescriptor_t wDesc, + const miopenTensorDescriptor_t oDesc, + const long p, + const float margin, + const miopenLossReductionMode_t reduction_mode, + const Tgpu* input, + const uint64_t* target, + const Tgpu* weight, + Tcheck* ref_output) +{ + auto I_tv = miopen::get_inner_expanded_tv<2>(miopen::deref(iDesc)); + auto T_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(tDesc)); + auto W_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(wDesc)); + auto O_tv = miopen::get_inner_expanded_tv<1>(miopen::deref(oDesc)); + auto N = I_tv.size[0], C = I_tv.size[1]; + + int32_t ret = 0; + double sum_loss = 0; + + for(size_t n = 0; n < N; n++) + { + double loss = 0; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + continue; + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = weight[W_tv.get_tensor_view_idx({y})] * t; + loss += t / C; + } + if(reduction_mode != MIOPEN_LOSS_REDUCTION_NONE) + sum_loss += loss; + else + ref_output[O_tv.get_tensor_view_idx({n})] = static_cast(loss); + } + if(reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) + ref_output[0] = static_cast(sum_loss / N); + else if(reduction_mode == MIOPEN_LOSS_REDUCTION_SUM) + ref_output[0] = static_cast(sum_loss); + return ret; +} + +template +class MultiMarginLossDriver : public Driver +{ +public: + MultiMarginLossDriver() : Driver() + { + miopenCreateTensorDescriptor(&iDesc); + miopenCreateTensorDescriptor(&tDesc); + miopenCreateTensorDescriptor(&wDesc); + miopenCreateTensorDescriptor(&oDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~MultiMarginLossDriver() override + { + miopenDestroyTensorDescriptor(iDesc); + miopenDestroyTensorDescriptor(tDesc); + miopenDestroyTensorDescriptor(wDesc); + miopenDestroyTensorDescriptor(oDesc); + } + +private: + InputFlags inflags; + + // forw = 0 -> run both fw, bw, = 1 -> run only fw, = 2 -> run only bw + int forw; + + miopenTensorDescriptor_t iDesc; + miopenTensorDescriptor_t tDesc; + miopenTensorDescriptor_t wDesc; + miopenTensorDescriptor_t oDesc; + + std::unique_ptr i_dev; + std::unique_ptr t_dev; + std::unique_ptr w_dev; + std::unique_ptr o_dev; + std::unique_ptr workspace_dev; + + std::vector I; + std::vector T; + std::vector W; + std::vector O; + std::vector Ohost; + + long p; + float margin; + miopenLossReductionMode_t reduction_mode; + size_t ws_sizeInBytes; +}; + +template +int MultiMarginLossDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Take (Default=1)", "int"); + inflags.AddInputFlag("dim", 'D', "41x4", "Dim of input tensor (Default=41x4)", "tensor"); + inflags.AddInputFlag("contiguous", 'C', "1", "Tensor is contiguous or not (Default=1)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + inflags.AddInputFlag("reduce", + 'R', + "none", + "Specifies the reduction to apply to the output ('none'|'mean'|'sum') " + "(Default=none to indicate no reduction)", + "str"); + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + auto reduction = inflags.GetValueStr("reduce"); + if(reduction != "none" && reduction != "mean" && reduction != "sum") + return miopenStatusInvalidValue; + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + forw = inflags.GetValueInt("forw"); + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::GetandSetData() +{ + // Set tensor description + std::vector in_len = inflags.GetValueTensor("dim").lengths; + size_t N = in_len[0], C = in_len[1]; + if(inflags.GetValueInt("contiguous") == 1) + { + SetTensorNd(iDesc, in_len, data_type); + + std::vector t_len = {N}; + SetTensorNd(tDesc, t_len, miopenInt64); + + std::vector w_len = {C}; + SetTensorNd(wDesc, w_len, data_type); + } + else + { + std::vector in_strides(in_len.size()); + in_strides.back() = 1; + for(int i = in_len.size() - 2; i >= 0; --i) + in_strides[i] = in_strides[i + 1] * in_len[i + 1]; + in_strides[0] *= 2; + SetTensorNd(iDesc, in_len, in_strides, data_type); + + std::vector t_len = {N}; + std::vector t_strides = {2}; + SetTensorNd(tDesc, t_len, t_strides, miopenInt64); + + std::vector w_lens = {C}; + std::vector w_strides = {2}; + SetTensorNd(wDesc, w_lens, w_strides, data_type); + } + + // Set p and margin + // p = 1 or 2 + p = prng::gen_A_to_B(static_cast(1), static_cast(3)); + margin = prng::gen_A_to_B(static_cast(0.5), static_cast(1.5)); + + // Set reduction_mode + auto reduction = inflags.GetValueStr("reduce"); + if(reduction == "none") + reduction_mode = MIOPEN_LOSS_REDUCTION_NONE; + else if(reduction == "mean") + reduction_mode = MIOPEN_LOSS_REDUCTION_MEAN; + else if(reduction == "sum") + reduction_mode = MIOPEN_LOSS_REDUCTION_SUM; + + // Set output tensor description (forw = 1 or = 0) + if(forw == 0 || forw == 1) + { + if(reduction == "none") + { + std::vector o_lens = {N}; + SetTensorNd(oDesc, o_lens, data_type); + } + else + { + std::vector o_lens = {1}; + SetTensorNd(oDesc, o_lens, data_type); + } + } + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::AllocateBuffersAndCopy() +{ + uint32_t ctx = 0; + + // for unpacked tensor, we need to use GetTensorSpace instead of GetTensorSize + size_t i_sz = GetTensorSpace(iDesc); + size_t t_sz = GetTensorSpace(tDesc); + size_t w_sz = GetTensorSpace(wDesc); + i_dev = std::make_unique(ctx, i_sz, sizeof(Tgpu)); + t_dev = std::make_unique(ctx, t_sz, sizeof(uint64_t)); + w_dev = std::make_unique(ctx, w_sz, sizeof(Tgpu)); + I = std::vector(i_sz); + T = std::vector(t_sz); + W = std::vector(w_sz); + for(int i = 0; i < i_sz; i++) + { + I[i] = prng::gen_A_to_B(static_cast(-1), static_cast(1)); + } + int C = miopen::deref(iDesc).GetLengths()[1]; + // 0 to C - 1 + for(int i = 0; i < t_sz; i++) + { + T[i] = prng::gen_A_to_B(static_cast(0), static_cast(C)); + } + for(int i = 0; i < w_sz; i++) + { + W[i] = prng::gen_A_to_B(static_cast(-1), static_cast(1)); + } + + if(i_dev->ToGPU(GetStream(), I.data()) != 0) + std::cerr << "Error copying (I) to GPU, size: " << i_dev->GetSize() << std::endl; + + if(t_dev->ToGPU(GetStream(), T.data()) != 0) + std::cerr << "Error copying (T) to GPU, size: " << t_dev->GetSize() << std::endl; + + if(w_dev->ToGPU(GetStream(), W.data()) != 0) + std::cerr << "Error copying (W) to GPU, size: " << w_dev->GetSize() << std::endl; + + if(forw == 0 || forw == 1) + { + size_t o_sz = GetTensorSpace(oDesc); + + miopenGetMultiMarginLossForwardWorkspaceSize( + GetHandle(), iDesc, tDesc, wDesc, oDesc, p, margin, reduction_mode, &ws_sizeInBytes); + if(ws_sizeInBytes == static_cast(-1)) + { + return miopenStatusAllocFailed; + } + + o_dev = std::make_unique(ctx, o_sz, sizeof(Tgpu)); + O = std::vector(o_sz, static_cast(0)); + Ohost = std::vector(o_sz, static_cast(0)); + if(o_dev->ToGPU(GetStream(), O.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << o_dev->GetSize() << std::endl; + + workspace_dev = std::make_unique(ctx, ws_sizeInBytes, sizeof(std::byte)); + } + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenMultiMarginLossForward(GetHandle(), + iDesc, + i_dev->GetMem(), + tDesc, + t_dev->GetMem(), + wDesc, + w_dev->GetMem(), + oDesc, + o_dev->GetMem(), + p, + margin, + reduction_mode, + workspace_dev->GetMem(), + ws_sizeInBytes); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Forward MultiMarginLoss Elapsed: " + << t.gettime_ms() / iter << " ms" << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Forward MultiMarginLoss Elapsed: " << kernel_average_time + << " ms" << std::endl; + } + + if(o_dev->FromGPU(GetStream(), O.data()) != 0) + std::cerr << "Error copying (o_dev) from GPU, size: " << o_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::RunForwardCPU() +{ + mloMultiMarginLossForwardRunHost(iDesc, + tDesc, + wDesc, + oDesc, + p, + margin, + reduction_mode, + I.data(), + T.data(), + W.data(), + Ohost.data()); + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref MultiMarginLossDriver::GetTolerance() +{ + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; +} + +template +int MultiMarginLossDriver::VerifyForward() +{ + RunForwardCPU(); + + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(Ohost, O); + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward MultiMarginLoss FAILED: " << error << " > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward MultiMarginLoss Verifies OK on CPU reference (" << error << " < " + << tolerance << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int MultiMarginLossDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_MULTIMARGINLOSS_DRIVER_HPP diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 7ed36c72a4..6591a264cc 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7791,6 +7791,101 @@ MIOPEN_EXPORT miopenStatus_t miopenPReLUBackward(miopenHandle_t handle, // CLOSEOUT RELU DOXYGEN GROUP #endif // MIOPEN_BETA_API +#ifdef MIOPEN_BETA_API + +/*! @ingroup LossFunction + * @enum miopenLossReductionMode_t + * Reduction mode for loss function + */ +typedef enum +{ + MIOPEN_LOSS_REDUCTION_NONE = 0, /*!< output tensor elements are not reduced */ + MIOPEN_LOSS_REDUCTION_SUM = 1, /*!< output tensor elements are summed up */ + MIOPEN_LOSS_REDUCTION_MEAN = 2, /*!< output tensor elements are summed up and divided with total + number of elements to get mean value */ +} miopenLossReductionMode_t; + +// MultiMarginLoss APIs +/** @addtogroup LossFunction + * + * @{ + */ + +/*! @brief Helper function to query the minimum workspace size required by the +MultiMarginLossForward call + * + * @param [in] handle MIOpen Handle + * @param [in] inputDesc Tensor descriptor for input tensor (N, C) where N is the batch +size and C is the number of classes + * @param [in] targetDesc Tensor descriptor for target tensor, must have shape (N). Each +value is between 0 and C - 1 + * @param [in] weightDesc Tensor descriptor for weight tensor. It is a manual rescaling +weight given to each class. It has to be a Tensor of size C + * @param [in] outputDesc Tensor descriptor for output tensor. If reduction is 'none, +then it must have shape (N). Otherwise, it is a scalar + * @param [in] p Has a default value of 1. The only supported values are 1 and 2 + * @param [in] margin Has a default value of 1 + * @param [in] reduction Reduction mode (sum, mean) + * @param [out] sizeInBytes Pointer to data to return the minimum workspace size + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t +miopenGetMultiMarginLossForwardWorkspaceSize(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + miopenTensorDescriptor_t targetDesc, + miopenTensorDescriptor_t weightDesc, + miopenTensorDescriptor_t outputDesc, + long p, + float margin, + miopenLossReductionMode_t reduction, + size_t* sizeInBytes); + +/*! @brief Execute a MultiMarginLoss forward layer + * + * @param [in] handle MIOpen handle + * @param [in] inputDesc Tensor descriptor for input tensor (N, C) where N is the +batch size and C is the number of classes. + * @param [in] input Data tensor input + * @param [in] targetDesc Tensor descriptor for target tensor, must have shape (N). +Each value is between 0 and C - 1 + * @param [in] target Data tensor target + * @param [in] weightDesc Tensor descriptor for weight tensor. It is a manual +rescaling weight given to each class. It has to be a Tensor of size C + * @param [in] weight Data tensor weight + * @param [in] outputDesc Tensor descriptor for output tensor. If reduction is 'none, +then it must have shape (N). Otherwise, it is a scalar. + * @param [out] output Data tensor output + * @param [in] p Has a default value of 1. The only supported values are 1 +and 2 + * @param [in] margin Has a default value of 1 + * @param [in] reduction Reduction mode. If reduction mode is mean or sum, you must + * provide param workspace and workspaceSizeInBytes. Call + * miopenGetMultiMarginLossForwardWorkspaceSize to get workspaceSizeInBytes + * @param [in] workspace Address of the allocated workspace data. Set = nullptr if +reduction = 'none' + * @param [in] workspaceSizeInBytes Size in bytes of the allocated workspace data. Set = 0 if +reduction = 'none + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t weightDesc, + const void* weight, + miopenTensorDescriptor_t outputDesc, + void* output, + long p, + float margin, + miopenLossReductionMode_t reduction, + void* workspace, + size_t workspaceSizeInBytes); + +/** @} */ +// CLOSEOUT LossFunction DOXYGEN GROUP +#endif // MIOPEN_BETA_API + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c4ffeede18..ab4c0fa32f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -160,6 +160,8 @@ set( MIOpen_Source lrn_api.cpp mha/mha_descriptor.cpp mha/problem_description.cpp + multimarginloss/problem_description.cpp + multimarginloss_api.cpp op_args.cpp operator.cpp performance_config.cpp @@ -306,6 +308,7 @@ set( MIOpen_Source solver/layernorm/forward_t5layernorm.cpp solver/mha/mha_solver_backward.cpp solver/mha/mha_solver_forward.cpp + solver/multimarginloss/forward_multimarginloss.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -480,12 +483,12 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/stride_array.hpp kernels/tensor_view.hpp kernels/utilities.inc + kernels/warp_reduce.hpp kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c16_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c32_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1024vgprs_fp16_fp16acc_f2x3_c16_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_metadata.inc kernels/workaround_issue_1431.hpp - kernels/warp_reduce.hpp kernels/xform_bidirect_winograd_code.inc kernels/xform_data_filter.inc kernels/xform_kd_cov2.inc @@ -523,6 +526,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl + kernels/MIOpenMultiMarginLoss.cpp kernels/MIOpenNeuron.cl kernels/MIOpenPReLU.cpp kernels/MIOpenPooling.cl @@ -535,8 +539,8 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConv1x1J1_stride.cl kernels/MIOpenReduceCalculation.cpp kernels/MIOpenReduceExtreme.cpp - kernels/MIOpenRoPE.cpp kernels/MIOpenReduceSum.cpp + kernels/MIOpenRoPE.cpp kernels/MIOpenSoftmax.cl kernels/MIOpenSoftmaxAttn.cpp kernels/MIOpenUtilKernels3.cl @@ -652,6 +656,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN adam.cpp addlayernorm.cpp cat.cpp + exec_utils.cpp groupnorm.cpp getitem.cpp glu.cpp @@ -659,7 +664,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN layernorm.cpp lrn.cpp mlo_dir_conv.cpp - exec_utils.cpp + multimarginloss.cpp ocl/activ_ocl.cpp ocl/batchnormocl.cpp ocl/convolutionocl.cpp diff --git a/src/include/miopen/multimarginloss.hpp b/src/include/miopen/multimarginloss.hpp new file mode 100644 index 0000000000..fa718923a2 --- /dev/null +++ b/src/include/miopen/multimarginloss.hpp @@ -0,0 +1,63 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_MULTIMARGINLOSS_HPP_ +#define MIOPEN_MULTIMARGINLOSS_HPP_ + +#include "miopen/miopen.h" +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +MIOPEN_INTERNALS_EXPORT std::size_t +GetMultiMarginLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& iDesc, + const TensorDescriptor& tDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& oDesc, + long p, + float margin, + miopenLossReductionMode_t reduction); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t MultiMarginLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + long p, + float margin, + miopenLossReductionMode_t reduction); + +} // namespace miopen +#endif // _MIOPEN_MULTIMARGINLOSS_HPP_ diff --git a/src/include/miopen/multimarginloss/invoke_params.hpp b/src/include/miopen/multimarginloss/invoke_params.hpp new file mode 100644 index 0000000000..9d8a8e3498 --- /dev/null +++ b/src/include/miopen/multimarginloss/invoke_params.hpp @@ -0,0 +1,60 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +namespace miopen { +namespace multimarginloss { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* iDesc = nullptr; + const TensorDescriptor* tDesc = nullptr; + const TensorDescriptor* wDesc = nullptr; + const TensorDescriptor* oDesc = nullptr; + + ConstData_t i = nullptr; + ConstData_t t = nullptr; + ConstData_t w = nullptr; + Data_t o = nullptr; + + long p; + float margin; + + Data_t workspace = nullptr; + std::size_t workspace_size = 0; + std::size_t GetWorkspaceSize() const { return workspace_size; } + Data_t GetWorkspace() const { return workspace; } +}; + +} // namespace multimarginloss + +} // namespace miopen diff --git a/src/include/miopen/multimarginloss/problem_description.hpp b/src/include/miopen/multimarginloss/problem_description.hpp new file mode 100644 index 0000000000..db1c8fa875 --- /dev/null +++ b/src/include/miopen/multimarginloss/problem_description.hpp @@ -0,0 +1,138 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "miopen/miopen.h" +#include +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace multimarginloss { + +struct ForwardProblemDescription : ProblemDescriptionBase +{ + ForwardProblemDescription(const TensorDescriptor& iDesc_, + const TensorDescriptor& tDesc_, + const TensorDescriptor& wDesc_, + const TensorDescriptor& oDesc_, + const long p_, + const float margin_, + const miopenLossReductionMode_t reduction_) + : iDesc(iDesc_), + tDesc(tDesc_), + wDesc(wDesc_), + oDesc(oDesc_), + p(p_), + margin(margin_), + reduction(reduction_) + { + if(iDesc.GetType() != oDesc.GetType() || iDesc.GetType() != wDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Input, output, weight tensor types do not match."); + } + if(tDesc.GetType() != miopenInt64) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Target tensor type should be miopenInt64."); + } + if(iDesc.GetNumDims() != 2) + { + MIOPEN_THROW(miopenStatusBadParm, "MultiMarginLoss: Input tensor need to be 2D tensor"); + } + if(tDesc.GetNumDims() != 1 || tDesc.GetLengths()[0] != iDesc.GetLengths()[0]) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Target tensor need to be 1D tensor. If input " + "tensor has shape (N, C) then target tensor must have shape (N)"); + } + if(wDesc.GetNumDims() != 1 || wDesc.GetLengths()[0] != iDesc.GetLengths()[1]) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Weight tensor need to be 1D tensor. If input " + "tensor has shape (N, C) then weight tensor must have shape (C)"); + } + // Check output tensor dimension + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + // non-reduction case + if(oDesc.GetNumDims() != 1 || oDesc.GetLengths()[0] != iDesc.GetLengths()[0]) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Output tensor need to be " + "1D tensor. If input " + "tensor has shape (N, C) then output tensor must have shape (N)"); + } + } + else + { + // reduction case + if(oDesc.GetNumDims() != 1 || oDesc.GetLengths()[0] != 1) + { + MIOPEN_THROW(miopenStatusBadParm, + "MultiMarginLoss: Output tensor need to be a scalar."); + } + } + // Check p value + if(p != 1 && p != 2) + { + MIOPEN_THROW(miopenStatusBadParm, "MultiMarginLoss: p need to be equal 1 or 2."); + } + } + + const TensorDescriptor& GetiDesc() const { return iDesc; } + const TensorDescriptor& GettDesc() const { return tDesc; } + const TensorDescriptor& GetwDesc() const { return wDesc; } + const TensorDescriptor& GetoDesc() const { return oDesc; } + long Getp() const { return p; } + float Getmargin() const { return margin; } + miopenLossReductionMode_t Getreduction() const { return reduction; } + bool allContiguousTensor() const + { + return iDesc.IsContiguous() && tDesc.IsContiguous() && wDesc.IsContiguous() && + oDesc.IsContiguous(); + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + TensorDescriptor iDesc; + TensorDescriptor tDesc; + TensorDescriptor wDesc; + TensorDescriptor oDesc; + long p; + float margin; + miopenLossReductionMode_t reduction; +}; + +} // namespace multimarginloss + +} // namespace miopen diff --git a/src/include/miopen/multimarginloss/solvers.hpp b/src/include/miopen/multimarginloss/solvers.hpp new file mode 100644 index 0000000000..82c5c79910 --- /dev/null +++ b/src/include/miopen/multimarginloss/solvers.hpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include + +namespace miopen { + +namespace solver { + +namespace multimarginloss { + +using ForwardMultiMarginLossSolver = + NonTunableSolverBase; + +struct MultiMarginLossForward final : ForwardMultiMarginLossSolver +{ + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } + bool + IsImprovementOverROCm(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const; + bool + IsApplicable(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + std::size_t GetWorkspaceSize( + const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const override; + bool MayNeedWorkspace() const override { return true; } +}; + +} // namespace multimarginloss + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index ab824faa32..fdfb857319 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -32,7 +32,6 @@ #include #include -#include namespace miopen { @@ -61,7 +60,8 @@ enum class Primitive Adam, Item, RoPE, - ReLU + ReLU, + MultiMarginLoss }; struct MIOPEN_INTERNALS_EXPORT Id diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index 1b095affb7..e85075a419 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -27,9 +27,8 @@ #ifndef MIOPEN_TENSOR_VIEW_UTIL_HPP_ #define MIOPEN_TENSOR_VIEW_UTIL_HPP_ -#include -#include #include "../../kernels/tensor_view.hpp" +#include namespace miopen { diff --git a/src/kernels/MIOpenMultiMarginLoss.cpp b/src/kernels/MIOpenMultiMarginLoss.cpp new file mode 100644 index 0000000000..2443f7863a --- /dev/null +++ b/src/kernels/MIOpenMultiMarginLoss.cpp @@ -0,0 +1,99 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" +#include "tensor_view.hpp" + +template +__device__ void multimarginlossforward2d(const DTYPE* __restrict__ I, + const uint64_t* __restrict__ T, + const DTYPE* __restrict__ W, + void* __restrict__ O, + const long p, + const float margin, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv, + tensor_view_t<1> O_tv) +{ + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + size_t N = I_tv.size[0], C = I_tv.size[1]; + size_t n = gid; + if(n >= N) + return; + + FLOAT_ACCUM loss = 0; + size_t y = T[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + { + // TODO: need to handle invalid target index value + return; + } + + FLOAT_ACCUM Iny = CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, y})]); + FLOAT_ACCUM Wy = CVT_FLOAT2ACCUM(W[W_tv.get_tensor_view_idx({y})]); + + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + FLOAT_ACCUM t = margin - Iny + CVT_FLOAT2ACCUM(I[I_tv.get_tensor_view_idx({n, c})]); + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = Wy * t; + loss += t; + } + loss /= C; + switch(REDUCTION_T) + { + case 0: static_cast(O)[O_tv.get_tensor_view_idx({n})] = CVT_ACCUM2FLOAT(loss); break; + case 1: static_cast(O)[n] = loss; break; + case 2: static_cast(O)[n] = loss / N; break; + default: break; + } +} + +extern "C" __global__ void MultiMarginLossForward2d(const FLOAT* __restrict__ I, + const uint64_t* __restrict__ T, + const FLOAT* __restrict__ W, + void* __restrict__ O, + const long p, + const float margin, + tensor_view_t<2> I_tv, + tensor_view_t<1> T_tv, + tensor_view_t<1> W_tv, + tensor_view_t<1> O_tv) +{ + // instantiate the kernel + multimarginlossforward2d(I, T, W, O, p, margin, I_tv, T_tv, W_tv, O_tv); +} diff --git a/src/kernels/MIOpenReduceSum.cpp b/src/kernels/MIOpenReduceSum.cpp index 5ed52008bf..367544cdbb 100644 --- a/src/kernels/MIOpenReduceSum.cpp +++ b/src/kernels/MIOpenReduceSum.cpp @@ -30,7 +30,6 @@ #include "float_types.h" #include "tensor_view.hpp" -#include "warp_reduce.hpp" #include "block_reduce.hpp" template @@ -47,12 +46,12 @@ ReduceSum(const FLOAT_ACCUM* input, TO* output, uint64_t N, tensor_view_t<1> out } extern "C" __global__ void ReduceSum(const FLOAT_ACCUM* __restrict__ input, - OUTPUT_TYPE* __restrict__ output, + FLOAT* __restrict__ output, uint64_t N, tensor_view_t<1> output_tv) { // instantiate the kernel - ReduceSum(input, output, N, output_tv); + ReduceSum(input, output, N, output_tv); } extern "C" __global__ void ReduceSumFLOATACCUM(const FLOAT_ACCUM* __restrict__ input, @@ -93,12 +92,12 @@ __device__ void Reduce1dSum(const FLOAT_ACCUM* __restrict__ input, } extern "C" __global__ void Reduce1dSum(const FLOAT_ACCUM* __restrict__ input, - OUTPUT_TYPE* __restrict__ output, + FLOAT* __restrict__ output, uint64_t output_numel, uint64_t inner_size, uint64_t outer_size, tensor_view_t<1> output_tv) { // instantiate the kernel - Reduce1dSum(input, output, output_numel, inner_size, outer_size, output_tv); + Reduce1dSum(input, output, output_numel, inner_size, outer_size, output_tv); } diff --git a/src/multimarginloss.cpp b/src/multimarginloss.cpp new file mode 100644 index 0000000000..53b28ce78b --- /dev/null +++ b/src/multimarginloss.cpp @@ -0,0 +1,102 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/miopen.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +std::size_t GetMultiMarginLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& iDesc, + const TensorDescriptor& tDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& oDesc, + const long p, + const float margin, + miopenLossReductionMode_t reduction) +{ + auto ctx = ExecutionContext{&handle}; + const auto problem = multimarginloss::ForwardProblemDescription{ + iDesc, tDesc, wDesc, oDesc, p, margin, reduction}; + + const auto solvers = solver::SolverContainer{}; + + auto pair_size_vector = solvers.GetWorkspaceSizes(ctx, problem); + return pair_size_vector.empty() ? static_cast(-1) : pair_size_vector.front().second; +} + +miopenStatus_t MultiMarginLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& iDesc, + ConstData_t i, + const TensorDescriptor& tDesc, + ConstData_t t, + const TensorDescriptor& wDesc, + ConstData_t w, + const TensorDescriptor& oDesc, + Data_t o, + const long p, + const float margin, + miopenLossReductionMode_t reduction) +{ + const auto problem = multimarginloss::ForwardProblemDescription{ + iDesc, tDesc, wDesc, oDesc, p, margin, reduction}; + + const auto invoke_params = [&]() { + auto tmp = multimarginloss::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.iDesc = &iDesc; + tmp.i = i; + tmp.tDesc = &tDesc; + tmp.t = t; + tmp.wDesc = &wDesc; + tmp.w = w; + tmp.oDesc = &oDesc; + tmp.o = o; + tmp.p = p; + tmp.margin = margin; + tmp.workspace = workspace; + tmp.workspace_size = workspaceSizeInBytes; + + return tmp; + }(); + + const auto algo = AlgorithmName{"MultiMarginLossForward"}; + const auto solvers = solver::SolverContainer{}; + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/multimarginloss/problem_description.cpp b/src/multimarginloss/problem_description.cpp new file mode 100644 index 0000000000..b275d5a4fe --- /dev/null +++ b/src/multimarginloss/problem_description.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include + +#include + +namespace miopen { + +namespace multimarginloss { + +NetworkConfig ForwardProblemDescription::MakeNetworkConfig() const +{ + std::ostringstream ss; + ss << "multilmarginloss_fwd"; + ss << "itype" << iDesc.GetType(); + ss << "ilen"; + auto ilen = iDesc.GetLengths(); + for(unsigned long i : ilen) + ss << i << "_"; + ss << "cont" << iDesc.IsContiguous(); + ss << "reduction" << reduction; + return NetworkConfig{ss.str()}; +} + +} // namespace multimarginloss + +} // namespace miopen diff --git a/src/multimarginloss_api.cpp b/src/multimarginloss_api.cpp new file mode 100644 index 0000000000..0e0d34b371 --- /dev/null +++ b/src/multimarginloss_api.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include + +extern "C" miopenStatus_t +miopenGetMultiMarginLossForwardWorkspaceSize(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + miopenTensorDescriptor_t targetDesc, + miopenTensorDescriptor_t weightDesc, + miopenTensorDescriptor_t outputDesc, + const long p, + const float margin, + miopenLossReductionMode_t reduction, + size_t* sizeInBytes) +{ + MIOPEN_LOG_FUNCTION( + handle, inputDesc, targetDesc, weightDesc, outputDesc, p, margin, reduction); + + return miopen::try_([&] { + miopen::deref(sizeInBytes) = + miopen::GetMultiMarginLossForwardWorkspaceSize(miopen::deref(handle), + miopen::deref(inputDesc), + miopen::deref(targetDesc), + miopen::deref(weightDesc), + miopen::deref(outputDesc), + p, + margin, + reduction); + }); +} + +extern "C" miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t weightDesc, + const void* weight, + miopenTensorDescriptor_t outputDesc, + void* output, + const long p, + const float margin, + miopenLossReductionMode_t reduction, + void* workspace, + size_t workspaceSizeInBytes) +{ + MIOPEN_LOG_FUNCTION(handle, + inputDesc, + input, + targetDesc, + target, + weightDesc, + weight, + outputDesc, + output, + p, + margin, + reduction, + workspace, + workspaceSizeInBytes); + + return miopen::try_([&] { + miopen::MultiMarginLossForward(miopen::deref(handle), + DataCast(workspace), + workspaceSizeInBytes, + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(weightDesc), + DataCast(weight), + miopen::deref(outputDesc), + DataCast(output), + p, + margin, + reduction); + }); +} diff --git a/src/solver.cpp b/src/solver.cpp index 1149255363..3cc0a0c2e2 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -683,6 +684,11 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Activation, glu::GLUForward{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId()); + Register(registry, + ++id, + Primitive::MultiMarginLoss, + multimarginloss::MultiMarginLossForward{}.SolverDbId()); + // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/multimarginloss/forward_multimarginloss.cpp b/src/solver/multimarginloss/forward_multimarginloss.cpp new file mode 100644 index 0000000000..939222ca2f --- /dev/null +++ b/src/solver/multimarginloss/forward_multimarginloss.cpp @@ -0,0 +1,321 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/buffer_info.hpp" +#include "miopen/miopen.h" +#include "miopen/mlo_internal.hpp" +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE_MULTIMARGINLOSS 256 +#define LOCAL_SIZE_REDUCE 256 + +namespace miopen { + +namespace solver { + +namespace multimarginloss { + +bool MultiMarginLossForward::IsImprovementOverROCm( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + int C = problem.GetiDesc().GetLengths()[1]; + if(problem.allContiguousTensor()) + { + switch(problem.GetiDesc().GetType()) + { + case miopenFloat: return C <= 33; + case miopenHalf: return C <= 43; + case miopenBFloat16: return C <= 44; + // Have not tested with other types yet + default: return true; + } + } + else + { + switch(problem.GetiDesc().GetType()) + { + case miopenFloat: return C <= 31; + case miopenHalf: return C <= 38; + case miopenBFloat16: return C <= 40; + // Have not tested with other types yet + default: return true; + } + } +} + +bool MultiMarginLossForward::IsApplicable( + const ExecutionContext& context, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + if(!IsImprovementOverROCm(context, problem)) + return false; + return true; +} + +ConvSolution MultiMarginLossForward::GetSolution( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + auto result = ConvSolution{miopenStatusSuccess}; + + // Start building result.construction_params + auto xgrid = problem.GetiDesc().GetLengths()[0]; + auto dtype = problem.GetiDesc().GetType(); + { + /* Phase 1: Calc loss for each element. */ + size_t xlocalsize = LOCAL_SIZE_MULTIMARGINLOSS; + size_t xgridsize = AlignUp(xgrid, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenMultiMarginLoss.cpp"; + kernel.kernel_name = "MultiMarginLossForward2d"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"REDUCTION_TYPE", static_cast(problem.Getreduction())}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + if(problem.Getreduction() != MIOPEN_LOSS_REDUCTION_NONE) + { + // If Reduction = NONE, then we should run second kernel to calculate mean/sum of result + // from first kernel above + /* Phase 2: Reduce FLOAT_ACCUM -> FLOAT_ACCUM */ + auto _size = xgrid; + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"REDUCE_SIZE", LOCAL_SIZE_REDUCE}, + }; + while(_size > LOCAL_SIZE_REDUCE) + { + size_t xlocalsize = LOCAL_SIZE_REDUCE; + size_t xgridsize = AlignUp(_size, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenReduceSum.cpp"; + kernel.kernel_name = "ReduceSumFLOATACCUM"; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + _size = (_size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; + } + + // Last kernel reduce: FLOAT_ACCUM -> FLOAT + size_t xlocalsize = LOCAL_SIZE_REDUCE; + size_t xgridsize = AlignUp(_size, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + kernel.kernel_file = "MIOpenReduceSum.cpp"; + kernel.kernel_name = "ReduceSum"; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + // End building result.construction_params + + // Start building result.invoker_factory + if(problem.Getreduction() == MIOPEN_LOSS_REDUCTION_NONE) + { + // Reduction = None -> invoke 1 kernel + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); + auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); + auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); + + kernel(params.i, + params.t, + params.w, + params.o, + params.p, + params.margin, + i_tv, + t_tv, + w_tv, + o_tv); + }; + }; + } + else + { + // Reduction != None -> invoke 2 kernels + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) params = raw_params.CastTo(); + auto i_tv = get_inner_expanded_tv<2>(deref(params.iDesc)); + auto t_tv = get_inner_expanded_tv<1>(deref(params.tDesc)); + auto w_tv = get_inner_expanded_tv<1>(deref(params.wDesc)); + auto o_tv = get_inner_expanded_tv<1>(deref(params.oDesc)); + + float elapsed = 0.0f; + HipEventPtr start; + HipEventPtr stop; + + const bool profiling = handle_.IsProfilingEnabled(); + if(profiling) + { + handle_.EnableProfiling(false); + start = miopen::make_hip_event(); + stop = miopen::make_hip_event(); + hipEventRecord(start.get(), handle_.GetStream()); + } + /* Phase 1: Calc loss for each element. */ + { + decltype(auto) kernel = handle_.Run(kernels.front()); + kernel(params.i, + params.t, + params.w, + params.workspace, + params.p, + params.margin, + i_tv, + t_tv, + w_tv, + o_tv); + } + + /* Phase 2: Reduce */ + auto size = deref(params.iDesc).GetLengths()[0]; + auto data_size = get_data_size(miopenFloat); + auto wt = MultiBufferWorkspaceTraits{size * data_size, + (size + LOCAL_SIZE_REDUCE - 1) / + LOCAL_SIZE_REDUCE * data_size}; + auto reduce_in = params.workspace; + auto reduce_out = static_cast(static_cast(params.workspace) + + wt.GetOffset(1)); + + int kernelCnt = 1; + for(kernelCnt; kernelCnt < kernels.size() - 1; kernelCnt++) + { + decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); + kernel(reduce_in, reduce_out, size); + std::swap(reduce_in, reduce_out); + size = (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; + } + + decltype(auto) kernel = handle_.Run(kernels[kernelCnt]); + kernel(reduce_in, params.o, size, o_tv); + + if(profiling) + { + hipEventRecord(stop.get(), handle_.GetStream()); + hipEventSynchronize(stop.get()); + hipEventElapsedTime(&elapsed, start.get(), stop.get()); + + // Clean up + hipEventDestroy(start.get()); + hipEventDestroy(stop.get()); + handle_.ResetKernelTime(); + handle_.AccumKernelTime(elapsed); + + handle_.EnableProfiling(true); + }; + }; + }; + } + // End building result.invoker_factory + + return result; +} + +std::size_t MultiMarginLossForward::GetWorkspaceSize( + const ExecutionContext& /*context*/, + const miopen::multimarginloss::ForwardProblemDescription& problem) const +{ + if(problem.Getreduction() == MIOPEN_LOSS_REDUCTION_NONE) + return 0; + + auto size = problem.GetiDesc().GetLengths()[0]; + auto data_size = get_data_size(miopenFloat); + return MultiBufferWorkspaceTraits{ + size * data_size, (size + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE * data_size} + .GetSize(); +} + +} // namespace multimarginloss + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/prelu/backward_prelu_multi_weights.cpp b/src/solver/prelu/backward_prelu_multi_weights.cpp index 5fed375a2b..c8137bee2e 100644 --- a/src/solver/prelu/backward_prelu_multi_weights.cpp +++ b/src/solver/prelu/backward_prelu_multi_weights.cpp @@ -87,7 +87,6 @@ MultiWeightsBackward::GetSolution(const ExecutionContext& context, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, - {"OUTPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, {"REDUCE_SIZE", LOCAL_SIZE_MW_REDUCE_BWD}, }; result.construction_params.push_back( diff --git a/src/solver/prelu/backward_prelu_single_weight.cpp b/src/solver/prelu/backward_prelu_single_weight.cpp index c554c84914..5d9b0bb00a 100644 --- a/src/solver/prelu/backward_prelu_single_weight.cpp +++ b/src/solver/prelu/backward_prelu_single_weight.cpp @@ -99,7 +99,6 @@ SingleWeightBackward::GetSolution(const ExecutionContext& /*context*/, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, - {"OUTPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, {"REDUCE_SIZE", LOCAL_SIZE_SW_REDUCE_BWD}, }; while(size > LOCAL_SIZE_SW_REDUCE_BWD) diff --git a/test/cpu_multimarginloss.hpp b/test/cpu_multimarginloss.hpp new file mode 100644 index 0000000000..1df2ac6e75 --- /dev/null +++ b/test/cpu_multimarginloss.hpp @@ -0,0 +1,85 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef GUARD_CPU_MULTIMARGINLOSS_HPP +#define GUARD_CPU_MULTIMARGINLOSS_HPP + +#include "miopen/miopen.h" +#include "tensor_holder.hpp" +#include + +template +void cpu_multimarginloss_forward(tensor input, + tensor target, + tensor weight, + tensor& ref_output, + const long p, + const float margin, + miopenLossReductionMode_t reduction_mode) +{ + auto I_tv = miopen::get_inner_expanded_tv<2>(input.desc); + auto T_tv = miopen::get_inner_expanded_tv<1>(target.desc); + auto W_tv = miopen::get_inner_expanded_tv<1>(weight.desc); + auto O_tv = miopen::get_inner_expanded_tv<1>(ref_output.desc); + auto N = I_tv.size[0], C = I_tv.size[1]; + + double sum = 0; + for(size_t n = 0; n < N; n++) + { + double loss = 0; + uint64_t y = target[T_tv.get_tensor_view_idx({n})]; + if(y >= C) + continue; + for(size_t c = 0; c < C; c++) + { + if(y == c) + continue; + double t = margin - static_cast(input[I_tv.get_tensor_view_idx({n, y})]) + + static_cast(input[I_tv.get_tensor_view_idx({n, c})]); + + if(t < 0) + continue; + if(p == 2) + t = t * t; + t = static_cast(weight[W_tv.get_tensor_view_idx({y})]) * t; + loss += t / C; + } + if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + ref_output[O_tv.get_tensor_view_idx({n})] = loss; + else + sum += loss; + } + if(reduction_mode == MIOPEN_LOSS_REDUCTION_MEAN) + { + ref_output[0] = static_cast(sum / N); + } + else if(reduction_mode == MIOPEN_LOSS_REDUCTION_SUM) + { + ref_output[0] = static_cast(sum); + } +} + +#endif diff --git a/test/gtest/multimarginloss.cpp b/test/gtest/multimarginloss.cpp new file mode 100644 index 0000000000..df2653af41 --- /dev/null +++ b/test/gtest/multimarginloss.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "multimarginloss.hpp" +#include + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace multimarginloss { + +std::string GetFloatArg() +{ + const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct GPU_MultiMarginLoss_FP32 : MultiMarginLossForwardTest +{ +}; + +struct GPU_MultiMarginLoss_FP16 : MultiMarginLossForwardTest +{ +}; + +struct GPU_MultiMarginLoss_BFP16 : MultiMarginLossForwardTest +{ +}; + +} // namespace multimarginloss + +using multimarginloss::GPU_MultiMarginLoss_BFP16; +using multimarginloss::GPU_MultiMarginLoss_FP16; +using multimarginloss::GPU_MultiMarginLoss_FP32; + +TEST_P(GPU_MultiMarginLoss_FP32, Test) +{ + if(!MIOPEN_TEST_ALL || + (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(GPU_MultiMarginLoss_FP16, Test) +{ + if(!MIOPEN_TEST_ALL || + (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(GPU_MultiMarginLoss_BFP16, Test) +{ + if(!MIOPEN_TEST_ALL || + (env::enabled(MIOPEN_TEST_ALL) && env::value(MIOPEN_TEST_FLOAT_ARG) == "--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_MultiMarginLoss_FP32, + testing::ValuesIn(MultiMarginLossTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_MultiMarginLoss_FP16, + testing::ValuesIn(MultiMarginLossFp16TestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_MultiMarginLoss_BFP16, + testing::ValuesIn(MultiMarginLossTestConfigs())); diff --git a/test/gtest/multimarginloss.hpp b/test/gtest/multimarginloss.hpp new file mode 100644 index 0000000000..493fa725c1 --- /dev/null +++ b/test/gtest/multimarginloss.hpp @@ -0,0 +1,242 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "cpu_multimarginloss.hpp" +#include "get_handle.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct MultiMarginLossTestCase +{ + std::vector dims; + bool cont; + miopenLossReductionMode_t reduction_mode; + long p; + + friend std::ostream& operator<<(std::ostream& os, const MultiMarginLossTestCase& tc) + { + os << "dims:"; + os << tc.dims[0]; + for(int i = 1; i < tc.dims.size(); i++) + os << "x" << tc.dims[i]; + os << " cont:" << tc.cont << " reduction_mode:" << tc.reduction_mode << " p:" << tc.p; + return os; + } +}; + +inline std::vector MultiMarginLossTestConfigs() +{ + // clang-format off + return { + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{543210, 7}, false, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_SUM, 1 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + }; + // clang-format on +} + +// Remove big tests with reduction from FP16 test because the result will be overflow/ underflow +inline std::vector MultiMarginLossFp16TestConfigs() +{ + // clang-format off + return { + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_MEAN, 1}, + {{22, 12}, false, MIOPEN_LOSS_REDUCTION_SUM, 1}, + {{22, 12}, true, MIOPEN_LOSS_REDUCTION_NONE, 1}, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_MEAN, 2 }, + {{9456, 13}, true, MIOPEN_LOSS_REDUCTION_SUM, 2 }, + {{9456, 13}, false, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{543210, 7}, true, MIOPEN_LOSS_REDUCTION_NONE, 2 }, + {{3995776, 6}, true, MIOPEN_LOSS_REDUCTION_NONE, 1 }, + }; + // clang-format on +} + +template +struct MultiMarginLossForwardTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + auto in_dims = config.dims; + reduction_mode = config.reduction_mode; + p = config.p; + margin = prng::gen_A_to_B(0.5, 1.5); + size_t N = in_dims[0], C = in_dims[1]; + + if(config.cont) + { + input = tensor{in_dims}; + // input is (N, C) -> target is (N) + target = tensor{std::vector{N}}; + // input is (N, C) -> weight is (C) + weight = tensor{std::vector{C}}; + } + else + { + std::vector in_strides(in_dims.size()); + in_strides.back() = 1; + for(int i = in_dims.size() - 2; i >= 0; --i) + in_strides[i] = in_strides[i + 1] * in_dims[i + 1]; + in_strides[0] *= 2; + input = tensor{in_dims, in_strides}; + + std::vector t_len = {N}; + std::vector t_strides = {2}; + target = tensor{t_len, t_strides}; + + std::vector w_lens = {C}; + std::vector w_strides = {2}; + weight = tensor{w_lens, w_strides}; + } + + auto gen_in_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-1), static_cast(1)); + }; + std::generate(input.begin(), input.end(), gen_in_value); + input_dev = handle.Write(input.data); + + for(auto& ptr : target) + { + ptr = prng::gen_A_to_B(0, C); + } + target_dev = handle.Write(target.data); + + auto gen_weight_value = [](auto...) { + return prng::gen_A_to_B(static_cast(-1), static_cast(1)); + }; + std::generate(weight.begin(), weight.end(), gen_weight_value); + weight_dev = handle.Write(weight.data); + + if(reduction_mode == MIOPEN_LOSS_REDUCTION_NONE) + { + // input is (N, C) -> output is (N) + output = tensor{std::vector{N}}; + ref_output = tensor{std::vector{N}}; + } + else + { + // Tensor with 1 element to store result after reduce + output = tensor{std::vector{1}}; + ref_output = tensor{std::vector{1}}; + } + std::fill(output.begin(), output.end(), 0); + std::fill(ref_output.begin(), ref_output.end(), 0); + output_dev = handle.Write(output.data); + + ws_sizeInBytes = miopen::GetMultiMarginLossForwardWorkspaceSize( + handle, input.desc, target.desc, weight.desc, output.desc, p, margin, reduction_mode); + if(ws_sizeInBytes == static_cast(-1)) + GTEST_FAIL() << "Call GetMultiMarginLossForwardWorkspaceSize failed!"; + if(ws_sizeInBytes > 0) + { + workspace = tensor{std::vector{ws_sizeInBytes / sizeof(float)}}; + std::fill(workspace.begin(), workspace.end(), 0); + workspace_dev = handle.Write(workspace.data); + } + else + { + workspace_dev = nullptr; + } + } + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + + cpu_multimarginloss_forward( + input, target, weight, ref_output, p, margin, reduction_mode); + + status = miopen::MultiMarginLossForward(handle, + workspace_dev.get(), + ws_sizeInBytes, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + weight.desc, + weight_dev.get(), + output.desc, + output_dev.get(), + p, + margin, + reduction_mode); + + ASSERT_EQ(status, miopenStatusSuccess); + + // Write from GPU to CPU + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + + auto error = miopen::rms_range(ref_output, output); + ASSERT_EQ(miopen::range_distance(ref_output), miopen::range_distance(output)); + EXPECT_LT(error, tolerance); + } + MultiMarginLossTestCase config; + + tensor input; + tensor target; + tensor weight; + tensor output; + tensor workspace; + + tensor ref_output; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr weight_dev; + miopen::Allocator::ManageDataPtr output_dev; + miopen::Allocator::ManageDataPtr workspace_dev; + + long p; + float margin; + miopenLossReductionMode_t reduction_mode; + size_t ws_sizeInBytes; +}; diff --git a/test/tensor_holder.hpp b/test/tensor_holder.hpp index a5700ff2b8..e1b03880b8 100644 --- a/test/tensor_holder.hpp +++ b/test/tensor_holder.hpp @@ -137,6 +137,11 @@ struct miopen_type : std::integral_constant +struct miopen_type : std::integral_constant +{ +}; + template struct tensor {