Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SigmoidFocalLoss operation #3143

Open
wants to merge 26 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d4f50a2
resolve conflict
BuiChiTrung Aug 5, 2024
7a6dfa4
remove githooks
BuiChiTrung Jul 24, 2024
42aafe3
add .githooks
BuiChiTrung Jul 24, 2024
605542b
add githooks
BuiChiTrung Jul 24, 2024
0a6dfa2
add Tcheck type in driver
BuiChiTrung Jul 24, 2024
9d75374
fix cppcheck err
BuiChiTrung Jul 24, 2024
f91144c
add MIOPEN_INTERNALS_EXPORT
BuiChiTrung Jul 26, 2024
35c6ee6
change gtest naming format following new convention
BuiChiTrung Jul 30, 2024
2b16cdb
update drive random bound
BuiChiTrung Jul 30, 2024
ee5952a
try revert back unit-test file to check pipeline
BuiChiTrung Jul 30, 2024
4fcf689
try __hip_ds_swizzlef_N
BuiChiTrung Jul 30, 2024
debd530
change unit-test format
BuiChiTrung Jul 30, 2024
182ea0b
use MultiBufferWorkspaceTraits
BuiChiTrung Jul 31, 2024
19c6390
remove redundant files
BuiChiTrung Aug 5, 2024
ae2ee25
rollback src/include/miopen/solver/implicitgemm_ck_util.hpp
BuiChiTrung Aug 5, 2024
c1c602c
revert warp_shuffle using shlf_down
BuiChiTrung Aug 6, 2024
edcd7e7
include header in .cpp file
BuiChiTrung Aug 6, 2024
091aa5b
merge duplicate code to validate in CPU and driver
BuiChiTrung Aug 8, 2024
cf3bcc0
remove param reduction in test config
BuiChiTrung Aug 8, 2024
a41a005
change verify algo in CPU to naive accumulate in reduce kernels
BuiChiTrung Aug 19, 2024
bc233f1
merge upstream/develop
BuiChiTrung Aug 26, 2024
afc738f
merge code with reduce kernel used in develop branch
BuiChiTrung Aug 27, 2024
449a51d
apply clang-format
BuiChiTrung Aug 27, 2024
446322c
fix implicitgemm_ck_util.hpp
BuiChiTrung Aug 28, 2024
76499d4
merge with upstream/develop
BuiChiTrung Aug 29, 2024
41b9300
add comment to kernel
BuiChiTrung Aug 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ The MIOpen API library is structured as follows:
* :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental)
* :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental)
* :doc:`ReLU <../doxygen/html/group___re_l_u>` (experimental)
* :doc:`SigmoidFocalLoss <../doxygen/html/group__loss_function>` (experimental)
1 change: 1 addition & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ add_executable(MIOpenDriver
dm_reducecalculation.cpp
dm_rnn.cpp
dm_rope.cpp
dm_sigmoid_focal_loss.cpp
dm_softmax.cpp
dm_t5layernorm.cpp
dm_tensorop.cpp
Expand Down
41 changes: 41 additions & 0 deletions driver/dm_sigmoid_focal_loss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#include "registry_driver_maker.hpp"
#include "sigmoid_focal_loss_driver.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "sigmoidfocalloss")
return new SigmoidFocalLossDriver<float, float>();
else if(base_arg == "sigmoidfocallossfp16")
return new SigmoidFocalLossDriver<float16, float>();
else if(base_arg == "sigmoidfocallossbfp16")
return new SigmoidFocalLossDriver<bfloat16, float>();
return nullptr;
}

REGISTER_DRIVER_MAKER(makeDriver);
5 changes: 3 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], "
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], "
"prelu[bfp16|fp16]\n");
"prelu[bfp16|fp16], sigmoidfocalloss[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand Down Expand Up @@ -209,7 +209,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" &&
arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" &&
arg != "ropefp16" && arg != "ropebfp16" && arg != "prelu" && arg != "prelufp16" &&
arg != "prelubfp16" && arg != "--version")
arg != "prelubfp16" && arg != "sigmoidfocalloss" && arg != "sigmoidfocallossfp16" &&
arg != "sigmoidfocallossbfp16" && arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Usage();
Expand Down
135 changes: 135 additions & 0 deletions driver/mloSigmoidFocalLossHost.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#include <miopen/miopen.h>
#include <miopen/tensor_view_utils.hpp>

template <typename Tgpu, typename Tcheck>
void mloSigmoidFocalLossFwdRunHost(Tgpu* input,
miopenTensorDescriptor_t inputDesc,
Tgpu* target,
miopenTensorDescriptor_t targetDesc,
Tcheck* outputHost,
miopenTensorDescriptor_t outputDesc,
float alpha,
float gamma,
miopenLossReductionMode_t reduction,
float divisor)
{
auto input_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(inputDesc));
auto target_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(targetDesc));
auto output_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc));
size_t inputSize = miopen::deref(inputDesc).GetElementSize();

for(size_t id = 0; id < inputSize; ++id)
{
tensor_layout_t<5> idx(input_tv, id);

Tcheck i = static_cast<Tcheck>(input[input_tv.get_tensor_view_idx(idx)]);
Tcheck t = static_cast<Tcheck>(target[target_tv.get_tensor_view_idx(idx)]);

Tcheck sig = 1 / (1 + exp(-i));
Tcheck ceLoss = -(t * log(sig) + (1 - t) * log(1 - sig));
Tcheck sigT = sig * t + (1 - sig) * (1 - t);
Tcheck loss = ceLoss * pow(1 - sigT, gamma);

if(alpha >= 0)
{
Tcheck alphaT = alpha * t + (1 - alpha) * (1 - t);
loss = alphaT * loss;
}

if(reduction == MIOPEN_LOSS_REDUCTION_NONE)
{
outputHost[output_tv.get_tensor_view_idx(idx)] = loss;
}
else
{
outputHost[0] += static_cast<Tcheck>(loss / divisor);
}
}
}

template <typename Tgpu, typename Tcheck>
void mloSigmoidFocalLossBwdRunHost(Tgpu* input,
miopenTensorDescriptor_t inputDesc,
Tgpu* target,
miopenTensorDescriptor_t targetDesc,
Tgpu* doutput,
miopenTensorDescriptor_t doutputDesc,
Tcheck* dinput,
miopenTensorDescriptor_t dinputDesc,
Tcheck* dtarget,
miopenTensorDescriptor_t dtargetDesc,
float alpha,
float gamma,
miopenLossReductionMode_t reduction,
float divisor)
{
auto input_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(inputDesc));
auto target_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(targetDesc));
auto doutput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(doutputDesc));
auto dinput_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(dinputDesc));
auto dtarget_tv = miopen::get_inner_expanded_tv<5>(miopen::deref(dtargetDesc));

size_t inputSize = miopen::deref(inputDesc).GetElementSize();

tensor_layout_t<5> doIdx(input_tv, 0);
Tcheck dO = static_cast<Tcheck>(doutput[doutput_tv.get_tensor_view_idx(doIdx)]);

for(size_t id = 0; id < inputSize; ++id)
{
tensor_layout_t<5> idx(input_tv, id);

Tcheck i = static_cast<Tcheck>(input[input_tv.get_tensor_view_idx(idx)]);
Tcheck t = static_cast<Tcheck>(target[target_tv.get_tensor_view_idx(idx)]);
if(reduction == MIOPEN_LOSS_REDUCTION_NONE)
{
dO = static_cast<Tcheck>(doutput[doutput_tv.get_tensor_view_idx(idx)]);
}

Tcheck p = 1 / (1 + exp(-i));
Tcheck ceLoss = -(t * log(p) + (1 - t) * log(1 - p));
Tcheck pT = p * t + (1 - p) * (1 - t);
Tcheck powPt = pow(1 - pT, gamma);
Tcheck alpha_t = alpha * t + (1 - alpha) * (1 - t);

if(dinput)
{
Tcheck dpdi = exp(-i) / pow(1 + exp(-i), 2);
Tcheck dcelossdi = (-t / p + (1 - t) / (1 - p)) * dpdi;
Tcheck dpowptdi = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * t) * dpdi;

// L = ce_loss * pow_pt => dL/di = dceloss/di * pow_pt + ce_loss * dpowpt/di
Tcheck dLdi = dcelossdi * powPt + ceLoss * dpowptdi;
Tcheck grad = dO * dLdi;

if(alpha >= 0)
{
grad *= alpha_t;
}
if(reduction != MIOPEN_LOSS_REDUCTION_NONE)
{
grad /= divisor;
}
dinput[dinput_tv.get_tensor_view_idx(idx)] = static_cast<Tcheck>(grad);
}

if(dtarget)
{
Tcheck dcelossdt = -log(p) + log(1 - p);
Tcheck dpowptdt = gamma * pow(1 - pT, gamma - 1) * (1 - 2 * p);
// L = ce_loss * pow_pt => dL/dt = dceloss/dt * pow_pt + ce_loss * dpowpt/dt
Tcheck dLdt = dcelossdt * powPt + ceLoss * dpowptdt;
Tcheck gradTarget = dO * dLdt;

if(alpha >= 0)
{
// alpha_t * dL/dt + dalpha_t/dt * dL
gradTarget = alpha_t * dLdt + (2 * alpha - 1) * ceLoss * powPt;
}
if(reduction != MIOPEN_LOSS_REDUCTION_NONE)
{
gradTarget /= divisor;
}
dtarget[dtarget_tv.get_tensor_view_idx(idx)] = static_cast<Tcheck>(gradTarget);
}
}
}
Loading