diff --git a/src/solver/conv/conv_direct_naive_conv_bwd.cpp b/src/solver/conv/conv_direct_naive_conv_bwd.cpp index 5d9d481543..7434bb8787 100644 --- a/src/solver/conv/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv/conv_direct_naive_conv_bwd.cpp @@ -40,9 +40,13 @@ using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { - if(!miopen::debug::AlwaysEnableConvDirectNaive && - env::disabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD)) - return false; + if(!miopen::debug::AlwaysEnableConvDirectNaive) + { + if(env::disabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_BWD)) + return false; + if(!ctx.use_hip_kernels) + return false; + } if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem)) return false; diff --git a/src/solver/conv/conv_direct_naive_conv_fwd.cpp b/src/solver/conv/conv_direct_naive_conv_fwd.cpp index c10bca0105..8bf51476b3 100644 --- a/src/solver/conv/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv/conv_direct_naive_conv_fwd.cpp @@ -39,9 +39,13 @@ using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { - if(!miopen::debug::AlwaysEnableConvDirectNaive && - env::disabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD)) - return false; + if(!miopen::debug::AlwaysEnableConvDirectNaive) + { + if(env::disabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD)) + return false; + if(!ctx.use_hip_kernels) + return false; + } if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem)) return false; diff --git a/src/solver/conv/conv_direct_naive_conv_wrw.cpp b/src/solver/conv/conv_direct_naive_conv_wrw.cpp index 936afceb3f..88eefd8122 100644 --- a/src/solver/conv/conv_direct_naive_conv_wrw.cpp +++ b/src/solver/conv/conv_direct_naive_conv_wrw.cpp @@ -40,9 +40,13 @@ using ProblemDescription = miopen::conv::ProblemDescription; bool ConvDirectNaiveConvWrw::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& problem) const { - if(!miopen::debug::AlwaysEnableConvDirectNaive && - env::disabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW)) - return false; + if(!miopen::debug::AlwaysEnableConvDirectNaive) + { + if(env::disabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_WRW)) + return false; + if(!ctx.use_hip_kernels) + return false; + } if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem)) return false;