From 4d008c817d5b5e952851935105a232b0b90d8ed7 Mon Sep 17 00:00:00 2001 From: hieule88 Date: Fri, 26 Jul 2024 11:05:43 +0700 Subject: [PATCH] re-fix default driver --- driver/interpolate_driver.hpp | 66 +++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/driver/interpolate_driver.hpp b/driver/interpolate_driver.hpp index 437326512b..7b7af09c8d 100644 --- a/driver/interpolate_driver.hpp +++ b/driver/interpolate_driver.hpp @@ -189,34 +189,64 @@ int InterpolateDriver::GetandSetData() mode = static_cast(inflags.GetValueInt("mode")); align_corners = static_cast(inflags.GetValueInt("align_corners")); - if(mode != MIOPEN_INTERPOLATE_MODE_NEAREST) + if(config_scale_factors[0] == -1 && size[0] == -1) { - for(int i = 0; i < size.size(); i++) + config_scale_factors[0] = 1; + for(int i = 1; i < in_len.size() - 2; i++) { - scale_factors.push_back(config_scale_factors[i]); + config_scale_factors.push_back(1); } } - else + + if(config_scale_factors[0] != -1) { - for(int i = 0; i < size.size(); i++) + if(mode != MIOPEN_INTERPOLATE_MODE_NEAREST) { - scale_factors.push_back(config_scale_factors[i]); + for(int i = 0; i < in_len.size() - 2; i++) + { + scale_factors.push_back(config_scale_factors[i]); + } } - for(int i = size.size(); i < 3; i++) + else { - scale_factors.push_back(0); + for(int i = 0; i < in_len.size() - 2; i++) + { + scale_factors.push_back(config_scale_factors[i]); + } + for(int i = in_len.size() - 2; i < 3; i++) + { + scale_factors.push_back(0); + } } } auto out_len = std::vector({in_len[0], in_len[1]}); - for(int i = 0; i < size.size(); i++) + if(size[0] != -1) { - if(scale_factors[i] != 0) - out_len.push_back(ceil(static_cast(in_len[i + 2] * scale_factors[i]))); - else + for(int i = 0; i < size.size(); i++) { - scale_factors[i] = static_cast(size[i]) / in_len[i + 2]; - out_len.push_back(size[i]); + if(size[i] == 0) + out_len.push_back(ceil(static_cast(in_len[i + 2] * scale_factors[i]))); + else + { + if(config_scale_factors[0] == -1) + { + scale_factors.push_back(static_cast(size[i]) / in_len[i + 2]); + } + else + { + scale_factors[i] = static_cast(size[i]) / in_len[i + 2]; + } + out_len.push_back(size[i]); + } + } + } + else + { + for(int i = 0; i < in_len.size() - 2; i++) + { + out_len.push_back(ceil(static_cast(in_len[i + 2] * scale_factors[i]))); + scale_factors[i] = static_cast(out_len[i + 2]) / in_len[i + 2]; } } @@ -248,15 +278,15 @@ int InterpolateDriver::AddCmdLineArgs() "string"); inflags.AddInputFlag("size", 'S', - "32", + "-1", "Output Spatial Size: D,H,W. " - "Example: 32.", + "Default: -1 - Use scale factors instead", "string"); inflags.AddInputFlag("scale_factors", 's', - "32", + "-1", "Multiplier for spatial size: factor_D,factor_H,factor_W. " - "Example: 32", + "Default: -1 - Use size instead", "string"); inflags.AddInputFlag("mode", 'm',