Skip to content

Commit

Permalink
use MultiBufferWorkspaceTraits
Browse files Browse the repository at this point in the history
  • Loading branch information
BuiChiTrung committed Aug 5, 2024
1 parent debd530 commit 182ea0b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/include/miopen/sigmoidfocalloss/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct SigmoidFocalLossFwd final : SigmoidFocalLossFwdSolverBase
const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription&
problem) const override;

MultiBufferWorkspaceTraits GetMultiBufferWorkspaceTraits(
const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const;

std::size_t
GetWorkspaceSize(const ExecutionContext& context,
const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem)
Expand Down
30 changes: 18 additions & 12 deletions src/solver/sigmoidfocalloss/forward_reduce_sigmoid_focal_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*
*******************************************************************************/

#include "miopen/buffer_info.hpp"
#include <miopen/sigmoidfocalloss/problem_description.hpp>
#include <miopen/miopen.h>
#include <miopen/datatype.hpp>
Expand All @@ -36,7 +37,7 @@
#include <miopen/tensor_view_utils.hpp>

#define LOCAL_SIZE 256
#define LOCAL_SIZE_REDUCE_FWD 256
#define LOCAL_SIZE_REDUCE 256

namespace miopen {

Expand Down Expand Up @@ -83,11 +84,11 @@ ConvSolution SigmoidFocalLossFwd::GetSolution(
do
{
result.construction_params.push_back(make_hip_kernel(
{LOCAL_SIZE_REDUCE_FWD}, {_size}, "MIOpenLossSum.cpp", "LossSum", build_params));
_size = AlignUp(_size, LOCAL_SIZE_REDUCE_FWD) / LOCAL_SIZE_REDUCE_FWD;
{LOCAL_SIZE_REDUCE}, {_size}, "MIOpenLossSum.cpp", "LossSum", build_params));
_size = AlignUp(_size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE;
} while(_size > 1);

result.invoker_factory = [](const std::vector<Kernel>& kernels) {
result.invoker_factory = [this, problem](const std::vector<Kernel>& kernels) {
return [=](const Handle& handle_, const AnyInvokeParams& raw_params) {
decltype(auto) params = raw_params.CastTo<miopen::sigmoidfocalloss::FwdInvokeParams>();
auto size = deref(params.inputDesc).GetElementSize();
Expand Down Expand Up @@ -127,11 +128,11 @@ ConvSolution SigmoidFocalLossFwd::GetSolution(
}

/* Execute reduce kernels */
auto wt = GetMultiBufferWorkspaceTraits(problem);
auto reduceIn = params.workspace;
auto reduceOut =
static_cast<Data_t>(static_cast<char*>(params.workspace) +
deref(params.inputDesc).GetElementSize() *
get_data_size(deref(params.outputDesc).GetType()));
static_cast<Data_t>(static_cast<char*>(params.workspace) + wt.GetOffset(1));

for(int i = 1; i < kernels.size(); ++i)
{
decltype(auto) kernel = handle_.Run(kernels[i]);
Expand All @@ -144,7 +145,7 @@ ConvSolution SigmoidFocalLossFwd::GetSolution(
{
kernel(reduceIn, params.output, size);
}
size = AlignUp(size, LOCAL_SIZE_REDUCE_FWD) / LOCAL_SIZE_REDUCE_FWD;
size = AlignUp(size, LOCAL_SIZE_REDUCE) / LOCAL_SIZE_REDUCE;
}

if(profiling)
Expand All @@ -169,13 +170,18 @@ ConvSolution SigmoidFocalLossFwd::GetSolution(
std::size_t SigmoidFocalLossFwd::GetWorkspaceSize(
const ExecutionContext& /*context*/,
const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const
{
return GetMultiBufferWorkspaceTraits(problem).GetSize();
}

MultiBufferWorkspaceTraits SigmoidFocalLossFwd::GetMultiBufferWorkspaceTraits(
const miopen::sigmoidfocalloss::SigmoidFocalLossFwdProblemDescription& problem) const
{
size_t inputElements = problem.GetInputDesc().GetElementSize();
size_t reduceElements = (inputElements + LOCAL_SIZE_REDUCE_FWD - 1) / LOCAL_SIZE_REDUCE_FWD;
size_t res =
(inputElements + reduceElements) * get_data_size(problem.GetOutputDesc().GetType());
size_t reduceElements = (inputElements + LOCAL_SIZE_REDUCE - 1) / LOCAL_SIZE_REDUCE;
size_t elementSize = get_data_size(problem.GetOutputDesc().GetType());

return res;
return MultiBufferWorkspaceTraits{inputElements * elementSize, reduceElements * elementSize};
}

} // namespace sigmoidfocalloss
Expand Down

0 comments on commit 182ea0b

Please sign in to comment.