Skip to content

Commit

Permalink
Fixed scratch size calculation for conv for HiFi targets for scenario…
Browse files Browse the repository at this point in the history
…s when input, filter and output heights are 1.
  • Loading branch information
pramods-cad committed Dec 9, 2024
1 parent 4a8bb6b commit 9280421
Showing 1 changed file with 42 additions and 6 deletions.
48 changes: 42 additions & 6 deletions tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,64 @@ TfLiteStatus ConvPrepareHifi(TfLiteContext* context, TfLiteNode* node) {
}

const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
const int stride_height = params->stride_height;
const int stride_width = params->stride_width;
const int pad_height = data->reference_op_data.padding.height;
const int pad_width = data->reference_op_data.padding.width;

int required_scratch = 0;
// TODO(b/277112516): Dilation is currently not supported on HiFi 4 NN Library
if ((params->dilation_width_factor == 1) &&
(params->dilation_height_factor == 1)) {
if (input->type == kTfLiteInt8) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_ASYM8S);
if (input_height == 1 && filter_height == 1 && output_height == 1)
{
int inp_h, filt_h, filt_w, str_h, pad_h, out_h;
inp_h = input_width;
filt_h = filter_width;
filt_w = filter_height;
str_h = stride_width;
pad_h = pad_width;
out_h = output_width;
required_scratch = xa_nn_conv2d_std_getsize(
inp_h, input_depth, filt_h, filt_w, str_h,
pad_h, out_h, output_channels, PREC_ASYM8S);
}
else
{
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_ASYM8S);
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
if (input->type == kTfLiteInt16) {
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_SYM16S);
if (input_height == 1 && filter_height == 1 && output_height == 1)
{
int inp_h, filt_h, filt_w, str_h, pad_h, out_h;
inp_h = input_width;
filt_h = filter_width;
filt_w = filter_height;
str_h = stride_width;
pad_h = pad_width;
out_h = output_width;
required_scratch = xa_nn_conv2d_std_getsize(
inp_h, input_depth, filt_h, filt_w, str_h,
pad_h, out_h, output_channels, PREC_SYM16S);
}
else
{
required_scratch = xa_nn_conv2d_std_getsize(
input_height, input_depth, filter_height, filter_width, stride_height,
pad_height, output_height, output_channels, PREC_SYM16S);
}
TF_LITE_ENSURE(context, required_scratch > 0);
}
}
Expand Down

0 comments on commit 9280421

Please sign in to comment.