Skip to content

Commit

Permalink
add comment to kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
BuiChiTrung committed Aug 31, 2024
1 parent 76499d4 commit 41b9300
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/kernels/MIOpenSigmoidFocalLoss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ __device__ void sigmoidFocalLossFwd(const TIO* input,
tensor_view_t<5> input_tv,
tensor_view_t<5> target_tv)
{
/*
Dim: input = target = workspace = {N, C, D, H, W}.
Each thread handle an elem in the input, target tensor.
Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}.
Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}.
*/
size_t gid = threadIdx.x + blockIdx.x * blockDim.x;

tensor_layout_t<5> idx(input_tv, gid);
Expand All @@ -63,6 +69,7 @@ __device__ void sigmoidFocalLossFwd(const TIO* input,
FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]);
FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]);

/* The formula follows torchvision package: torchvision/ops/focal_loss.py */
FLOAT_ACCUM p = 1 / (1 + exp(-i));
FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p));
FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t);
Expand Down Expand Up @@ -105,6 +112,12 @@ __device__ void sigmoidFocalLossBwd(const TIO* input,
tensor_view_t<5> dinput_tv,
tensor_view_t<5> dtarget_tv)
{
/*
Dim: input = target = doutput = dinput = dtarget = {N, C, D, H, W}.
Each thread handle an elem in the input, target, doutput tensor.
Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}.
Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}.
*/
size_t gid = threadIdx.x + blockIdx.x * blockDim.x;

tensor_layout_t<5> idx(input_tv, gid);
Expand All @@ -116,6 +129,7 @@ __device__ void sigmoidFocalLossBwd(const TIO* input,
FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]);
FLOAT_ACCUM dO = CVT_FLOAT2ACCUM(doutput[doutput_tv.get_tensor_view_idx(doIdx)]);

/* Formula is formed by compute fwd's formula gradient */
FLOAT_ACCUM p = 1 / (1 + exp(-i));
FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p));
FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t);
Expand Down Expand Up @@ -199,6 +213,12 @@ __device__ void sigmoidFocalLossUnreducedFwd(const TIO* input,
tensor_view_t<5> target_tv,
tensor_view_t<5> output_tv)
{
/*
Dim: input = target = output = {N, C, D, H, W}.
Each thread handle an elem in the input, target tensor.
Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}.
Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}.
*/
size_t gid = threadIdx.x + blockIdx.x * blockDim.x;

tensor_layout_t<5> idx(input_tv, gid);
Expand All @@ -208,6 +228,7 @@ __device__ void sigmoidFocalLossUnreducedFwd(const TIO* input,
FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]);
FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]);

/* The formula follows torchvision package: torchvision/ops/focal_loss.py */
FLOAT_ACCUM p = 1 / (1 + exp(-i));
FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p));
FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t);
Expand Down Expand Up @@ -249,6 +270,12 @@ __device__ void sigmoidFocalLossUnreducedBwd(const TIO* input,
tensor_view_t<5> dinput_tv,
tensor_view_t<5> dtarget_tv)
{
/*
Dim: input = target = doutput = dinput = dtarget = {N, C, D, H, W}.
Each thread handle an elem in the input, target, doutput tensor.
Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}.
Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}.
*/
size_t gid = threadIdx.x + blockIdx.x * blockDim.x;

tensor_layout_t<5> idx(input_tv, gid);
Expand All @@ -259,6 +286,7 @@ __device__ void sigmoidFocalLossUnreducedBwd(const TIO* input,
FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]);
FLOAT_ACCUM dO = CVT_FLOAT2ACCUM(doutput[doutput_tv.get_tensor_view_idx(idx)]);

/* Formula is formed by compute fwd's formula gradient */
FLOAT_ACCUM p = 1 / (1 + exp(-i));
FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p));
FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t);
Expand Down

0 comments on commit 41b9300

Please sign in to comment.