Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For ci testing #3007

Closed
wants to merge 8 commits into from
112 changes: 62 additions & 50 deletions tensorflow/lite/micro/kernels/concatenation.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,13 @@ constexpr int kOutputTensor = 0;

struct OpData {
ConcatenationParams params;

#ifdef USE_TFLM_COMPRESSION

// scratch buffers for compressed tensors
int scratch_indices[kMaxInputNum];

#endif // USE_TFLM_COMPRESSION
};

// Handles negative axis index, coerces to positive index value.
Expand All @@ -52,8 +59,6 @@ inline int CalculatePositiveAxis(int axis, const TfLiteTensor* output_tensor) {
inline void GetAllInputTensorShapes(const TfLiteContext* context,
const TfLiteNode* node,
RuntimeShape all_shapes[kMaxInputNum]) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
RuntimeShape shape = tflite::micro::GetTensorShape(t);
Expand All @@ -73,12 +78,22 @@ inline void GetShapesPointers(const RuntimeShape* shapes, size_t num,
template <typename T>
inline void GetAllInputTensorData(const TfLiteContext* context,
const TfLiteNode* node,
T* all_data[kMaxInputNum]) {
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
const T* all_data[kMaxInputNum]) {
#ifdef USE_TFLM_COMPRESSION
const OpData* data = static_cast<const OpData*>(node->user_data);
MicroContext* micro_context = GetMicroContext(context);
#endif // USE_TFLM_COMPRESSION

for (int i = 0; i < node->inputs->size; ++i) {
const TfLiteEvalTensor* t = tflite::micro::GetEvalInput(context, node, i);
#ifdef USE_TFLM_COMPRESSION
const CompressionTensorData* comp_td =
micro_context->GetTensorCompressionData(node, i);
all_data[i] = tflite::micro::GetTensorData<T>(micro_context, t, comp_td,
data->scratch_indices[i]);
#else // USE_TFLM_COMPRESSION
all_data[i] = tflite::micro::GetTensorData<T>(t);
#endif // USE_TFLM_COMPRESSION
}
}

Expand All @@ -88,16 +103,17 @@ void EvalUnquantized(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape inputs_shape[kMaxInputNum];
const RuntimeShape* inputs_shape_ptr[kMaxInputNum];
const data_type* inputs_data[kMaxInputNum];
TFLITE_DCHECK(context != nullptr);
TFLITE_DCHECK(node != nullptr);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);
GetAllInputTensorShapes(context, node, inputs_shape);
GetShapesPointers(inputs_shape, node->inputs->size, inputs_shape_ptr);
GetAllInputTensorData(context, node, inputs_data);

TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);

TFLITE_DCHECK(node->user_data != nullptr);
const OpData* data = static_cast<const OpData*>(node->user_data);

reference_ops::Concatenation(data->params, inputs_shape_ptr, inputs_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<data_type>(output));
Expand Down Expand Up @@ -126,7 +142,6 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteType output_type = output_tensor->type;

micro_context->DeallocateTempTfLiteTensor(input_tensor);
micro_context->DeallocateTempTfLiteTensor(output_tensor);

// Check activation and input type
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
Expand All @@ -136,16 +151,22 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
input_type == kTfLiteInt64 || input_type == kTfLiteBool);

// Output type must match input type
TF_LITE_ENSURE_EQ(context, output_type, input_type);
TF_LITE_ENSURE_TYPES_EQ(context, output_type, input_type);

// This implementation does not support large number of input tensors
const int num_inputs = NumInputs(node);
TF_LITE_ENSURE(context, num_inputs <= kMaxInputNum);

// Shapes with dimensions >4 are not yet supported with static allocation.
// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

// Shapes with dimensions > kMaxSmallSize are not yet supported with static
// allocation.
for (int i = 0; i < num_inputs; ++i) {
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, input_type);
int num_dimensions = NumDimensions(input);

if (num_dimensions > RuntimeShape::kMaxSmallSize) {
Expand All @@ -155,62 +176,53 @@ TfLiteStatus ConcatenationPrepare(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape::kMaxSmallSize, num_dimensions);
return kTfLiteError;
}

if (input_type == kTfLiteInt8) {
// Make sure there is no re-scaling needed for Int8 quantized kernel. This
// is a restriction we introduced to Int8 kernels.
TF_LITE_ENSURE_EQ(context, static_cast<double>(input->params.scale),
static_cast<double>(output_tensor->params.scale));
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
output_tensor->params.zero_point);
} else if (input_type == kTfLiteInt16) {
// Make sure that all Int16 inputs have a null zero-point.
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
}

#ifdef USE_TFLM_COMPRESSION

// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
data->scratch_indices[i] =
micro_context->AllocateDecompressionScratchBuffer(node, i);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(input);
}

// Calculate OpData.
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);

TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input_type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, output_tensor->params.zero_point, 0);
}

switch (output_type) { // Already know in/outtypes are same.
case kTfLiteBool:
case kTfLiteFloat32:
case kTfLiteInt8:
case kTfLiteInt16:
case kTfLiteInt32:
case kTfLiteInt64: {
data->params.axis = CalculatePositiveAxis(params->axis, output);
data->params.inputs_count = node->inputs->size;
break;
}
case kTfLiteInt8: {
data->params.axis = CalculatePositiveAxis(params->axis, output);
data->params.axis = CalculatePositiveAxis(params->axis, output_tensor);
data->params.inputs_count = node->inputs->size;

float* input_scales =
reinterpret_cast<float*>(context->AllocatePersistentBuffer(
context, node->inputs->size * sizeof(float)));

int32_t* input_zero_points =
reinterpret_cast<int32_t*>(context->AllocatePersistentBuffer(
context, node->inputs->size * sizeof(int32_t)));

// Allocate persistent scale and zeropoint buffers.
// Store input scale and zero point values in OpParams:
for (int i = 0; i < node->inputs->size; ++i) {
TfLiteTensor* t = micro_context->AllocateTempInputTensor(node, i);
TF_LITE_ENSURE(context, t != nullptr);
input_scales[i] = t->params.scale;
input_zero_points[i] = t->params.zero_point;
micro_context->DeallocateTempTfLiteTensor(t);
}

data->params.input_scale = input_scales;
data->params.input_zeropoint = input_zero_points;
data->params.output_zeropoint = output->params.zero_point;
data->params.output_scale = output->params.scale;
break;
}
default:
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
MicroPrintf("Op Concatenation does not currently support type '%s'.",
TfLiteTypeGetName(output_type));
return kTfLiteError;
}

micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(output_tensor);

return kTfLiteOk;
}
Expand Down
Loading
Loading