Skip to content

Commit

Permalink
update isimprovementoverrocm
Browse files Browse the repository at this point in the history
  • Loading branch information
littlecutebird committed Aug 21, 2024
1 parent 86149f6 commit fdc4232
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
30 changes: 23 additions & 7 deletions src/solver/multimarginloss/forward_reduced_multimarginloss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,29 @@ bool MultiMarginLossForward::IsImprovementOverROCm(
const ExecutionContext& /*context*/,
const miopen::multimarginloss::ForwardProblemDescription& problem) const
{
if(problem.GetiDesc().GetLengths()[1] <= 30)
return true;
if((problem.GetiDesc().GetType() == miopenHalf ||
problem.GetiDesc().GetType() == miopenBFloat16) &&
problem.GetiDesc().IsContiguous() && problem.GetiDesc().GetLengths()[1] <= 40)
return true;
return false;
int C = problem.GetiDesc().GetLengths()[1];
if(problem.allContiguousTensor())
{
switch(problem.GetiDesc().GetType())
{
case miopenFloat: return C <= 33;
case miopenHalf: return C <= 43;
case miopenBFloat16: return C <= 44;
// Have not tested with other types yet
default: return true;
}
}
else
{
switch(problem.GetiDesc().GetType())
{
case miopenFloat: return C <= 31;
case miopenHalf: return C <= 38;
case miopenBFloat16: return C <= 40;
// Have not tested with other types yet
default: return true;
}
}
}

bool MultiMarginLossForward::IsApplicable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ bool MultiMarginLossUnreducedForward::IsImprovementOverROCm(
{
switch(problem.GetiDesc().GetType())
{
case miopenFloat: return C <= 32;
case miopenFloat: return C <= 33;
case miopenHalf: return C <= 43;
case miopenBFloat16: return C <= 44;
// Have not tested with other types yet
Expand Down

0 comments on commit fdc4232

Please sign in to comment.