diff --git a/docs/reference/index.rst b/docs/reference/index.rst index c2b74eabee..80dbff611e 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,5 +37,6 @@ The MIOpen API library is structured as follows: * :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental) * :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental) * :doc:`ReLU <../doxygen/html/group___re_l_u>` (experimental) - * :doc:`Kthvalue <../doxygen/html/group__kthvalue>` (experimental) - * :doc:`GLU <../doxygen/html/group__glu>` (experimental) + * :doc:`Kthvalue <../doxygen/html/group__loss_function>` (experimental) + * :doc:`GLU <../doxygen/html/group__loss_function>` (experimental) + * :doc:`SigmoidFocalLoss <../doxygen/html/group__loss_function>` (experimental) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 60d6fe6ce6..fb1be229bf 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -57,6 +57,7 @@ add_executable(MIOpenDriver dm_reducecalculation.cpp dm_rnn.cpp dm_rope.cpp + dm_sigmoid_focal_loss.cpp dm_softmarginloss.cpp dm_softmax.cpp dm_t5layernorm.cpp diff --git a/driver/dm_sigmoid_focal_loss.cpp b/driver/dm_sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..3ec7e9ac31 --- /dev/null +++ b/driver/dm_sigmoid_focal_loss.cpp @@ -0,0 +1,41 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "registry_driver_maker.hpp" +#include "sigmoid_focal_loss_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "sigmoidfocalloss") + return new SigmoidFocalLossDriver(); + else if(base_arg == "sigmoidfocallossfp16") + return new SigmoidFocalLossDriver(); + else if(base_arg == "sigmoidfocallossbfp16") + return new SigmoidFocalLossDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index d77d5d02d2..cb6eddb4cf 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -314,7 +314,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, " "getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], " "prelu[bfp16|fp16], kthvalue[bfp16|fp16], glu[bfp16|fp16], softmarginloss[bfp16|fp16], " - "multimarginloss[bfp16|fp16]\n"); + "multimarginloss[bfp16|fp16], sigmoidfocalloss[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -352,7 +352,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "kthvaluebfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" && arg != "softmarginloss" && arg != "softmarginlossfp16" && arg != "softmarginlossbfp16" && arg != "multimarginloss" && arg != "multimarginlossfp16" && arg != "multimarginlossbfp16" && - arg != "--version") + arg != "sigmoidfocalloss" && arg != "sigmoidfocallossfp16" && + arg != "sigmoidfocallossbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/mloSigmoidFocalLossHost.hpp b/driver/mloSigmoidFocalLossHost.hpp new file mode 100644 index 0000000000..2f77cd10ee --- /dev/null +++ b/driver/mloSigmoidFocalLossHost.hpp @@ -0,0 +1,135 @@ +#include +#include + +template +void mloSigmoidFocalLossFwdRunHost(Tgpu* input, + miopenTensorDescriptor_t inputDesc, + Tgpu* target, + miopenTensorDescriptor_t targetDesc, + Tcheck* outputHost, + miopenTensorDescriptor_t outputDesc, + float alpha, + float gamma, + miopenLossReductionMode_t reduction, + float divisor) +{ + auto input_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(inputDesc)); + auto target_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(targetDesc)); + auto output_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc)); + size_t inputSize = miopen::deref(inputDesc).GetElementSize(); + + for(size_t id = 0; id < inputSize; ++id) + { + tensor_layout_t<5> idx(input_tv, id); + + Tcheck i = static_cast(input[input_tv.get_tensor_view_idx(idx)]); + Tcheck t = static_cast(target[target_tv.get_tensor_view_idx(idx)]); + + Tcheck sig = 1 / (1 + exp(-i)); + Tcheck ceLoss = -(t * log(sig) + (1 - t) * log(1 - sig)); + Tcheck sigT = sig * t + (1 - sig) * (1 - t); + Tcheck loss = ceLoss * pow(1 - sigT, gamma); + + if(alpha >= 0) + { + Tcheck alphaT = alpha * t + (1 - alpha) * (1 - t); + loss = alphaT * loss; + } + + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + outputHost[output_tv.get_tensor_view_idx(idx)] = loss; + } + else + { + outputHost[0] += static_cast(loss / divisor); + } + } +} + +template +void mloSigmoidFocalLossBwdRunHost(Tgpu* input, + miopenTensorDescriptor_t inputDesc, + Tgpu* target, + miopenTensorDescriptor_t targetDesc, + Tgpu* doutput, + miopenTensorDescriptor_t doutputDesc, + Tcheck* dinput, + miopenTensorDescriptor_t dinputDesc, + Tcheck* dtarget, + miopenTensorDescriptor_t dtargetDesc, + float alpha, + float gamma, + miopenLossReductionMode_t reduction, + float divisor) +{ + auto input_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(inputDesc)); + auto target_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(targetDesc)); + auto doutput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(doutputDesc)); + auto dinput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(dinputDesc)); + auto dtarget_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(dtargetDesc)); + + size_t inputSize = miopen::deref(inputDesc).GetElementSize(); + + tensor_layout_t<5> doIdx(input_tv, 0); + Tcheck dO = static_cast(doutput[doutput_tv.get_tensor_view_idx(doIdx)]); + + for(size_t id = 0; id < inputSize; ++id) + { + tensor_layout_t<5> idx(input_tv, id); + + Tcheck i = static_cast(input[input_tv.get_tensor_view_idx(idx)]); + Tcheck t = static_cast(target[target_tv.get_tensor_view_idx(idx)]); + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + dO = static_cast(doutput[doutput_tv.get_tensor_view_idx(idx)]); + } + + Tcheck p = 1 / (1 + exp(-i)); + Tcheck ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); + Tcheck pT = p * t + (1 - p) * (1 - t); + Tcheck powPt = pow(1 - pT, gamma); + Tcheck alpha_t = alpha * t + (1 - alpha) * (1 - t); + + if(dinput) + { + Tcheck dpdi = exp(-i) / pow(1 + exp(-i), 2); + Tcheck dcelossdi = (-t / p + (1 - t) / (1 - p)) * dpdi; + Tcheck dpowptdi = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * t) * dpdi; + + // L = ce_loss * pow_pt => dL/di = dceloss/di * pow_pt + ce_loss * dpowpt/di + Tcheck dLdi = dcelossdi * powPt + ceLoss * dpowptdi; + Tcheck grad = dO * dLdi; + + if(alpha >= 0) + { + grad *= alpha_t; + } + if(reduction != MIOPEN_LOSS_REDUCTION_NONE) + { + grad /= divisor; + } + dinput[dinput_tv.get_tensor_view_idx(idx)] = static_cast(grad); + } + + if(dtarget) + { + Tcheck dcelossdt = -log(p) + log(1 - p); + Tcheck dpowptdt = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * p); + // L = ce_loss * pow_pt => dL/dt = dceloss/dt * pow_pt + ce_loss * dpowpt/dt + Tcheck dLdt = dcelossdt * powPt + ceLoss * dpowptdt; + Tcheck gradTarget = dO * dLdt; + + if(alpha >= 0) + { + // alpha_t * dL/dt + dalpha_t/dt * dL + gradTarget = alpha_t * dLdt + (2 * alpha - 1) * ceLoss * powPt; + } + if(reduction != MIOPEN_LOSS_REDUCTION_NONE) + { + gradTarget /= divisor; + } + dtarget[dtarget_tv.get_tensor_view_idx(idx)] = static_cast(gradTarget); + } + } +} diff --git a/driver/sigmoid_focal_loss_driver.hpp b/driver/sigmoid_focal_loss_driver.hpp new file mode 100644 index 0000000000..188336af62 --- /dev/null +++ b/driver/sigmoid_focal_loss_driver.hpp @@ -0,0 +1,526 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once +#include "InputFlags.hpp" +#include "driver.hpp" +#include +#include +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "mloSigmoidFocalLossHost.hpp" +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> +#include +#include + +const float MAX_FP16 = 65504; + +template +class SigmoidFocalLossDriver : public Driver +{ +public: + SigmoidFocalLossDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&targetDesc); + miopenCreateTensorDescriptor(&outputDesc); + miopenCreateTensorDescriptor(&doutputDesc); + miopenCreateTensorDescriptor(&dinputDesc); + miopenCreateTensorDescriptor(&dtargetDesc); + + data_type = miopen_type{}; + } + + std::vector ComputeStrides(std::vector input); + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + int RunBackwardCPU(); + + Tcheck GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~SigmoidFocalLossDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(targetDesc); + miopenDestroyTensorDescriptor(outputDesc); + miopenDestroyTensorDescriptor(doutputDesc); + miopenDestroyTensorDescriptor(dinputDesc); + miopenDestroyTensorDescriptor(dtargetDesc); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t targetDesc; + miopenTensorDescriptor_t outputDesc; + miopenTensorDescriptor_t doutputDesc; + miopenTensorDescriptor_t dinputDesc; + miopenTensorDescriptor_t dtargetDesc; + + std::unique_ptr input_dev; + std::unique_ptr target_dev; + std::unique_ptr output_dev; + std::unique_ptr doutput_dev; + std::unique_ptr dinput_dev; + std::unique_ptr dtarget_dev; + std::unique_ptr workspace_dev; + + std::vector input; + std::vector target; + std::vector output; + std::vector outputHost; + std::vector doutput; + std::vector dinput; + std::vector dinputHost; + std::vector dtarget; + std::vector dtargetHost; + + float alpha; + float gamma; + float divisor; + bool isContiguous; + bool isTargetGradientComputed; + miopenLossReductionMode_t reduction; + + size_t workSpaceSizeInBytes; +}; + +template +int SigmoidFocalLossDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::GetandSetData() +{ + auto inDims = inflags.GetValueTensor("dim-lengths").lengths; + alpha = inflags.GetValueDouble("alpha"); + gamma = inflags.GetValueDouble("gamma"); + isContiguous = inflags.GetValueInt("is-contiguous") == 1 ? true : false; + isTargetGradientComputed = inflags.GetValueInt("target-gradient") == 1 ? true : false; + reduction = static_cast(inflags.GetValueInt("reduction")); + + std::vector inStride = ComputeStrides(inDims); + + SetTensorNd(inputDesc, inDims, inStride, data_type); + SetTensorNd(targetDesc, inDims, inStride, data_type); + SetTensorNd(doutputDesc, inDims, data_type); + SetTensorNd(dinputDesc, inDims, data_type); + + if(isTargetGradientComputed) + { + SetTensorNd(dtargetDesc, inDims, data_type); + } + else + { + std::vector dtargetDim(1); + dtargetDim[0] = 1; + SetTensorNd(dtargetDesc, dtargetDim, data_type); + } + + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + SetTensorNd(outputDesc, inDims, data_type); + } + else + { + std::vector outDim(1); + outDim[0] = 1; + SetTensorNd(outputDesc, outDim, data_type); + divisor = 1; + if(reduction == MIOPEN_LOSS_REDUCTION_MEAN) + { + divisor = miopen::deref(inputDesc).GetElementSize(); + } + } + + return 0; +} + +// Equivalent to: tensor.tranpose(0, -1).contiguous().tranpose(0, -1) incase contiguous = False +template +std::vector SigmoidFocalLossDriver::ComputeStrides(std::vector inputDim) +{ + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; +} + +template +int SigmoidFocalLossDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward (Default=1)", "int"); + inflags.AddTensorFlag( + "dim-lengths", 'D', "256x4x2", "The dimensional lengths of the input tensor"); + inflags.AddInputFlag("is-contiguous", 'c', "1", "is-contiguous (Default=1)", "int"); + inflags.AddInputFlag( + "reduction", 'R', "0", "reduction mode: 0(default) - unreduced, 1 - sum, 2 -mean", "int"); + inflags.AddInputFlag("alpha", 'A', "0.25", "Alpha (Default=0.25)", "float"); + inflags.AddInputFlag("gamma", 'G', "2", "Gamma (Default=2)", "float"); + inflags.AddInputFlag( + "target-gradient", 'T', "0", "Is target gradient computed (Default=0)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = miopen::deref(inputDesc).GetElementSize(); + size_t target_sz = miopen::deref(targetDesc).GetElementSize(); + size_t out_sz = miopen::deref(outputDesc).GetElementSize(); + size_t dO_sz = miopen::deref(doutputDesc).GetElementSize(); + size_t dI_sz = miopen::deref(dinputDesc).GetElementSize(); + size_t dT_sz = miopen::deref(dtargetDesc).GetElementSize(); + + uint32_t ctx = 0; + + input_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); + target_dev = std::unique_ptr(new GPUMem(ctx, target_sz, sizeof(Tgpu))); + output_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + doutput_dev = std::unique_ptr(new GPUMem(ctx, dO_sz, sizeof(Tgpu))); + dinput_dev = std::unique_ptr(new GPUMem(ctx, dI_sz, sizeof(Tgpu))); + dtarget_dev = std::unique_ptr(new GPUMem(ctx, dT_sz, sizeof(Tgpu))); + + miopenGetSigmoidFocalLossForwardWorkspaceSize( + handle, inputDesc, targetDesc, outputDesc, reduction, &workSpaceSizeInBytes); + workspace_dev = std::make_unique(ctx, workSpaceSizeInBytes, sizeof(std::byte)); + + input = std::vector(in_sz, static_cast(0)); + target = std::vector(target_sz, static_cast(0)); + output = std::vector(out_sz, static_cast(0)); + outputHost = std::vector(out_sz, static_cast(0)); + doutput = std::vector(dO_sz, static_cast(0)); + dinput = std::vector(dI_sz, static_cast(0)); + dinputHost = std::vector(dI_sz, static_cast(0)); + dtarget = std::vector(dT_sz, static_cast(0)); + dtargetHost = std::vector(dT_sz, static_cast(0)); + + float randomBound = 2; + // For half, the random bound is smaller to avoid half overflow + if(data_type == miopenHalf && reduction != MIOPEN_LOSS_REDUCTION_NONE) + { + randomBound = 0.5; + } + for(int i = 0; i < in_sz; i++) + { + input[i] = + prng::gen_A_to_B(static_cast(-randomBound), static_cast(randomBound)); + target[i] = + prng::gen_A_to_B(static_cast(-randomBound), static_cast(randomBound)); + } + for(int i = 0; i < dO_sz; ++i) + { + doutput[i] = + prng::gen_A_to_B(static_cast(-randomBound), static_cast(randomBound)); + } + + if(input_dev->ToGPU(GetStream(), input.data()) != 0) + std::cerr << "Error copying (in) to GPU, size: " << input_dev->GetSize() << std::endl; + + if(target_dev->ToGPU(GetStream(), target.data()) != 0) + std::cerr << "Error copying (in) to GPU, size: " << target_dev->GetSize() << std::endl; + + if(output_dev->ToGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << output_dev->GetSize() << std::endl; + + if(doutput_dev->ToGPU(GetStream(), doutput.data()) != 0) + std::cerr << "Error copying (dO) to GPU, size: " << doutput_dev->GetSize() << std::endl; + + if(dinput_dev->ToGPU(GetStream(), dinput.data()) != 0) + std::cerr << "Error copying (dI) to GPU, size: " << dinput_dev->GetSize() << std::endl; + + if(dtarget_dev->ToGPU(GetStream(), dtarget.data()) != 0) + std::cerr << "Error copying (dT) to GPU, size: " << dtarget_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenSigmoidFocalLossForward(GetHandle(), + workspace_dev->GetMem(), + workSpaceSizeInBytes, + inputDesc, + input_dev->GetMem(), + targetDesc, + target_dev->GetMem(), + outputDesc, + output_dev->GetMem(), + alpha, + gamma, + reduction); + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Sigmoid Focal Loss Fwd Elapsed: " << t.gettime_ms() / iter + << " ms" << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Sigmoid Focal Loss Fwd Elapsed: " << kernel_average_time + << " ms" << std::endl; + } + + if(output_dev->FromGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << output_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::RunForwardCPU() +{ + mloSigmoidFocalLossFwdRunHost(input.data(), + inputDesc, + target.data(), + targetDesc, + outputHost.data(), + outputDesc, + alpha, + gamma, + reduction, + divisor); + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::RunBackwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + void* p_dtarget = nullptr; + if(isTargetGradientComputed) + { + p_dtarget = dtarget_dev->GetMem(); + } + + miopenSigmoidFocalLossBackward(GetHandle(), + inputDesc, + input_dev->GetMem(), + targetDesc, + target_dev->GetMem(), + doutputDesc, + doutput_dev->GetMem(), + dinputDesc, + dinput_dev->GetMem(), + dtargetDesc, + p_dtarget, + alpha, + gamma, + reduction); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Sigmoid Focal Loss Bwd Elapsed: " << t.gettime_ms() / iter + << " ms" << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Sigmoid Focal Loss Bwd Elapsed: " << kernel_average_time + << " ms" << std::endl; + } + + if(dinput_dev->FromGPU(GetStream(), dinput.data()) != 0) + std::cerr << "Error copying (dI_dev) from GPU, size: " << dinput_dev->GetSize() + << std::endl; + if(isTargetGradientComputed && dtarget_dev->FromGPU(GetStream(), dtarget.data()) != 0) + std::cerr << "Error copying (dT_dev) from GPU, size: " << dtarget_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::RunBackwardCPU() +{ + Tcheck* p_dtarget = nullptr; + if(isTargetGradientComputed) + { + p_dtarget = dtargetHost.data(); + } + mloSigmoidFocalLossBwdRunHost(input.data(), + inputDesc, + target.data(), + targetDesc, + doutput.data(), + doutputDesc, + dinputHost.data(), + dinputDesc, + p_dtarget, + dtargetDesc, + alpha, + gamma, + reduction, + divisor); + + return miopenStatusSuccess; +} + +template +Tcheck SigmoidFocalLossDriver::GetTolerance() +{ + Tcheck tolerance; + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + } + else + { + tolerance = std::is_same::value ? 1.0e-2 : 8.2e-1; + } + + return tolerance; +} + +template +int SigmoidFocalLossDriver::VerifyForward() +{ + RunForwardCPU(); + + if(miopen::deref(inputDesc).GetType() == miopenHalf && + reduction != MIOPEN_LOSS_REDUCTION_NONE && abs(outputHost[0]) > MAX_FP16) + { + std::cout << "Float16 overflow - CPU output: " << outputHost[0] << std::endl; + } + + const Tcheck tolerance = GetTolerance(); + auto error = miopen::rms_range(outputHost, output); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward " << reduction << " Sigmoid Focal Loss FAILED: " << error << " > " + << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward " << reduction << " Sigmoid Focal Loss Verifies OK on CPU reference (" + << error << "< " << tolerance << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int SigmoidFocalLossDriver::VerifyBackward() +{ + RunBackwardCPU(); + + const Tcheck tolerance = GetTolerance(); + auto dinputError = miopen::rms_range(dinputHost, dinput); + auto dtargetError = miopen::rms_range(dtargetHost, dtarget); + + if(!std::isfinite(dinputError) || dinputError > tolerance) + { + std::cout << "Backward " << reduction << " Sigmoid Focal Loss FAILED: " << dinputError + << " > " << tolerance << std::endl; + return EC_VerifyBwd; + } + else if(isTargetGradientComputed && (!std::isfinite(dtargetError) || dtargetError > tolerance)) + { + std::cout << "Backward " << reduction << " Sigmoid Focal Loss FAILED: " << dtargetError + << " > " << tolerance << std::endl; + return EC_VerifyBwd; + } + else + { + std::cout << "Backward " << reduction + << " Sigmoid Focal Loss Verifies OK on CPU reference (dinput: " << dinputError + << ", dtarget: " << dtargetError << "< " << tolerance << ')' << std::endl; + } + + return miopenStatusSuccess; +} diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 67652ab832..5524874405 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -8092,17 +8092,17 @@ MIOPEN_EXPORT miopenStatus_t miopenSoftMarginLossBackward(miopenHandle_t handle, /** @} */ // CLOSEOUT LossFunction DOXYGEN GROUP -#endif +#endif // MIOPEN_BETA_API -#ifdef MIOPEN_BETA_API // MultiMarginLoss APIs +#ifdef MIOPEN_BETA_API /** @addtogroup LossFunction * * @{ */ /*! @brief Helper function to query the minimum workspace size required by the -MultiMarginLossForward call +MultiMarginLoss Forward call * * @param [in] handle MIOpen Handle * @param [in] inputDesc Tensor descriptor for input tensor (N, C) where N is the batch @@ -8176,6 +8176,98 @@ MIOPEN_EXPORT miopenStatus_t miopenMultiMarginLossForward(miopenHandle_t handle, // CLOSEOUT LossFunction DOXYGEN GROUP #endif // MIOPEN_BETA_API +// SigmoidFocalLoss APIs +#ifdef MIOPEN_BETA_API +/** @addtogroup LossFunction + * + * @{ + */ + +/*! @brief Helper function to query the minimum workspace size required by the SigmoidFocalLoss + * Forward call + * + * @param handle MIOpen Handle (input) + * @param inputDesc Tensor descriptor for input tensor (input) + * @param targetDesc Tensor descriptor for target tensor (input) + * @param outputDesc Tensor descriptor for output tensor (input) + * @param reduction Reduction (input) + * @param sizeInBytes Pointer to data to return the minimum workspace size + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t +miopenGetSigmoidFocalLossForwardWorkspaceSize(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + miopenTensorDescriptor_t targetDesc, + miopenTensorDescriptor_t outputDesc, + miopenLossReductionMode_t reduction, + size_t* sizeInBytes); + +/*! @brief Execute a SigmoidFocalLoss forward layer + * + * @param handle MIOpen handle (input) + * @param workspace Address of the allocated workspace data (input) + * @param workspaceSizeInBytes Size in bytes of the allocated workspace data (input) + * @param inputDesc Tensor descriptor for input tensor (input) + * @param input Data tensor input (input) + * @param targetDesc Tensor descriptor for target tensor (input) + * @param target Data tensor target (input) + * @param outputDesc Tensor descriptor for output tensor (input) + * @param output Data tensor output (output) + * @param alpha Alpha (input) + * @param gamma Gamma (input) + * @param reduction Reduction (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenSigmoidFocalLossForward(miopenHandle_t handle, + void* workspace, + size_t workspaceSizeInBytes, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t outputDesc, + void* output, + float alpha, + float gamma, + miopenLossReductionMode_t reduction); + +/*! @brief Execute a SigmoidFocalLoss backward layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for input tensor (input) + * @param input Data tensor input (input) + * @param targetDesc Tensor descriptor for target tensor (input) + * @param target Data tensor target (input) + * @param doutputDesc Tensor descriptor for output gradient (input) + * @param doutput Gradient of output (input) + * @param dinputDesc Tensor descriptor for input gradient (input) + * @param dinput Gradient of input (output) + * @param dtargetDesc Tensor descriptor for target gradient (input) + * @param dtarget Gradient of target (output) + * @param alpha Alpha (input) + * @param gamma Gamma (input) + * @param reduction Reduction (input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenSigmoidFocalLossBackward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t doutputDesc, + const void* doutput, + miopenTensorDescriptor_t dinputDesc, + void* dinput, + miopenTensorDescriptor_t dtargetDesc, + void* dtarget, + float alpha, + float gamma, + miopenLossReductionMode_t reduction); + +/** @} */ +// CLOSEOUT LossFunction DOXYGEN GROUP +#endif // MIOPEN_BETA_API + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 92e4f4264a..357bffb708 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -196,6 +196,8 @@ set( MIOpen_Source rope_api.cpp rope/problem_description.cpp scalar.cpp + sigmoidfocalloss/problem_description.cpp + sigmoid_focal_loss_api.cpp softmarginloss/problem_description.cpp softmarginloss_api.cpp softmax.cpp @@ -333,6 +335,10 @@ set( MIOpen_Source solver/reduce/forward_sum.cpp solver/rope/backward_rope.cpp solver/rope/forward_rope.cpp + solver/sigmoidfocalloss/backward_reduce_sigmoid_focal_loss.cpp + solver/sigmoidfocalloss/backward_unreduce_sigmoid_focal_loss.cpp + solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp + solver/sigmoidfocalloss/forward_unreduce_sigmoid_focal_loss.cpp solver/softmarginloss/backward_softmarginloss.cpp solver/softmarginloss/forward_softmarginloss.cpp solver/softmax/attn_softmax.cpp @@ -589,6 +595,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/gcnAsmBNBwdTrainSpatial.s kernels/MIOpenTensorKernels.cl kernels/MIOpenTensorKernelsHip.cpp + kernels/MIOpenSigmoidFocalLoss.cpp kernels/MIOpenSubTensorOpWithScalarKernel.cl kernels/MIOpenSubTensorOpWithSubTensorKernel.cl kernels/MIOpenSubTensorOpWithCastTensorKernel.cl @@ -703,6 +710,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN reducecalculation.cpp reduceextreme.cpp rope.cpp + sigmoid_focal_loss.cpp softmarginloss.cpp transformers_adam_w.cpp ${PROJECT_BINARY_DIR}/db_path.cpp diff --git a/src/include/miopen/sigmoid_focal_loss.hpp b/src/include/miopen/sigmoid_focal_loss.hpp new file mode 100644 index 0000000000..353c8b479a --- /dev/null +++ b/src/include/miopen/sigmoid_focal_loss.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_SIGMOID_FOCAL_LOSS_HPP_ +#define MIOPEN_SIGMOID_FOCAL_LOSS_HPP_ + +#include +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +MIOPEN_INTERNALS_EXPORT size_t +GetSigmoidFocalLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& inputDesc, + const TensorDescriptor& targetDesc, + const TensorDescriptor& outputDesc, + miopenLossReductionMode_t reduction); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t SigmoidFocalLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& outputDesc, + Data_t output, + float alpha, + float gamma, + miopenLossReductionMode_t reduction); + +MIOPEN_INTERNALS_EXPORT miopenStatus_t +SigmoidFocalLossBackward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& doutputDesc, + ConstData_t doutput, + const TensorDescriptor& dinputDesc, + Data_t dinput, + const TensorDescriptor& dtargetDesc, + Data_t dtarget, + float alpha, + float gamma, + miopenLossReductionMode_t reduction); + +} // namespace miopen +#endif // MIOPEN_SIGMOID_FOCAL_LOSS_HPP_ diff --git a/src/include/miopen/sigmoidfocalloss/invoke_params.hpp b/src/include/miopen/sigmoidfocalloss/invoke_params.hpp new file mode 100644 index 0000000000..e2801cead2 --- /dev/null +++ b/src/include/miopen/sigmoidfocalloss/invoke_params.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace miopen { + +namespace sigmoidfocalloss { + +struct SigmoidFocalLossInvokeParams : public miopen::InvokeParams +{ + SigmoidFocalLossInvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + const TensorDescriptor* targetDesc = nullptr; + + ConstData_t input = nullptr; + ConstData_t target = nullptr; + Data_t workspace = nullptr; + std::size_t workspace_size = 0; + float alpha = 0.25; + float gamma = 2.0f; + miopenLossReductionMode_t reduction = MIOPEN_LOSS_REDUCTION_NONE; + + std::size_t GetWorkspaceSize() const { return workspace_size; } + Data_t GetWorkspace() const { return workspace; } +}; + +struct FwdInvokeParams : SigmoidFocalLossInvokeParams +{ + FwdInvokeParams() = default; + + const TensorDescriptor* outputDesc = nullptr; + Data_t output = nullptr; +}; + +struct BwdInvokeParams : SigmoidFocalLossInvokeParams +{ + BwdInvokeParams() = default; + + const TensorDescriptor* doutputDesc = nullptr; + const TensorDescriptor* dinputDesc = nullptr; + const TensorDescriptor* dtargetDesc = nullptr; + + ConstData_t doutput = nullptr; + ConstData_t dinput = nullptr; + ConstData_t dtarget = nullptr; +}; + +} // namespace sigmoidfocalloss + +} // namespace miopen diff --git a/src/include/miopen/sigmoidfocalloss/problem_description.hpp b/src/include/miopen/sigmoidfocalloss/problem_description.hpp new file mode 100644 index 0000000000..3590b5c3d4 --- /dev/null +++ b/src/include/miopen/sigmoidfocalloss/problem_description.hpp @@ -0,0 +1,118 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace sigmoidfocalloss { + +bool checkSameLength(const TensorDescriptor& x, const TensorDescriptor& y); + +struct SigmoidFocalLossProblemDescription : ProblemDescriptionBase +{ + SigmoidFocalLossProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& targetDesc_, + const miopenLossReductionMode_t reduction_) + : inputDesc(inputDesc_), targetDesc(targetDesc_), reduction(reduction_) + { + if(!checkSameLength(inputDesc, targetDesc)) + MIOPEN_THROW(miopenStatusBadParm, + "SigmoidFocalLoss: Input, target tensor sizes do not match."); + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetTargetDesc() const { return targetDesc; } + +public: + TensorDescriptor inputDesc; + TensorDescriptor targetDesc; + miopenLossReductionMode_t reduction; +}; + +struct SigmoidFocalLossFwdProblemDescription : SigmoidFocalLossProblemDescription +{ + SigmoidFocalLossFwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& targetDesc_, + const TensorDescriptor& outputDesc_, + const miopenLossReductionMode_t reduction_) + : SigmoidFocalLossProblemDescription(inputDesc_, targetDesc_, reduction_), + outputDesc(outputDesc_) + { + miopenDataType_t dtype = inputDesc.GetType(); + if(dtype != targetDesc.GetType() || dtype != outputDesc.GetType()) + MIOPEN_THROW(miopenStatusBadParm, + "SigmoidFocalLoss: Input, target, output tensor type do not match."); + } + + NetworkConfig MakeNetworkConfig() const override; + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + +public: + TensorDescriptor outputDesc; +}; + +struct SigmoidFocalLossBwdProblemDescription : SigmoidFocalLossProblemDescription +{ + SigmoidFocalLossBwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& targetDesc_, + const TensorDescriptor& doutputDesc_, + const TensorDescriptor& dinputDesc_, + const TensorDescriptor& dtargetDesc_, + const miopenLossReductionMode_t reduction_) + : SigmoidFocalLossProblemDescription(inputDesc_, targetDesc_, reduction_), + doutputDesc(doutputDesc_), + dinputDesc(dinputDesc_), + dtargetDesc(dtargetDesc_) + { + miopenDataType_t dtype = inputDesc.GetType(); + if(dtype != targetDesc.GetType() || dtype != doutputDesc.GetType() || + dtype != dinputDesc.GetType() || dtype != dtargetDesc.GetType()) + MIOPEN_THROW(miopenStatusBadParm, + "SigmoidFocalLoss: Input, target, doutput, dinput, dtarget tensor type do " + "not match."); + } + + NetworkConfig MakeNetworkConfig() const override; + const TensorDescriptor& GetDoutputDesc() const { return doutputDesc; } + const TensorDescriptor& GetDinputDesc() const { return dinputDesc; } + const TensorDescriptor& GetDtargetDesc() const { return dtargetDesc; } + +public: + TensorDescriptor doutputDesc; + TensorDescriptor dinputDesc; + TensorDescriptor dtargetDesc; +}; + +} // namespace sigmoidfocalloss + +} // namespace miopen diff --git a/src/include/miopen/sigmoidfocalloss/solvers.hpp b/src/include/miopen/sigmoidfocalloss/solvers.hpp new file mode 100644 index 0000000000..67d566c935 --- /dev/null +++ b/src/include/miopen/sigmoidfocalloss/solvers.hpp @@ -0,0 +1,125 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace sigmoidfocalloss { + +using SigmoidFocalLossFwdSolverBase = + NonTunableSolverBase; + +struct SigmoidFocalLossFwd final : SigmoidFocalLossFwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& + problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& + problem) const override; + + MultiBufferWorkspaceTraits GetMultiBufferWorkspaceTraits( + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const; + + std::size_t + GetWorkspaceSize(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) + const override; + + bool MayNeedWorkspace() const override { return true; } +}; + +using SigmoidFocalLossBwdSolverBase = + NonTunableSolverBase; + +struct SigmoidFocalLossBwd final : SigmoidFocalLossBwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& + problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& + problem) const override; +}; + +using SigmoidFocalLossUnreducedFwdSolverBase = + NonTunableSolverBase; + +struct SigmoidFocalLossUnreducedFwd final : SigmoidFocalLossUnreducedFwdSolverBase +{ + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } + + bool IsApplicable(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& + problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& + problem) const override; +}; + +using SigmoidFocalLossUnreducedBwdSolverBase = + NonTunableSolverBase; + +struct SigmoidFocalLossUnreducedBwd final : SigmoidFocalLossUnreducedBwdSolverBase +{ + const std::string& SolverDbId() const override + { + return GetSolverDbId(); + } + + bool IsApplicable(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& + problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& + problem) const override; +}; + +} // namespace sigmoidfocalloss + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/sigmoidfocalloss/utils.hpp b/src/include/miopen/sigmoidfocalloss/utils.hpp new file mode 100644 index 0000000000..0dddceea7e --- /dev/null +++ b/src/include/miopen/sigmoidfocalloss/utils.hpp @@ -0,0 +1,49 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +const auto make_hip_kernel = [](std::vector localsize, + std::vector gridsize, + std::string kernel_file, + std::string kernel_name, + miopen::KernelBuildParameters build_params) { + while(localsize.size() < 3) + localsize.push_back(1); + while(gridsize.size() < 3) + gridsize.push_back(1); + for(int i = 0; i < localsize.size(); ++i) + gridsize[i] = AlignUp(gridsize[i], localsize[i]); + return miopen::solver::KernelInfo{build_params.GenerateFor(miopen::kbp::HIP{}), + localsize, + gridsize, + kernel_file, + kernel_name}; +}; diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index f79a5f5a54..3524c33451 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -62,8 +62,7 @@ enum class Primitive RoPE, ReLU, Kthvalue, - SoftMarginLoss, - MultiMarginLoss + Loss }; struct MIOPEN_INTERNALS_EXPORT Id diff --git a/src/kernels/MIOpenSigmoidFocalLoss.cpp b/src/kernels/MIOpenSigmoidFocalLoss.cpp new file mode 100644 index 0000000000..d12335a5f7 --- /dev/null +++ b/src/kernels/MIOpenSigmoidFocalLoss.cpp @@ -0,0 +1,357 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" +#include "tensor_view.hpp" + +#ifndef IN_OUT_TYPE +#define IN_OUT_TYPE float +#endif + +#ifndef CVT_ACCUM2FLOAT +#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) +#endif + +#ifndef CVT_FLOAT2ACCUM +#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) +#endif + +template +__device__ void sigmoidFocalLossFwd(const TIO* input, + TIO* target, + FLOAT_ACCUM* workspace, + float alpha, + float gamma, + float divisor, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv) +{ + /* + Dim: input = target = workspace = {N, C, D, H, W}. + Each thread handle an elem in the input, target tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ + size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + tensor_layout_t<5> idx(input_tv, gid); + if(idx.layout[0] >= input_tv.size[0]) + return; + + FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]); + FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); + + /* The formula follows torchvision package: torchvision/ops/focal_loss.py */ + FLOAT_ACCUM p = 1 / (1 + exp(-i)); + FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); + FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); + FLOAT_ACCUM loss = ceLoss * pow(1 - pT, gamma); + + if(alpha >= 0) + { + FLOAT_ACCUM alpha_t = alpha * t + (1 - alpha) * (1 - t); + loss = alpha_t * loss; + } + + workspace[gid] = loss / divisor; +} + +extern "C" __global__ void SigmoidFocalLossFwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* target, + FLOAT_ACCUM* workspace, + float alpha, + float gamma, + float divisor, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv) +{ + sigmoidFocalLossFwd( + input, target, workspace, alpha, gamma, divisor, input_tv, target_tv); +} + +template +__device__ void sigmoidFocalLossBwd(const TIO* input, + const TIO* target, + const TIO* doutput, + TIO* dinput, + TIO* dtarget, + float alpha, + float gamma, + float divisor, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv, + tensor_view_t<5> doutput_tv, + tensor_view_t<5> dinput_tv, + tensor_view_t<5> dtarget_tv) +{ + /* + Dim: input = target = doutput = dinput = dtarget = {N, C, D, H, W}. + Each thread handle an elem in the input, target, doutput tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ + size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + tensor_layout_t<5> idx(input_tv, gid); + tensor_layout_t<5> doIdx(doutput_tv, 0); + if(idx.layout[0] >= input_tv.size[0]) + return; + + FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]); + FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); + FLOAT_ACCUM dO = CVT_FLOAT2ACCUM(doutput[doutput_tv.get_tensor_view_idx(doIdx)]); + + /* Formula is formed by compute fwd's formula gradient */ + FLOAT_ACCUM p = 1 / (1 + exp(-i)); + FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); + FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); + FLOAT_ACCUM powPt = pow(1 - pT, gamma); + FLOAT_ACCUM alpha_t = alpha * t + (1 - alpha) * (1 - t); + + if(dinput) + { + FLOAT_ACCUM dpdi = exp(-i) / pow(1 + exp(-i), 2); + // dceloss/di = dceloss/dp * dp/di + FLOAT_ACCUM dcelossdi = (-t / p + (1 - t) / (1 - p)) * dpdi; + // dpowt/di = dpowt/dpT * dpT/dp * dp/di + FLOAT_ACCUM dpowptdi = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * t) * dpdi; + + // L = ce_loss * pow_pt => dL/di = dceloss/di * pow_pt + ce_loss * dpowpt/di + FLOAT_ACCUM dLdi = dcelossdi * powPt + ceLoss * dpowptdi; + FLOAT_ACCUM grad = dO * dLdi; + + if(alpha >= 0) + { + grad *= alpha_t; + } + grad /= divisor; + dinput[dinput_tv.get_tensor_view_idx(idx)] = CVT_ACCUM2FLOAT(grad); + } + + if(dtarget) + { + FLOAT_ACCUM dcelossdt = -log(p) + log(1 - p); + FLOAT_ACCUM dpowptdt = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * p); + // L = ce_loss * pow_pt => dL/dt = dceloss/dt * pow_pt + ce_loss * dpowpt/dt + FLOAT_ACCUM dLdt = dcelossdt * powPt + ceLoss * dpowptdt; + FLOAT_ACCUM gradTarget = dO * dLdt; + + if(alpha >= 0) + { + // alpha_t * dL/dt + dalpha_t/dt * dL + gradTarget = alpha_t * dLdt + (2 * alpha - 1) * ceLoss * powPt; + } + gradTarget /= divisor; + dtarget[dtarget_tv.get_tensor_view_idx(idx)] = CVT_ACCUM2FLOAT(gradTarget); + } +} + +extern "C" __global__ void SigmoidFocalLossBwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* target, + IN_OUT_TYPE* doutput, + IN_OUT_TYPE* dinput, + IN_OUT_TYPE* dtarget, + float alpha, + float gamma, + float divisor, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv, + tensor_view_t<5> doutput_tv, + tensor_view_t<5> dinput_tv, + tensor_view_t<5> dtarget_tv) +{ + sigmoidFocalLossBwd(input, + target, + doutput, + dinput, + dtarget, + alpha, + gamma, + divisor, + input_tv, + target_tv, + doutput_tv, + dinput_tv, + dtarget_tv); +} + +template +__device__ void sigmoidFocalLossUnreducedFwd(const TIO* input, + TIO* target, + TIO* output, + float alpha, + float gamma, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv, + tensor_view_t<5> output_tv) +{ + /* + Dim: input = target = output = {N, C, D, H, W}. + Each thread handle an elem in the input, target tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ + size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + tensor_layout_t<5> idx(input_tv, gid); + if(idx.layout[0] >= input_tv.size[0]) + return; + + FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]); + FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); + + /* The formula follows torchvision package: torchvision/ops/focal_loss.py */ + FLOAT_ACCUM p = 1 / (1 + exp(-i)); + FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); + FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); + FLOAT_ACCUM loss = ceLoss * pow(1 - pT, gamma); + + if(alpha >= 0) + { + FLOAT_ACCUM alpha_t = alpha * t + (1 - alpha) * (1 - t); + loss = alpha_t * loss; + } + + output[output_tv.get_tensor_view_idx(idx)] = CVT_ACCUM2FLOAT(loss); +} + +extern "C" __global__ void SigmoidFocalLossUnreducedFwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* target, + IN_OUT_TYPE* output, + float alpha, + float gamma, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv, + tensor_view_t<5> output_tv) +{ + sigmoidFocalLossUnreducedFwd( + input, target, output, alpha, gamma, input_tv, target_tv, output_tv); +} + +template +__device__ void sigmoidFocalLossUnreducedBwd(const TIO* input, + const TIO* target, + const TIO* doutput, + TIO* dinput, + TIO* dtarget, + float alpha, + float gamma, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv, + tensor_view_t<5> doutput_tv, + tensor_view_t<5> dinput_tv, + tensor_view_t<5> dtarget_tv) +{ + /* + Dim: input = target = doutput = dinput = dtarget = {N, C, D, H, W}. + Each thread handle an elem in the input, target, doutput tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ + size_t gid = threadIdx.x + blockIdx.x * blockDim.x; + + tensor_layout_t<5> idx(input_tv, gid); + if(idx.layout[0] >= input_tv.size[0]) + return; + + FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]); + FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); + FLOAT_ACCUM dO = CVT_FLOAT2ACCUM(doutput[doutput_tv.get_tensor_view_idx(idx)]); + + /* Formula is formed by compute fwd's formula gradient */ + FLOAT_ACCUM p = 1 / (1 + exp(-i)); + FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); + FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); + FLOAT_ACCUM powPt = pow(1 - pT, gamma); + FLOAT_ACCUM alpha_t = alpha * t + (1 - alpha) * (1 - t); + + if(dinput) + { + FLOAT_ACCUM dpdi = exp(-i) / pow(1 + exp(-i), 2); + // dceloss/di = dceloss/dp * dp/di + FLOAT_ACCUM dcelossdi = (-t / p + (1 - t) / (1 - p)) * dpdi; + // dpowt/di = dpowt/dpT * dpT/dp * dp/di + FLOAT_ACCUM dpowptdi = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * t) * dpdi; + + // L = ce_loss * pow_pt => dL/di = dceloss/di * pow_pt + ce_loss * dpowpt/di + FLOAT_ACCUM dLdi = dcelossdi * powPt + ceLoss * dpowptdi; + FLOAT_ACCUM grad = dO * dLdi; + + if(alpha >= 0) + { + grad *= alpha_t; + } + dinput[dinput_tv.get_tensor_view_idx(idx)] = CVT_ACCUM2FLOAT(grad); + } + + if(dtarget) + { + FLOAT_ACCUM dcelossdt = -log(p) + log(1 - p); + FLOAT_ACCUM dpowptdt = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * p); + // L = ce_loss * pow_pt => dL/dt = dceloss/dt * pow_pt + ce_loss * dpowpt/dt + FLOAT_ACCUM dLdt = dcelossdt * powPt + ceLoss * dpowptdt; + FLOAT_ACCUM gradTarget = dO * dLdt; + + if(alpha >= 0) + { + // alpha_t * dL/dt + dalpha_t/dt * dL + gradTarget = alpha_t * dLdt + (2 * alpha - 1) * ceLoss * powPt; + } + dtarget[dtarget_tv.get_tensor_view_idx(idx)] = CVT_ACCUM2FLOAT(gradTarget); + } +} + +extern "C" __global__ void SigmoidFocalLossUnreducedBwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* target, + IN_OUT_TYPE* doutput, + IN_OUT_TYPE* dinput, + IN_OUT_TYPE* dtarget, + float alpha, + float gamma, + tensor_view_t<5> input_tv, + tensor_view_t<5> target_tv, + tensor_view_t<5> doutput_tv, + tensor_view_t<5> dinput_tv, + tensor_view_t<5> dtarget_tv) +{ + sigmoidFocalLossUnreducedBwd(input, + target, + doutput, + dinput, + dtarget, + alpha, + gamma, + input_tv, + target_tv, + doutput_tv, + dinput_tv, + dtarget_tv); +} diff --git a/src/sigmoid_focal_loss.cpp b/src/sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..3858f0a918 --- /dev/null +++ b/src/sigmoid_focal_loss.cpp @@ -0,0 +1,171 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +size_t GetSigmoidFocalLossForwardWorkspaceSize(Handle& handle, + const TensorDescriptor& inputDesc, + const TensorDescriptor& targetDesc, + const TensorDescriptor& outputDesc, + miopenLossReductionMode_t reduction) +{ + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + return 0; + } + + auto ctx = ExecutionContext{&handle}; + const auto problem = sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription{ + inputDesc, targetDesc, outputDesc, reduction}; + + const auto algo = AlgorithmName{"SigmoidFocalLossFwd"}; + const auto solvers = solver::SolverContainer{}; + + auto pair_size_vector = solvers.GetWorkspaceSizes(ctx, problem); + + return pair_size_vector.empty() ? static_cast(-1) : pair_size_vector.front().second; +} + +miopenStatus_t SigmoidFocalLossForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& outputDesc, + Data_t output, + float alpha, + float gamma, + miopenLossReductionMode_t reduction) +{ + const auto problem = sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription{ + inputDesc, targetDesc, outputDesc, reduction}; + + const auto invoke_params = [&]() { + auto tmp = sigmoidfocalloss::FwdInvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.targetDesc = &targetDesc; + tmp.outputDesc = &outputDesc; + tmp.input = input; + tmp.target = target; + tmp.output = output; + tmp.workspace = workspace; + tmp.workspace_size = workspaceSizeInBytes; + tmp.alpha = alpha; + tmp.gamma = gamma; + tmp.reduction = reduction; + return tmp; + }(); + + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + const auto algo = AlgorithmName{"SigmoidFocalLossUnreducedFwd"}; + const auto solvers = + solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + } + else + { + const auto algo = AlgorithmName{"SigmoidFocalLossFwd"}; + const auto solvers = + solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + } + + return miopenStatusSuccess; +} + +miopenStatus_t SigmoidFocalLossBackward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& targetDesc, + ConstData_t target, + const TensorDescriptor& doutputDesc, + ConstData_t doutput, + const TensorDescriptor& dinputDesc, + Data_t dinput, + const TensorDescriptor& dtargetDesc, + Data_t dtarget, + float alpha, + float gamma, + const miopenLossReductionMode_t reduction) +{ + const auto problem = sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription{ + inputDesc, targetDesc, doutputDesc, dinputDesc, dtargetDesc, reduction}; + + const auto invoke_params = [&]() { + auto tmp = sigmoidfocalloss::BwdInvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.targetDesc = &targetDesc; + tmp.doutputDesc = &doutputDesc; + tmp.dinputDesc = &dinputDesc; + tmp.dtargetDesc = &dtargetDesc; + tmp.input = input; + tmp.target = target; + tmp.doutput = doutput; + tmp.dinput = dinput; + tmp.dtarget = dtarget; + tmp.alpha = alpha; + tmp.gamma = gamma; + tmp.reduction = reduction; + return tmp; + }(); + + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + const auto algo = AlgorithmName{"SigmoidFocalLossUnreducedBwd"}; + const auto solvers = + solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + } + else + { + const auto algo = AlgorithmName{"SigmoidFocalLossBwd"}; + const auto solvers = + solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + } + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/sigmoid_focal_loss_api.cpp b/src/sigmoid_focal_loss_api.cpp new file mode 100644 index 0000000000..2cc511bb28 --- /dev/null +++ b/src/sigmoid_focal_loss_api.cpp @@ -0,0 +1,192 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +inline std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + os << '{'; + for(int i = 0; i < v.size(); ++i) + { + if(i != 0) + os << ','; + os << v[i]; + } + os << '}'; + return os; +} + +static void LogCmdSigmoidFocalLoss(const miopenTensorDescriptor_t inputDesc, + const miopenTensorDescriptor_t targetDesc, + bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(inputDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "sigmoidFocalLossfp16"; + } + else if(dtype == miopenFloat) + { + ss << "sigmoidFocalLossfp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "sigmoidFocalLossbfp16"; + } + + MIOPEN_LOG_FUNCTION(inputDesc, targetDesc); + ss << " -n " << miopen::deref(inputDesc).GetLengths()[0]; + ss << " -T " << miopen::deref(inputDesc).GetLengths(); + ss << " -Si " << miopen::deref(inputDesc).GetStrides(); + ss << " -St " << miopen::deref(targetDesc).GetStrides(); + ss << " -F " << ((is_fwd) ? "1" : "2"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t +miopenGetSigmoidFocalLossForwardWorkspaceSize(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + const miopenTensorDescriptor_t targetDesc, + const miopenTensorDescriptor_t outputDesc, + miopenLossReductionMode_t reduction, + size_t* sizeInBytes) +{ + + MIOPEN_LOG_FUNCTION(handle, inputDesc, targetDesc, outputDesc, sizeInBytes); + + return miopen::try_([&] { + miopen::deref(sizeInBytes) = + miopen::GetSigmoidFocalLossForwardWorkspaceSize(miopen::deref(handle), + miopen::deref(inputDesc), + miopen::deref(targetDesc), + miopen::deref(outputDesc), + reduction); + }); +} + +extern "C" miopenStatus_t miopenSigmoidFocalLossForward(miopenHandle_t handle, + void* workspace, + size_t workspaceSizeInBytes, + const miopenTensorDescriptor_t inputDesc, + const void* input, + const miopenTensorDescriptor_t targetDesc, + const void* target, + const miopenTensorDescriptor_t outputDesc, + void* output, + const float alpha, + const float gamma, + const miopenLossReductionMode_t reduction) +{ + MIOPEN_LOG_FUNCTION(handle, + workspace, + workspaceSizeInBytes, + inputDesc, + input, + targetDesc, + target, + outputDesc, + output, + alpha, + gamma, + reduction); + + LogCmdSigmoidFocalLoss(inputDesc, targetDesc, true); + + return miopen::try_([&] { + miopen::SigmoidFocalLossForward(miopen::deref(handle), + DataCast(workspace), + workspaceSizeInBytes, + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(outputDesc), + DataCast(output), + alpha, + gamma, + reduction); + }); +} + +extern "C" miopenStatus_t miopenSigmoidFocalLossBackward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t targetDesc, + const void* target, + miopenTensorDescriptor_t doutputDesc, + const void* doutput, + miopenTensorDescriptor_t dinputDesc, + void* dinput, + miopenTensorDescriptor_t dtargetDesc, + void* dtarget, + float alpha, + float gamma, + const miopenLossReductionMode_t reduction) +{ + MIOPEN_LOG_FUNCTION(handle, + inputDesc, + input, + targetDesc, + target, + doutputDesc, + doutput, + dinputDesc, + dinput, + dtargetDesc, + dtarget, + alpha, + gamma, + reduction); + + LogCmdSigmoidFocalLoss(inputDesc, targetDesc, false); + + return miopen::try_([&] { + miopen::SigmoidFocalLossBackward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(targetDesc), + DataCast(target), + miopen::deref(doutputDesc), + DataCast(doutput), + miopen::deref(dinputDesc), + DataCast(dinput), + miopen::deref(dtargetDesc), + DataCast(dtarget), + alpha, + gamma, + reduction); + }); +} diff --git a/src/sigmoidfocalloss/problem_description.cpp b/src/sigmoidfocalloss/problem_description.cpp new file mode 100644 index 0000000000..825df9286e --- /dev/null +++ b/src/sigmoidfocalloss/problem_description.cpp @@ -0,0 +1,88 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace sigmoidfocalloss { + +bool checkSameLength(const TensorDescriptor& x, const TensorDescriptor& y) +{ + if(x.GetNumDims() != y.GetNumDims()) + return false; + for(int32_t i = 0; i < x.GetNumDims(); ++i) + { + if(x.GetLengths()[i] != y.GetLengths()[i]) + return false; + } + return true; +} + +NetworkConfig SigmoidFocalLossBwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = inputDesc.GetType(); + auto target_dtype = targetDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto dim_num = inputDesc.GetNumDims(); + + std::ostringstream ss; + + ss << "sfl_bwd"; + ss << "reduction" << reduction; + ss << "i_dtype" << input_dtype; + ss << "t_dtype" << target_dtype; + ss << "dim_num" << dim_num; + ss << "size" << size; + + return NetworkConfig{ss.str()}; +} + +NetworkConfig SigmoidFocalLossFwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = inputDesc.GetType(); + auto target_dtype = targetDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto dim_num = inputDesc.GetNumDims(); + + std::ostringstream ss; + + ss << "sfl_fwd"; + ss << "reduction" << reduction; + ss << "i_dtype" << input_dtype; + ss << "t_dtype" << target_dtype; + ss << "dim_num" << dim_num; + ss << "size" << size; + + return NetworkConfig{ss.str()}; +} + +} // namespace sigmoidfocalloss + +} // namespace miopen diff --git a/src/solver.cpp b/src/solver.cpp index 1f6873d5f7..167a085872 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -30,19 +30,20 @@ #include #include #include +#include #include #include -#include #include #include +#include +#include #include #include #include #include -#include +#include #include #include -#include #include #include @@ -688,27 +689,35 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::RoPE, rope::RoPEForward{}.SolverDbId()); Register(registry, ++id, Primitive::RoPE, rope::RoPEBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::ReLU, prelu::MultiWeightsBackward{}.SolverDbId()); Register(registry, ++id, Primitive::ReLU, prelu::SingleWeightBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Kthvalue, kthvalue::KthvalueFwd{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, glu::GLUForward{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Loss, softmarginloss::SoftMarginLossForward{}.SolverDbId()); + Register( + registry, ++id, Primitive::Loss, softmarginloss::SoftMarginLossBackward{}.SolverDbId()); + + Register( + registry, ++id, Primitive::Loss, multimarginloss::MultiMarginLossForward{}.SolverDbId()); + + Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId()); + Register(registry, ++id, - Primitive::SoftMarginLoss, - softmarginloss::SoftMarginLossForward{}.SolverDbId()); - Register(registry, - ++id, - Primitive::SoftMarginLoss, - softmarginloss::SoftMarginLossBackward{}.SolverDbId()); + Primitive::Loss, + sigmoidfocalloss::SigmoidFocalLossUnreducedFwd{}.SolverDbId()); Register(registry, ++id, - Primitive::MultiMarginLoss, - multimarginloss::MultiMarginLossForward{}.SolverDbId()); + Primitive::Loss, + sigmoidfocalloss::SigmoidFocalLossUnreducedBwd{}.SolverDbId()); + Register(registry, ++id, Primitive::Loss, sigmoidfocalloss::SigmoidFocalLossFwd{}.SolverDbId()); + Register(registry, ++id, Primitive::Loss, sigmoidfocalloss::SigmoidFocalLossBwd{}.SolverDbId()); - Register(registry, ++id, Primitive::Mha, mha::MhaCKFlashAttentionV2Forward{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function, and don't leave a white // space between this comment and the newly registered solver(s)! } diff --git a/src/solver/sigmoidfocalloss/backward_reduce_sigmoid_focal_loss.cpp b/src/solver/sigmoidfocalloss/backward_reduce_sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..4e5046da49 --- /dev/null +++ b/src/solver/sigmoidfocalloss/backward_reduce_sigmoid_focal_loss.cpp @@ -0,0 +1,119 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace sigmoidfocalloss { + +bool SigmoidFocalLossBwd::IsApplicable( + const ExecutionContext& /*context*/, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& problem) const +{ + if(problem.GetInputDesc().GetNumDims() > 5) + return false; + return true; +} + +ConvSolution SigmoidFocalLossBwd::GetSolution( + const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto dtype = problem.GetDinputDesc().GetType(); + auto target_dtype = miopen::GetDataType(problem.GetTargetDesc().GetType()); + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"TARGET_TYPE", target_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + result.construction_params.push_back(make_hip_kernel({LOCAL_SIZE}, + {problem.GetInputDesc().GetElementSize()}, + "MIOpenSigmoidFocalLoss.cpp", + "SigmoidFocalLossBwd", + build_params)); + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto target_tv = get_inner_expanded_tv<5>(deref(params.targetDesc)); + auto doutput_tv = get_inner_expanded_tv<5>(deref(params.doutputDesc)); + auto dinput_tv = get_inner_expanded_tv<5>(deref(params.dinputDesc)); + auto dtarget_tv = get_inner_expanded_tv<5>(deref(params.dtargetDesc)); + float divisor = 1; + if(params.reduction == MIOPEN_LOSS_REDUCTION_MEAN) + { + divisor = deref(params.inputDesc).GetElementSize(); + } + + kernel(params.input, + params.target, + params.doutput, + params.dinput, + params.dtarget, + params.alpha, + params.gamma, + divisor, + input_tv, + target_tv, + doutput_tv, + dinput_tv, + dtarget_tv); + }; + }; + + return result; +} + +} // namespace sigmoidfocalloss + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/sigmoidfocalloss/backward_unreduce_sigmoid_focal_loss.cpp b/src/solver/sigmoidfocalloss/backward_unreduce_sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..8d34198d73 --- /dev/null +++ b/src/solver/sigmoidfocalloss/backward_unreduce_sigmoid_focal_loss.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace sigmoidfocalloss { + +bool SigmoidFocalLossUnreducedBwd::IsApplicable( + const ExecutionContext& /*context*/, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& problem) const +{ + if(problem.GetInputDesc().GetNumDims() > 5) + return false; + return true; +} + +ConvSolution SigmoidFocalLossUnreducedBwd::GetSolution( + const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossBwdProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto dtype = problem.GetDinputDesc().GetType(); + auto target_dtype = miopen::GetDataType(problem.GetTargetDesc().GetType()); + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"TARGET_TYPE", target_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + result.construction_params.push_back(make_hip_kernel({LOCAL_SIZE}, + {problem.GetInputDesc().GetElementSize()}, + "MIOpenSigmoidFocalLoss.cpp", + "SigmoidFocalLossUnreducedBwd", + build_params)); + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto target_tv = get_inner_expanded_tv<5>(deref(params.targetDesc)); + auto doutput_tv = get_inner_expanded_tv<5>(deref(params.doutputDesc)); + auto dinput_tv = get_inner_expanded_tv<5>(deref(params.dinputDesc)); + auto dtarget_tv = get_inner_expanded_tv<5>(deref(params.dtargetDesc)); + + kernel(params.input, + params.target, + params.doutput, + params.dinput, + params.dtarget, + params.alpha, + params.gamma, + input_tv, + target_tv, + doutput_tv, + dinput_tv, + dtarget_tv); + }; + }; + + return result; +} + +} // namespace sigmoidfocalloss + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp b/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..5af00b9701 --- /dev/null +++ b/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp @@ -0,0 +1,202 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/buffer_info.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE_SIGMOIDFOCALLOSS 256 +#define LOCAL_SIZE_REDUCE 256 + +namespace miopen { + +namespace solver { + +namespace sigmoidfocalloss { + +bool SigmoidFocalLossFwd::IsApplicable( + const ExecutionContext& /*context*/, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + if(problem.GetInputDesc().GetNumDims() > 5) + return false; + return true; +} + +ConvSolution SigmoidFocalLossFwd::GetSolution( + const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto target_dtype = miopen::GetDataType(problem.GetTargetDesc().GetType()); + auto size = problem.GetInputDesc().GetElementSize(); + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"TARGET_TYPE", target_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"REDUCE_SIZE", LOCAL_SIZE_REDUCE}, + }; + + /* Prepare params for loss kernel */ + result.construction_params.push_back(make_hip_kernel({LOCAL_SIZE_SIGMOIDFOCALLOSS}, + {size}, + "MIOpenSigmoidFocalLoss.cpp", + "SigmoidFocalLossFwd", + build_params)); + + /* Prepare params for reduce kernels */ + auto _size = size; + while(_size > LOCAL_SIZE_REDUCE) + { + result.construction_params.push_back(make_hip_kernel({LOCAL_SIZE_REDUCE}, + {_size}, + "MIOpenReduceSum.cpp", + "ReduceSumFLOATACCUM", + build_params)); + // {LOCAL_SIZE_REDUCE}, {_size}, "MIOpenLossSum.cpp", "LossSum", build_params)); + _size = AlignUp(_size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; + } + + result.construction_params.push_back(make_hip_kernel( + {LOCAL_SIZE_REDUCE}, {_size}, "MIOpenReduceSum.cpp", "ReduceSum", build_params)); + + result.invoker_factory = [this, problem](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) params = raw_params.CastTo(); + auto size = deref(params.inputDesc).GetElementSize(); + + auto elapsed = 0.f; + HipEventPtr start; + HipEventPtr stop; + + const bool profiling = handle_.IsProfilingEnabled(); + if(profiling) + { + handle_.EnableProfiling(false); + start = miopen::make_hip_event(); + stop = miopen::make_hip_event(); + hipEventRecord(start.get(), handle_.GetStream()); + } + + /* Execute loss kernel */ + { + decltype(auto) kernel = handle_.Run(kernels.front()); + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto target_tv = get_inner_expanded_tv<5>(deref(params.targetDesc)); + float divisor = 1; + if(params.reduction == MIOPEN_LOSS_REDUCTION_MEAN) + { + divisor = size; + } + + kernel(params.input, + params.target, + params.workspace, + params.alpha, + params.gamma, + divisor, + input_tv, + target_tv); + } + + /* Execute reduce kernels */ + auto wt = GetMultiBufferWorkspaceTraits(problem); + auto reduceIn = params.workspace; + auto reduceOut = + static_cast(static_cast(params.workspace) + wt.GetOffset(1)); + + for(int i = 1; i < kernels.size(); ++i) + { + decltype(auto) kernel = handle_.Run(kernels[i]); + if(i + 1 != kernels.size()) + { + kernel(reduceIn, reduceOut, size); + std::swap(reduceIn, reduceOut); + } + else + { + auto output_tv = get_inner_expanded_tv<1>(deref(params.outputDesc)); + kernel(reduceIn, params.output, size, output_tv); + } + size = AlignUp(size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; + } + + if(profiling) + { + hipEventRecord(stop.get(), handle_.GetStream()); + hipEventSynchronize(stop.get()); + hipEventElapsedTime(&elapsed, start.get(), stop.get()); + + hipEventDestroy(start.get()); + hipEventDestroy(stop.get()); + handle_.ResetKernelTime(); + handle_.AccumKernelTime(elapsed); + + handle_.EnableProfiling(true); + }; + }; + }; + + return result; +} + +std::size_t SigmoidFocalLossFwd::GetWorkspaceSize( + const ExecutionContext& /*context*/, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + return GetMultiBufferWorkspaceTraits(problem).GetSize(); +} + +MultiBufferWorkspaceTraits SigmoidFocalLossFwd::GetMultiBufferWorkspaceTraits( + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + size_t inputElements = problem.GetInputDesc().GetElementSize(); + size_t reduceElements = (inputElements + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; + size_t elementSize = get_data_size(miopenFloat); + + return MultiBufferWorkspaceTraits{inputElements * elementSize, reduceElements * elementSize}; +} + +} // namespace sigmoidfocalloss + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/sigmoidfocalloss/forward_unreduce_sigmoid_focal_loss.cpp b/src/solver/sigmoidfocalloss/forward_unreduce_sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..91e8b48e49 --- /dev/null +++ b/src/solver/sigmoidfocalloss/forward_unreduce_sigmoid_focal_loss.cpp @@ -0,0 +1,107 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace sigmoidfocalloss { + +bool SigmoidFocalLossUnreducedFwd::IsApplicable( + const ExecutionContext& /*context*/, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + if(problem.GetInputDesc().GetNumDims() > 5) + return false; + return true; +} + +ConvSolution SigmoidFocalLossUnreducedFwd::GetSolution( + const ExecutionContext& context, + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + auto in_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto target_dtype = miopen::GetDataType(problem.GetTargetDesc().GetType()); + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"TARGET_TYPE", target_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + result.construction_params.push_back(make_hip_kernel({LOCAL_SIZE}, + {problem.GetInputDesc().GetElementSize()}, + "MIOpenSigmoidFocalLoss.cpp", + "SigmoidFocalLossUnreducedFwd", + build_params)); + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto target_tv = get_inner_expanded_tv<5>(deref(params.targetDesc)); + auto output_tv = get_inner_expanded_tv<5>(deref(params.outputDesc)); + + kernel(params.input, + params.target, + params.output, + params.alpha, + params.gamma, + input_tv, + target_tv, + output_tv); + }; + }; + + return result; +} + +} // namespace sigmoidfocalloss + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_sigmoid_focal_loss.hpp b/test/cpu_sigmoid_focal_loss.hpp new file mode 100644 index 0000000000..fe21c94e27 --- /dev/null +++ b/test/cpu_sigmoid_focal_loss.hpp @@ -0,0 +1,157 @@ +#pragma once + +#include "miopen/miopen.h" +#include "tensor_holder.hpp" +#include "tensor_view.hpp" +#include +#include + +template +void cpu_sigmoid_focal_loss_forward(tensor input, + tensor target, + tensor& outputHost, + float alpha, + float gamma, + miopenLossReductionMode_t reduction, + float divisor) +{ + auto input_tv = miopen::get_inner_expanded_tv<5>(input.desc); + auto target_tv = miopen::get_inner_expanded_tv<5>(target.desc); + auto output_tv = miopen::get_inner_expanded_tv<5>(outputHost.desc); + size_t inputSize = input.desc.GetElementSize(); + float outputFloat = 0; + + for(size_t id = 0; id < inputSize; ++id) + { + tensor_layout_t<5> idx(input_tv, id); + + float i = static_cast(input[input_tv.get_tensor_view_idx(idx)]); + float t = static_cast(target[target_tv.get_tensor_view_idx(idx)]); + + float sig = 1 / (1 + std::exp(-i)); + float ceLoss = -(t * std::log(sig) + (1 - t) * std::log(1 - sig)); + float sigT = sig * t + (1 - sig) * (1 - t); + float loss = ceLoss * std::pow(1 - sigT, gamma); + + if(alpha >= 0) + { + float alphaT = alpha * t + (1 - alpha) * (1 - t); + loss = alphaT * loss; + } + + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + outputHost[output_tv.get_tensor_view_idx(idx)] = static_cast(loss); + } + else + { + outputFloat += loss / divisor; + } + } + + if(reduction != MIOPEN_LOSS_REDUCTION_NONE) + { + outputHost[0] = static_cast(outputFloat); + } +} + +template +void cpu_sigmoid_focal_loss_backward(tensor input, + tensor target, + tensor doutput, + tensor& dinput, + tensor& dtarget, + float alpha, + float gamma, + miopenLossReductionMode_t reduction, + float divisor) +{ + auto input_tv = miopen::get_inner_expanded_tv<5>(input.desc); + auto target_tv = miopen::get_inner_expanded_tv<5>(target.desc); + auto doutput_tv = miopen::get_inner_expanded_tv<5>(doutput.desc); + auto dinput_tv = miopen::get_inner_expanded_tv<5>(dinput.desc); + auto dtarget_tv = miopen::get_inner_expanded_tv<5>(dtarget.desc); + + size_t inputSize = input.desc.GetElementSize(); + + tensor_layout_t<5> doIdx(input_tv, 0); + float dO = static_cast(doutput[doutput_tv.get_tensor_view_idx(doIdx)]); + + for(size_t id = 0; id < inputSize; ++id) + { + tensor_layout_t<5> idx(input_tv, id); + + float i = static_cast(input[input_tv.get_tensor_view_idx(idx)]); + float t = static_cast(target[target_tv.get_tensor_view_idx(idx)]); + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + dO = static_cast(doutput[doutput_tv.get_tensor_view_idx(idx)]); + } + + float p = 1 / (1 + std::exp(-i)); + float ceLoss = -(t * std::log(p) + (1 - t) * std::log(1 - p)); + float pT = p * t + (1 - p) * (1 - t); + float powPt = std::pow(1 - pT, gamma); + float alpha_t = alpha * t + (1 - alpha) * (1 - t); + + if(dinput.data.size() > 0) + { + float dpdi = std::exp(-i) / std::pow(1 + std::exp(-i), 2); + float dcelossdi = (-t / p + (1 - t) / (1 - p)) * dpdi; + float dpowptdi = gamma * std::pow(1 - pT, gamma - 1) * (1 - 2 * t) * dpdi; + + // L = ce_loss * pow_pt => dL/di = dceloss/di * pow_pt + ce_loss * dpowpt/di + float dLdi = dcelossdi * powPt + ceLoss * dpowptdi; + float grad = dO * dLdi; + + if(alpha >= 0) + { + grad *= alpha_t; + } + if(reduction != MIOPEN_LOSS_REDUCTION_NONE) + { + grad /= divisor; + } + dinput[dinput_tv.get_tensor_view_idx(idx)] = static_cast(grad); + } + + if(dtarget.data.size() > 0) + { + float dcelossdt = -std::log(p) + std::log(1 - p); + float dpowptdt = gamma * std::pow(1 - pT, gamma - 1) * (1 - 2 * p); + // L = ce_loss * pow_pt => dL/dt = dceloss/dt * pow_pt + ce_loss * dpowpt/dt + float dLdt = dcelossdt * powPt + ceLoss * dpowptdt; + float gradTarget = dO * dLdt; + + if(alpha >= 0) + { + // alpha_t * dL/dt + dalpha_t/dt * dL + gradTarget = alpha_t * dLdt + (2 * alpha - 1) * ceLoss * powPt; + } + if(reduction != MIOPEN_LOSS_REDUCTION_NONE) + { + gradTarget /= divisor; + } + dtarget[dtarget_tv.get_tensor_view_idx(idx)] = static_cast(gradTarget); + } + } +} + +template +float get_tolerance(miopenLossReductionMode_t reduction) +{ + float tolerance; + if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + { + tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + } + else + { + tolerance = std::is_same::value ? 1.0e-2 : 8.2e-1; + } + + return tolerance; +} diff --git a/test/gtest/sigmoid_focal_loss.cpp b/test/gtest/sigmoid_focal_loss.cpp new file mode 100644 index 0000000000..48982ee8db --- /dev/null +++ b/test/gtest/sigmoid_focal_loss.cpp @@ -0,0 +1,165 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "sigmoid_focal_loss.hpp" +#include + +namespace sigmoidfocalloss { +using GPU_SigmoidFocalLoss_fwd_FP32 = SigmoidFocalLossFwdTest; +using GPU_SigmoidFocalLoss_fwd_FP16 = SigmoidFocalLossFwdTest; +using GPU_SigmoidFocalLoss_fwd_BFP16 = SigmoidFocalLossFwdTest; +using GPU_SigmoidFocalLoss_bwd_FP32 = SigmoidFocalLossBwdTest; +using GPU_SigmoidFocalLoss_bwd_FP16 = SigmoidFocalLossBwdTest; +using GPU_SigmoidFocalLoss_bwd_BFP16 = SigmoidFocalLossBwdTest; +using GPU_SigmoidFocalLossUnreduced_fwd_FP32 = SigmoidFocalLossUnreducedFwdTest; +using GPU_SigmoidFocalLossUnreduced_fwd_FP16 = SigmoidFocalLossUnreducedFwdTest; +using GPU_SigmoidFocalLossUnreduced_fwd_BFP16 = SigmoidFocalLossUnreducedFwdTest; +using GPU_SigmoidFocalLossUnreduced_bwd_FP32 = SigmoidFocalLossUnreducedBwdTest; +using GPU_SigmoidFocalLossUnreduced_bwd_FP16 = SigmoidFocalLossUnreducedBwdTest; +using GPU_SigmoidFocalLossUnreduced_bwd_BFP16 = SigmoidFocalLossUnreducedBwdTest; +}; // namespace sigmoidfocalloss + +using namespace sigmoidfocalloss; + +TEST_P(GPU_SigmoidFocalLoss_fwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLoss_fwd_FP32, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLoss_fwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLoss_fwd_FP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLoss_fwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLoss_fwd_BFP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLoss_bwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLoss_bwd_FP32, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLoss_bwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLoss_bwd_FP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLoss_bwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLoss_bwd_BFP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLossUnreduced_fwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLossUnreduced_fwd_FP32, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLossUnreduced_fwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLossUnreduced_fwd_FP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLossUnreduced_fwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLossUnreduced_fwd_BFP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLossUnreduced_bwd_FP32, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLossUnreduced_bwd_FP32, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLossUnreduced_bwd_FP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLossUnreduced_bwd_FP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); + +TEST_P(GPU_SigmoidFocalLossUnreduced_bwd_BFP16, Test) +{ + RunTest(); + Verify(); +}; + +INSTANTIATE_TEST_SUITE_P(Full, + GPU_SigmoidFocalLossUnreduced_bwd_BFP16, + testing::ValuesIn(SigmoidFocalLossTestConfigs())); diff --git a/test/gtest/sigmoid_focal_loss.hpp b/test/gtest/sigmoid_focal_loss.hpp new file mode 100644 index 0000000000..ab59893e52 --- /dev/null +++ b/test/gtest/sigmoid_focal_loss.hpp @@ -0,0 +1,509 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "cpu_sigmoid_focal_loss.hpp" +#include "get_handle.hpp" +#include "miopen/allocator.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include + +struct SigmoidFocalLossTestCase +{ + std::vector dims; + bool isContiguous; + float alpha; + float gamma; + friend std::ostream& operator<<(std::ostream& os, const SigmoidFocalLossTestCase& tc) + { + os << "dims: "; + for(auto dim : tc.dims) + { + os << dim << " "; + } + return os << "is_contiguous: " << tc.isContiguous << " alpha: " << tc.alpha + << " gamma: " << tc.gamma; + } + + std::vector GetDims() const { return dims; } + + SigmoidFocalLossTestCase() {} + + SigmoidFocalLossTestCase(std::vector dim_, + bool isContiguous_ = true, + float alpha_ = 0.25, + float gamma_ = 2) + : dims(dim_), isContiguous(isContiguous_), alpha(alpha_), gamma(gamma_) + { + } + + std::vector ComputeStrides(std::vector inputDim) const + { + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; + } +}; + +inline std::vector SigmoidFocalLossTestConfigs() +{ + return { + SigmoidFocalLossTestCase({1}), // 1D cont + SigmoidFocalLossTestCase({4000}), // 1D cont + SigmoidFocalLossTestCase({100, 500}), // 2D cont + SigmoidFocalLossTestCase({100, 500}, false), // 2D non-cont + SigmoidFocalLossTestCase({10, 20, 200}), // 3D cont + SigmoidFocalLossTestCase({10, 20, 200}, false), // 3D non-cont + SigmoidFocalLossTestCase({8, 3, 20, 100}), // 4D cont + SigmoidFocalLossTestCase({8, 3, 20, 100}, false), // 4D non-cont + SigmoidFocalLossTestCase({2, 2, 3, 4, 100}), // 5D cont + SigmoidFocalLossTestCase({2, 2, 3, 4, 100}, false), // 5D non-cont + SigmoidFocalLossTestCase({10}, + true, + 0.6, + 3), // 5D non-cont, custom alpha, gamma + }; +} + +template +struct SigmoidFocalLossUnreducedFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + reduction = MIOPEN_LOSS_REDUCTION_NONE; + + auto in_dims = config.GetDims(); + auto in_strides = config.ComputeStrides(in_dims); + + auto in_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + input = tensor{in_dims, in_strides}.generate(in_gen_value); + + auto tar_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + target = tensor{in_dims, in_strides}.generate(tar_gen_value); + + output = tensor{in_dims}; + std::fill(output.begin(), output.end(), 0); + + outputHost = tensor{in_dims}; + std::fill(outputHost.begin(), outputHost.end(), 0); + + input_dev = handle.Write(input.data); + target_dev = handle.Write(target.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + miopenStatus_t status; + + status = miopen::SigmoidFocalLossForward(handle, + nullptr, + 0, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + output.desc, + output_dev.get(), + config.alpha, + config.gamma, + reduction); + cpu_sigmoid_focal_loss_forward( + input, target, outputHost, config.alpha, config.gamma, reduction, 1); + + EXPECT_EQ(status, miopenStatusSuccess); + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = get_tolerance(reduction); + + auto error = miopen::rms_range(outputHost, output); + + EXPECT_TRUE(miopen::range_distance(outputHost) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold) + << "Error output beyond tolerance Error: " << error << ", Threshold: " << threshold; + } + SigmoidFocalLossTestCase config; + miopenLossReductionMode_t reduction; + + tensor input; + tensor target; + tensor output; + + tensor outputHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr output_dev; +}; + +template +struct SigmoidFocalLossUnreducedBwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + reduction = MIOPEN_LOSS_REDUCTION_NONE; + + auto in_dims = config.GetDims(); + auto in_strides = config.ComputeStrides(in_dims); + auto in_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + input = tensor{in_dims, in_strides}.generate(in_gen_value); + + auto tar_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + target = tensor{in_dims, in_strides}.generate(tar_gen_value); + + auto dOut_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + dOutput = tensor{in_dims, in_strides}.generate(dOut_gen_value); + + dInput = tensor{in_dims}; + std::fill(dInput.begin(), dInput.end(), 0); + + dInputHost = tensor{in_dims}; + std::fill(dInputHost.begin(), dInputHost.end(), 0); + + dTarget = tensor{in_dims}; + std::fill(dTarget.begin(), dTarget.end(), 0); + + dTargetHost = tensor{in_dims}; + std::fill(dTargetHost.begin(), dTargetHost.end(), 0); + + input_dev = handle.Write(input.data); + target_dev = handle.Write(target.data); + dOutput_dev = handle.Write(dOutput.data); + dInput_dev = handle.Write(dInput.data); + dTarget_dev = handle.Write(dTarget.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + miopenStatus_t status; + + status = miopen::SigmoidFocalLossBackward(handle, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + dOutput.desc, + dOutput_dev.get(), + dInput.desc, + dInput_dev.get(), + dTarget.desc, + dTarget_dev.get(), + config.alpha, + config.gamma, + reduction); + cpu_sigmoid_focal_loss_backward(input, + target, + dOutput, + dInputHost, + dTargetHost, + config.alpha, + config.gamma, + reduction, + 1); + + EXPECT_EQ(status, miopenStatusSuccess); + + dInput.data = handle.Read(dInput_dev, dInput.data.size()); + dTarget.data = handle.Read(dTarget_dev, dTarget.data.size()); + } + + void Verify() + { + double threshold = get_tolerance(reduction); + + auto dInputError = miopen::rms_range(dInputHost, dInput); + + EXPECT_TRUE(miopen::range_distance(dInputHost) == miopen::range_distance(dInput)); + EXPECT_TRUE(dInputError < threshold) + << "dInput error output beyond tolerance Error: " << dInputError + << ", Threshold: " << threshold; + + auto dTargetError = miopen::rms_range(dTargetHost, dTarget); + + EXPECT_TRUE(miopen::range_distance(dTargetHost) == miopen::range_distance(dTarget)); + EXPECT_TRUE(dTargetError < threshold) + << "dTarget error output beyond tolerance Error: " << dTargetError + << ", Threshold: " << threshold; + } + SigmoidFocalLossTestCase config; + miopenLossReductionMode_t reduction; + + tensor input; + tensor target; + tensor dOutput; + tensor dInput; + tensor dTarget; + + tensor dInputHost; + tensor dTargetHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr dOutput_dev; + miopen::Allocator::ManageDataPtr dInput_dev; + miopen::Allocator::ManageDataPtr dTarget_dev; +}; + +template +struct SigmoidFocalLossFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + reduction = miopenLossReductionMode_t(int(prng::gen_0_to_B(2) + 1)); + + auto in_dims = config.GetDims(); + auto in_strides = config.ComputeStrides(in_dims); + + auto in_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 20); }; + input = tensor{in_dims, in_strides}.generate(in_gen_value); + + auto tar_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 20); }; + target = tensor{in_dims, in_strides}.generate(tar_gen_value); + + size_t workspaceSizeBytes = miopen::GetSigmoidFocalLossForwardWorkspaceSize( + handle, input.desc, target.desc, output.desc, reduction); + size_t workspaceElements = workspaceSizeBytes / sizeof(float); + + workspace = tensor(workspaceElements); + std::fill(workspace.begin(), workspace.end(), 0); + + output = tensor(1); + std::fill(output.begin(), output.end(), 0); + + outputHost = tensor(1); + std::fill(outputHost.begin(), outputHost.end(), 0); + + divisor = 1; + if(reduction == MIOPEN_LOSS_REDUCTION_MEAN) + { + divisor *= input.desc.GetElementSize(); + } + + input_dev = handle.Write(input.data); + target_dev = handle.Write(target.data); + workspace_dev = handle.Write(workspace.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + miopenStatus_t status; + + status = miopen::SigmoidFocalLossForward(handle, + workspace_dev.get(), + workspace.GetDataByteSize(), + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + output.desc, + output_dev.get(), + config.alpha, + config.gamma, + reduction); + cpu_sigmoid_focal_loss_forward( + input, target, outputHost, config.alpha, config.gamma, reduction, divisor); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = get_tolerance(reduction); + + auto error = miopen::rms_range(outputHost, output); + + EXPECT_TRUE(miopen::range_distance(outputHost) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold) + << "Error output beyond tolerance Error: " << error << ", Threshold: " << threshold + << " Reduction: " << reduction; + } + SigmoidFocalLossTestCase config; + miopenLossReductionMode_t reduction; + + tensor input; + tensor target; + tensor workspace; + tensor output; + + tensor outputHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr workspace_dev; + miopen::Allocator::ManageDataPtr output_dev; + + float divisor; +}; + +template +struct SigmoidFocalLossBwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + auto in_dims = config.GetDims(); + auto in_strides = config.ComputeStrides(in_dims); + + reduction = miopenLossReductionMode_t(int(prng::gen_0_to_B(2) + 1)); + + auto in_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + input = tensor{in_dims, in_strides}.generate(in_gen_value); + + auto tar_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 50); }; + target = tensor{in_dims, in_strides}.generate(tar_gen_value); + + dOutput = tensor(1); + dOutput[0] = prng::gen_descreet_uniform_sign(0.1, 50); + + dInput = tensor{in_dims}; + std::fill(dInput.begin(), dInput.end(), 0); + + dInputHost = tensor{in_dims}; + std::fill(dInputHost.begin(), dInputHost.end(), 0); + + dTarget = tensor{in_dims}; + std::fill(dTarget.begin(), dTarget.end(), 0); + + dTargetHost = tensor{in_dims}; + std::fill(dTargetHost.begin(), dTargetHost.end(), 0); + + divisor = 1; + if(reduction == MIOPEN_LOSS_REDUCTION_MEAN) + { + divisor *= input.desc.GetElementSize(); + } + input_dev = handle.Write(input.data); + target_dev = handle.Write(target.data); + dOutput_dev = handle.Write(dOutput.data); + dInput_dev = handle.Write(dInput.data); + dTarget_dev = handle.Write(dTarget.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + miopenStatus_t status; + + status = miopen::SigmoidFocalLossBackward(handle, + input.desc, + input_dev.get(), + target.desc, + target_dev.get(), + dOutput.desc, + dOutput_dev.get(), + dInput.desc, + dInput_dev.get(), + dTarget.desc, + dTarget_dev.get(), + config.alpha, + config.gamma, + reduction); + cpu_sigmoid_focal_loss_backward(input, + target, + dOutput, + dInputHost, + dTargetHost, + config.alpha, + config.gamma, + reduction, + divisor); + + EXPECT_EQ(status, miopenStatusSuccess); + + dInput.data = handle.Read(dInput_dev, dInput.data.size()); + dTarget.data = handle.Read(dTarget_dev, dTarget.data.size()); + } + + void Verify() + { + double threshold = get_tolerance(reduction); + + auto dInputError = miopen::rms_range(dInputHost, dInput); + + EXPECT_TRUE(miopen::range_distance(dInputHost) == miopen::range_distance(dInput)); + EXPECT_TRUE(dInputError < threshold) + << "dInput error output beyond tolerance Error: " << dInputError + << ", Threshold: " << threshold; + + auto dTargetError = miopen::rms_range(dTargetHost, dTarget); + + EXPECT_TRUE(miopen::range_distance(dTargetHost) == miopen::range_distance(dTarget)); + EXPECT_TRUE(dTargetError < threshold) + << "dTarget error output beyond tolerance Error: " << dTargetError + << ", Threshold: " << threshold; + } + SigmoidFocalLossTestCase config; + miopenLossReductionMode_t reduction; + + tensor input; + tensor target; + tensor dOutput; + tensor dInput; + tensor dTarget; + + tensor dInputHost; + tensor dTargetHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr target_dev; + miopen::Allocator::ManageDataPtr dOutput_dev; + miopen::Allocator::ManageDataPtr dInput_dev; + miopen::Allocator::ManageDataPtr dTarget_dev; + + float divisor; +};