diff --git a/src/kernels/MIOpenSigmoidFocalLoss.cpp b/src/kernels/MIOpenSigmoidFocalLoss.cpp index b8f3630e8d..d12335a5f7 100644 --- a/src/kernels/MIOpenSigmoidFocalLoss.cpp +++ b/src/kernels/MIOpenSigmoidFocalLoss.cpp @@ -54,6 +54,12 @@ __device__ void sigmoidFocalLossFwd(const TIO* input, tensor_view_t<5> input_tv, tensor_view_t<5> target_tv) { + /* + Dim: input = target = workspace = {N, C, D, H, W}. + Each thread handle an elem in the input, target tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ size_t gid = threadIdx.x + blockIdx.x * blockDim.x; tensor_layout_t<5> idx(input_tv, gid); @@ -63,6 +69,7 @@ __device__ void sigmoidFocalLossFwd(const TIO* input, FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]); FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); + /* The formula follows torchvision package: torchvision/ops/focal_loss.py */ FLOAT_ACCUM p = 1 / (1 + exp(-i)); FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); @@ -105,6 +112,12 @@ __device__ void sigmoidFocalLossBwd(const TIO* input, tensor_view_t<5> dinput_tv, tensor_view_t<5> dtarget_tv) { + /* + Dim: input = target = doutput = dinput = dtarget = {N, C, D, H, W}. + Each thread handle an elem in the input, target, doutput tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ size_t gid = threadIdx.x + blockIdx.x * blockDim.x; tensor_layout_t<5> idx(input_tv, gid); @@ -116,6 +129,7 @@ __device__ void sigmoidFocalLossBwd(const TIO* input, FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); FLOAT_ACCUM dO = CVT_FLOAT2ACCUM(doutput[doutput_tv.get_tensor_view_idx(doIdx)]); + /* Formula is formed by compute fwd's formula gradient */ FLOAT_ACCUM p = 1 / (1 + exp(-i)); FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); @@ -199,6 +213,12 @@ __device__ void sigmoidFocalLossUnreducedFwd(const TIO* input, tensor_view_t<5> target_tv, tensor_view_t<5> output_tv) { + /* + Dim: input = target = output = {N, C, D, H, W}. + Each thread handle an elem in the input, target tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ size_t gid = threadIdx.x + blockIdx.x * blockDim.x; tensor_layout_t<5> idx(input_tv, gid); @@ -208,6 +228,7 @@ __device__ void sigmoidFocalLossUnreducedFwd(const TIO* input, FLOAT_ACCUM i = CVT_FLOAT2ACCUM(input[input_tv.get_tensor_view_idx(idx)]); FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); + /* The formula follows torchvision package: torchvision/ops/focal_loss.py */ FLOAT_ACCUM p = 1 / (1 + exp(-i)); FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t); @@ -249,6 +270,12 @@ __device__ void sigmoidFocalLossUnreducedBwd(const TIO* input, tensor_view_t<5> dinput_tv, tensor_view_t<5> dtarget_tv) { + /* + Dim: input = target = doutput = dinput = dtarget = {N, C, D, H, W}. + Each thread handle an elem in the input, target, doutput tensor. + Lws = {LOCAL_SIZE_SIGMOIDFOCALLOSS(default = 256), 1, 1}. + Gws = {AlignUp(N * C * D * H * W, lws.x), 1, 1}. + */ size_t gid = threadIdx.x + blockIdx.x * blockDim.x; tensor_layout_t<5> idx(input_tv, gid); @@ -259,6 +286,7 @@ __device__ void sigmoidFocalLossUnreducedBwd(const TIO* input, FLOAT_ACCUM t = CVT_FLOAT2ACCUM(target[target_tv.get_tensor_view_idx(idx)]); FLOAT_ACCUM dO = CVT_FLOAT2ACCUM(doutput[doutput_tv.get_tensor_view_idx(idx)]); + /* Formula is formed by compute fwd's formula gradient */ FLOAT_ACCUM p = 1 / (1 + exp(-i)); FLOAT_ACCUM ceLoss = -(t * log(p) + (1 - t) * log(1 - p)); FLOAT_ACCUM pT = p * t + (1 - p) * (1 - t);