diff --git a/src/include/miopen/sigmoidfocalloss/solvers.hpp b/src/include/miopen/sigmoidfocalloss/solvers.hpp index 992ad5a9d6..9cb3bd15e8 100644 --- a/src/include/miopen/sigmoidfocalloss/solvers.hpp +++ b/src/include/miopen/sigmoidfocalloss/solvers.hpp @@ -50,6 +50,9 @@ struct SigmoidFocalLossFwd final : SigmoidFocalLossFwdSolverBase const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const override; + MultiBufferWorkspaceTraits GetMultiBufferWorkspaceTraits( + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const; + std::size_t GetWorkspaceSize(const ExecutionContext& context, const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) diff --git a/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp b/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp index d3f874251f..f1a37fc54f 100644 --- a/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp +++ b/src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp @@ -24,6 +24,7 @@ * *******************************************************************************/ +#include "miopen/buffer_info.hpp" #include #include #include @@ -36,7 +37,7 @@ #include #define LOCAL_SIZE 256 -#define LOCAL_SIZE_REDUCE_FWD 256 +#define LOCAL_SIZE_REDUCE 256 namespace miopen { @@ -83,11 +84,11 @@ ConvSolution SigmoidFocalLossFwd::GetSolution( do { result.construction_params.push_back(make_hip_kernel( - {LOCAL_SIZE_REDUCE_FWD}, {_size}, "MIOpenLossSum.cpp", "LossSum", build_params)); - _size = AlignUp(_size, LOCAL_SIZE_REDUCE_FWD) / LOCAL_SIZE_REDUCE_FWD; + {LOCAL_SIZE_REDUCE}, {_size}, "MIOpenLossSum.cpp", "LossSum", build_params)); + _size = AlignUp(_size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; } while(_size > 1); - result.invoker_factory = [](const std::vector& kernels) { + result.invoker_factory = [this, problem](const std::vector& kernels) { return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) params = raw_params.CastTo(); auto size = deref(params.inputDesc).GetElementSize(); @@ -127,11 +128,11 @@ ConvSolution SigmoidFocalLossFwd::GetSolution( } /* Execute reduce kernels */ + auto wt = GetMultiBufferWorkspaceTraits(problem); auto reduceIn = params.workspace; auto reduceOut = - static_cast(static_cast(params.workspace) + - deref(params.inputDesc).GetElementSize() * - get_data_size(deref(params.outputDesc).GetType())); + static_cast(static_cast(params.workspace) + wt.GetOffset(1)); + for(int i = 1; i < kernels.size(); ++i) { decltype(auto) kernel = handle_.Run(kernels[i]); @@ -144,7 +145,7 @@ ConvSolution SigmoidFocalLossFwd::GetSolution( { kernel(reduceIn, params.output, size); } - size = AlignUp(size, LOCAL_SIZE_REDUCE_FWD) / LOCAL_SIZE_REDUCE_FWD; + size = AlignUp(size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE; } if(profiling) @@ -169,13 +170,18 @@ ConvSolution SigmoidFocalLossFwd::GetSolution( std::size_t SigmoidFocalLossFwd::GetWorkspaceSize( const ExecutionContext& /*context*/, const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const +{ + return GetMultiBufferWorkspaceTraits(problem).GetSize(); +} + +MultiBufferWorkspaceTraits SigmoidFocalLossFwd::GetMultiBufferWorkspaceTraits( + const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const { size_t inputElements = problem.GetInputDesc().GetElementSize(); - size_t reduceElements = (inputElements + LOCAL_SIZE_REDUCE_FWD - 1) / LOCAL_SIZE_REDUCE_FWD; - size_t res = - (inputElements + reduceElements) * get_data_size(problem.GetOutputDesc().GetType()); + size_t reduceElements = (inputElements + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE; + size_t elementSize = get_data_size(problem.GetOutputDesc().GetType()); - return res; + return MultiBufferWorkspaceTraits{inputElements * elementSize, reduceElements * elementSize}; } } // namespace sigmoidfocalloss