Skip to content

Commit

Permalink
Revert "SPACE_TO_BATCH_ND: update output tensor shape" (#2339)
Browse files Browse the repository at this point in the history
Temporarily reverting so that we can reland with SpaceToBatch, BatchToSpace, Conv2D and ExpandDims output tensor resizing atomically.

Reverts #2335

BUG=#2338
  • Loading branch information
rascani authored Nov 30, 2023
1 parent 2f97d82 commit 83bbff9
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 555 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/testing/micro_test.h"

#define TF_LITE_MICRO_CHECK_FAIL() \
do { \
if (micro_test::did_test_fail) { \
return kTfLiteError; \
} \
} while (false)

namespace {

// Arena size is a guesstimate, followed by use of
Expand Down
114 changes: 11 additions & 103 deletions tensorflow/lite/micro/kernels/space_to_batch_nd.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -15,10 +15,7 @@ limitations under the License.

#include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"

#include <algorithm>

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
Expand All @@ -27,11 +24,12 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_utils.h"

namespace tflite {

namespace {

constexpr int kInputTensor = 0;
constexpr int kBlockShapeTensor = 1;
constexpr int kPaddingTensor = 2;
constexpr int kCropsTensor = 2;
constexpr int kOutputTensor = 0;

// Currently, only 3D NHC and 4D NHWC input/output op_context are supported.
Expand All @@ -46,68 +44,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return context->AllocatePersistentBuffer(context, sizeof(SpaceToBatchParams));
}

TfLiteStatus ReshapeOutputTensor(TfLiteContext* context, const TfLiteNode* node,
const TfLiteTensor* input,
const TfLiteTensor* block_shape,
const TfLiteTensor* padding,
TfLiteTensor* output) {
TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(block_shape));
TF_LITE_ENSURE(context, IsConstantOrPersistentTensor(padding));
const int32_t* block_shape_data = GetTensorData<int32_t>(block_shape);
const int32_t* padding_data = GetTensorData<int32_t>(padding);

TfLiteIntArray* input_dims = input->dims;
int spatial_dims_num = input_dims->size - 2;
// Block_shape should be a 1D tensor with dimension [spatial_dims_num].
TF_LITE_ENSURE_EQ(context, NumDimensions(block_shape), 1);
TF_LITE_ENSURE_EQ(context, block_shape->dims->data[0], spatial_dims_num);
// Padding should be a 2D tensor with dimension [spatial_dims_num, 2].
TF_LITE_ENSURE_EQ(context, NumDimensions(padding), 2);
TF_LITE_ENSURE_EQ(context, padding->dims->data[0], spatial_dims_num);
TF_LITE_ENSURE_EQ(context, padding->dims->data[1], 2);

// copy from input tensor as per TfLite code
RuntimeShape output_shape = GetTensorShape(input);
// keep a copy of the output tensor shape for later comparison
RuntimeShape old_output_shape = GetTensorShape(output);

// Ensures the input height and width (with padding) is a multiple of block
// shape height and width.
int output_batch_size = input_dims->data[0];
for (int dim = 0; dim < spatial_dims_num; ++dim) {
int final_dim_size = (input_dims->data[dim + 1] + padding_data[dim * 2] +
padding_data[dim * 2 + 1]);
TF_LITE_ENSURE(context, block_shape_data[dim] != 0);
TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape_data[dim], 0);
output_shape.SetDim(dim + 1, final_dim_size / block_shape_data[dim]);
output_batch_size *= block_shape_data[dim];
}
output_shape.SetDim(0, output_batch_size);
output_shape.SetDim(input_dims->size - 1,
input_dims->data[input_dims->size - 1]);

// check if need to relocate output tensor dims
if (output_shape == old_output_shape) {
return kTfLiteOk;
} else if (output_shape.FlatSize() > old_output_shape.FlatSize() &&
output->data.data != nullptr) {
MicroPrintf(
"SPACE_TO_BATCH_ND: resizing flatbuffer tensor data is not supported");
return kTfLiteError;
}

// set the output tensor dims from output_shape
TF_LITE_ENSURE_EQ(context, input_dims->size, output->dims->size);
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_STATUS(tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
std::copy_n(output_shape.DimsData(), output_shape.DimensionsCount(),
output->dims->data);

return kTfLiteOk;
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
MicroContext* micro_context = GetMicroContext(context);

Expand All @@ -116,47 +52,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {

TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* block_shape =
micro_context->AllocateTempInputTensor(node, kBlockShapeTensor);
TF_LITE_ENSURE(context, block_shape != nullptr);
TfLiteTensor* padding =
micro_context->AllocateTempInputTensor(node, kPaddingTensor);
TF_LITE_ENSURE(context, padding != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, input != nullptr && output != nullptr);

TF_LITE_ENSURE(context, NumDimensions(input) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) >= kInputOutputMinDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(input) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE(context, NumDimensions(output) <= kInputOutputMaxDimensionNum);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);

TF_LITE_ENSURE(context, node->user_data != nullptr);
SpaceToBatchParams& params =
*(static_cast<SpaceToBatchParams*>(node->user_data));

if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE(context, input->params.scale == output->params.scale);
TF_LITE_ENSURE(context,
input->params.zero_point == output->params.zero_point);
params.output_offset = output->params.zero_point;
} else {
params.output_offset = 0;
}

TfLiteStatus status =
ReshapeOutputTensor(context, node, input, block_shape, padding, output);

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(block_shape);
micro_context->DeallocateTempTfLiteTensor(padding);
micro_context->DeallocateTempTfLiteTensor(output);

return status;
return kTfLiteOk;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Expand All @@ -168,8 +76,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* block_shape =
tflite::micro::GetEvalInput(context, node, kBlockShapeTensor);
const TfLiteEvalTensor* padding =
tflite::micro::GetEvalInput(context, node, kPaddingTensor);
const TfLiteEvalTensor* crops =
tflite::micro::GetEvalInput(context, node, kCropsTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

Expand All @@ -180,8 +88,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(block_shape),
tflite::micro::GetTensorData<int32_t>(block_shape),
tflite::micro::GetTensorShape(padding),
tflite::micro::GetTensorData<int32_t>(padding),
tflite::micro::GetTensorShape(crops),
tflite::micro::GetTensorData<int32_t>(crops),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
Expand All @@ -191,8 +99,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(block_shape),
tflite::micro::GetTensorData<int32_t>(block_shape),
tflite::micro::GetTensorShape(padding),
tflite::micro::GetTensorData<int32_t>(padding),
tflite::micro::GetTensorShape(crops),
tflite::micro::GetTensorData<int32_t>(crops),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
Expand Down
Loading

0 comments on commit 83bbff9

Please sign in to comment.