Skip to content

Commit

Permalink
re-fix default driver
Browse files Browse the repository at this point in the history
  • Loading branch information
hieule88 committed Jul 26, 2024
1 parent 07598ab commit 4d008c8
Showing 1 changed file with 48 additions and 18 deletions.
66 changes: 48 additions & 18 deletions driver/interpolate_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,34 +189,64 @@ int InterpolateDriver<Tgpu, Tref>::GetandSetData()
mode = static_cast<miopenInterpolateMode_t>(inflags.GetValueInt("mode"));
align_corners = static_cast<bool>(inflags.GetValueInt("align_corners"));

if(mode != MIOPEN_INTERPOLATE_MODE_NEAREST)
if(config_scale_factors[0] == -1 && size[0] == -1)
{
for(int i = 0; i < size.size(); i++)
config_scale_factors[0] = 1;
for(int i = 1; i < in_len.size() - 2; i++)
{
scale_factors.push_back(config_scale_factors[i]);
config_scale_factors.push_back(1);
}
}
else

if(config_scale_factors[0] != -1)
{
for(int i = 0; i < size.size(); i++)
if(mode != MIOPEN_INTERPOLATE_MODE_NEAREST)
{
scale_factors.push_back(config_scale_factors[i]);
for(int i = 0; i < in_len.size() - 2; i++)
{
scale_factors.push_back(config_scale_factors[i]);
}
}
for(int i = size.size(); i < 3; i++)
else
{
scale_factors.push_back(0);
for(int i = 0; i < in_len.size() - 2; i++)
{
scale_factors.push_back(config_scale_factors[i]);
}
for(int i = in_len.size() - 2; i < 3; i++)
{
scale_factors.push_back(0);
}
}
}

auto out_len = std::vector<int>({in_len[0], in_len[1]});
for(int i = 0; i < size.size(); i++)
if(size[0] != -1)
{
if(scale_factors[i] != 0)
out_len.push_back(ceil(static_cast<int>(in_len[i + 2] * scale_factors[i])));
else
for(int i = 0; i < size.size(); i++)
{
scale_factors[i] = static_cast<float>(size[i]) / in_len[i + 2];
out_len.push_back(size[i]);
if(size[i] == 0)
out_len.push_back(ceil(static_cast<int>(in_len[i + 2] * scale_factors[i])));
else
{
if(config_scale_factors[0] == -1)
{
scale_factors.push_back(static_cast<float>(size[i]) / in_len[i + 2]);
}
else
{
scale_factors[i] = static_cast<float>(size[i]) / in_len[i + 2];
}
out_len.push_back(size[i]);
}
}
}
else
{
for(int i = 0; i < in_len.size() - 2; i++)
{
out_len.push_back(ceil(static_cast<int>(in_len[i + 2] * scale_factors[i])));
scale_factors[i] = static_cast<float>(out_len[i + 2]) / in_len[i + 2];
}
}

Expand Down Expand Up @@ -248,15 +278,15 @@ int InterpolateDriver<Tgpu, Tref>::AddCmdLineArgs()
"string");
inflags.AddInputFlag("size",
'S',
"32",
"-1",
"Output Spatial Size: D,H,W. "
"Example: 32.",
"Default: -1 - Use scale factors instead",
"string");
inflags.AddInputFlag("scale_factors",
's',
"32",
"-1",
"Multiplier for spatial size: factor_D,factor_H,factor_W. "
"Example: 32",
"Default: -1 - Use size instead",
"string");
inflags.AddInputFlag("mode",
'm',
Expand Down

0 comments on commit 4d008c8

Please sign in to comment.