diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index 0df35fce4eb..7be915ab51e 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,15 +45,35 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const auto& data = *(static_cast(node->user_data)); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::Conv( ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); @@ -67,9 +87,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else if (bias->type == kTfLiteInt64) { @@ -79,9 +108,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else { @@ -119,9 +157,18 @@ TfLiteStatus ConvEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h index 0c8073f48f0..0090053e03c 100644 --- a/tensorflow/lite/micro/kernels/conv.h +++ b/tensorflow/lite/micro/kernels/conv.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -49,6 +49,14 @@ struct OpDataConv { // A buffer used to store unpacked filter values. This is used if the source // tensor is of n-bit precision that cannot be easily processed by kernels. int filter_buffer_index; + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int weights_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION }; extern const int kConvInputTensor; diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc index 51c7a6ff2d6..9f0f2f79588 100644 --- a/tensorflow/lite/micro/kernels/conv_common.cc +++ b/tensorflow/lite/micro/kernels/conv_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -209,6 +209,23 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) { &data->filter_buffer_index); } +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kConvWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, + kConvWeightsTensor); + data->bias_scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(filter); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(output); diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index 0fb9411a3f0..48eddeb9958 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/micro/kernels/conv_test.h" +#include + #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" @@ -46,6 +48,90 @@ static int kOutputShape[] = {4, 2, 1, 2, 3}; static const float kGoldenData[kOutputElements] = {18, 2, 5, 18, 2, 5, 17, 4, 3, 37, 4, 3}; +#ifdef USE_TFLM_COMPRESSION + +// compressed filter data for kBinQuant scheme, matches kFilterData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterData[] = { + 0x05, 0x38, 0x20, 0x90, 0x00, +}; +constexpr float kBinQuantFilterValueTable[] = { + 1, 2, 3, 4, -1, +}; +constexpr size_t kBinQuantFilterValueTableElements = + std::extent::value; +constexpr int kBinQuantFilterBitWidth = 3; +// compressed bias data for kBinQuant scheme, matches kBiasData +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; +constexpr int kBinQuantBiasBitWidth = 2; + +// Common inputs and outputs for quantized compressed tensor tests. +// Values from TfLite conv_test.cc SimplePerChannelTest. +static int kInputShapeQ1[] = {4, 1, 2, 3, 2}; +static const float kInputDataQ1[] = { + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 +}; +constexpr size_t kInputElementsQ1 = std::extent::value; + +constexpr int kNumChannelsQ1 = 2; +static int kFilterShapeQ1[] = {4, 2, 2, 2, 2}; +// Original filter data: +// static constexpr float kFilterDataQ1[] = { +// // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] +// 1, 2, // out channel = 0, y = 0, x = 0 +// 3, 4, // out channel = 0, y = 0, x = 1 +// 3, 4, // out channel = 0, y = 1, x = 0 +// 5, 6, // out channel = 0, y = 1, x = 1 +// 7, 8, // out channel = 1, y = 0, x = 0 +// 5, 6, // out channel = 1, y = 0, x = 1 +// 3, 4, // out channel = 1, y = 1, x = 0 +// 1, 2, // out channel = 1, y = 1, x = 1 +// }; + +static int kBiasShapeQ1[] = {1, 2}; +static const float kBiasDataQ1[] = {3, -2}; +constexpr size_t kBiasElementsQ1 = std::extent::value; + +static int kOutputShapeQ1[] = {4, 1, 1, 2, 2}; +static const float kGoldenDataQ1[] = {31, 64, -57, -46}; +constexpr int kOutputElementsQ1 = std::extent::value; +static const float kGoldenDataQ1_16[] = {31, 63.99804688, -57, -46}; + +// compressed filter data for kBinQuant scheme, matches kFilterDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = { + 0x05, 0x34, 0xE5, 0xDE, 0x54, 0xC1, +}; +constexpr float kBinQuantFilterValueTableQ1[] = { + 1, 2, 3, 4, 5, 6, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, +}; +constexpr size_t kBinQuantFilterValueTableElementsQ1 = + std::extent::value; +constexpr int kBinQuantFilterBitWidthQ1 = 3; +// compressed bias data for kBinQuant scheme, matches kBiasDataQ1 +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00}; +constexpr int kBinQuantBiasBitWidthQ1 = 1; + +static TfLiteConvParams common_conv_params_q1 = { + kTfLitePaddingValid, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, // activation + 1, // dilation_width_factor + 1, // dilation_height_factor + kTfLiteNoType // quantized_bias_type +}; + +#endif // USE_TFLM_COMPRESSION + static TfLiteConvParams common_conv_params = { kTfLitePaddingValid, // padding 2, // stride_width @@ -122,6 +208,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannel) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelCompressed) { + const float input_scale = 0.5f; + const float output_scale = 0.5f; + const int input_zero_point = -1; + const int output_zero_point = -1; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int8_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int8_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int8_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestFloat) { float output_data[tflite::testing::kOutputElements]; @@ -136,6 +281,40 @@ TF_LITE_MICRO_TEST(SimpleTestFloat) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestFloatCompressed) { + tflite::testing::TestCompressionInfo filter_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = tflite::testing::kBinQuantFilterValueTable; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElements; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::kBiasData; + bias_comp_info.value_table_stride = tflite::testing::kBiasElements; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + float output_data[tflite::testing::kOutputElements]; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvFloat( + tflite::testing::kInputShape, tflite::testing::kInputData, + tflite::testing::kFilterShape, + reinterpret_cast(tflite::testing::kBinQuantFilterData), + tflite::testing::kBiasShape, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::kOutputShape, tflite::testing::kGoldenData, + &tflite::testing::common_conv_params, tflite::Register_CONV_2D(), + output_data, &filter_comp_info, &bias_comp_info)); +} + +#endif + TF_LITE_MICRO_TEST(InputAndFilterSameWidthHeight) { const int output_dims_count = 2; float output_data[output_dims_count]; @@ -246,6 +425,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBias) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel64bBiasCompressed) { + const float input_scale = 128.0f / 65536; + const float output_scale = 128.0f / 65536; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int64_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBias) { const int output_dims_count = 12; int16_t output_data[output_dims_count]; @@ -276,6 +514,65 @@ TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBias) { output_data)); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantized16x8PerChannel32bBiasCompressed) { + const float input_scale = 128.0f / 65536; + const float output_scale = 128.0f / 65536; + const int input_zero_point = 0; + const int output_zero_point = 0; + constexpr float filter_scales[] = {tflite::testing::kNumChannelsQ1, 1.0f, + 2.0f}; + constexpr int filter_zero_points[] = {tflite::testing::kNumChannelsQ1, 0, 0}; + // bias scales and zero points will be computed + float bias_scales[std::extent::value] = {}; + int bias_zero_points[std::extent::value] = {}; + + int16_t input_quantized[tflite::testing::kInputElementsQ1]; + int8_t filter_quantized[tflite::testing::kBinQuantFilterValueTableElementsQ1]; + int32_t bias_quantized[tflite::testing::kBiasElementsQ1]; + int16_t golden_quantized[tflite::testing::kOutputElementsQ1]; + int16_t output_quantized[tflite::testing::kOutputElementsQ1]; + + tflite::testing::TestCompressionQuantizedInfo filter_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + filter_comp_info.value_table = filter_quantized; + filter_comp_info.value_table_stride = + tflite::testing::kBinQuantFilterValueTableElementsQ1 / + tflite::testing::kNumChannelsQ1; + filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1; + filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1; + filter_comp_info.data = tflite::testing::kBinQuantFilterValueTableQ1; + filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1; + filter_comp_info.scales = filter_scales; + filter_comp_info.zero_points = filter_zero_points; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = + tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1; + bias_comp_info.data = tflite::testing::kBiasDataQ1; + bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1; + bias_comp_info.scales = bias_scales; + bias_comp_info.zero_points = bias_zero_points; + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + tflite::testing::TestConvQuantizedPerChannelCompressed( + tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1, + input_quantized, input_scale, input_zero_point, + tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1_16, + golden_quantized, output_quantized, output_scale, output_zero_point, + &tflite::testing::common_conv_params_q1, tflite::Register_CONV_2D(), + &filter_comp_info, &bias_comp_info)); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestDilatedQuantizedPerChannel) { const int output_dims_count = 24; int8_t output_data[output_dims_count]; diff --git a/tensorflow/lite/micro/kernels/conv_test.h b/tensorflow/lite/micro/kernels/conv_test.h index c655f043bcc..642f4c76d7a 100644 --- a/tensorflow/lite/micro/kernels/conv_test.h +++ b/tensorflow/lite/micro/kernels/conv_test.h @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/conv.h" #include "tensorflow/lite/micro/kernels/kernel_runner.h" #include "tensorflow/lite/micro/kernels/micro_ops.h" #include "tensorflow/lite/micro/test_helpers.h" @@ -26,35 +27,101 @@ limitations under the License. namespace tflite { namespace testing { -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data); +constexpr int kConvMaxTensors = 4; +constexpr int kConvMaxInputTensors = 3; +template TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, int8_t* output_data); - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const float* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - float* output_data, float tolerance = 1e-5); - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const int8_t* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - int8_t* output_data, float tolerance = 1e-5); - -TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - int* output_dims_data, - const float* expected_output_data, - TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data); + int output_length, const TfLiteConvParams* conv_params, + TFLMRegistration registration, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const CompressedTensorList* comp_list_p = nullptr +#endif // USE_TFLM_COMPRESSION +) { + // TODO(b/358165875): support optional bias tensor + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, conv_params +#ifdef USE_TFLM_COMPRESSION + , + nullptr, comp_list_p +#endif // USE_TFLM_COMPRESSION + ); + + const char* init_data = reinterpret_cast(conv_params); + TfLiteStatus status = runner.InitAndPrepare(init_data); + if (status != kTfLiteOk) { + return status; + } + return runner.Invoke(); +} + +template +TfLiteStatus ValidateConvGoldens( + TfLiteTensor* tensors, int tensors_size, const T* expected_output_data, + int output_length, const TfLiteConvParams* conv_params, + TFLMRegistration registration, T* output_data, float tolerance = 1e-5 +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + if (filter_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*filter_comp_info, tensors[kConvWeightsTensor], + kConvWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kConvBiasTensor], + kConvBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + + TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, + conv_params, registration, output_data +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); + if (status != kTfLiteOk) { + return status; + } + for (int i = 0; i < output_length; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], + tolerance); + } + return kTfLiteOk; +} + +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +); TfLiteStatus TestConvQuantizedPerChannel( int* input_dims_data, const float* input_data, int8_t* input_quantized, @@ -88,6 +155,74 @@ TfLiteStatus TestConvQuantizedPerChannel( float output_scale, int output_zero_point, TfLiteConvParams* conv_params, TFLMRegistration registration, int16_t* output_data); +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestConvQuantizedPerChannelCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteConvParams* conv_params, TFLMRegistration registration, + const TestCompressionQuantizedInfo* filter_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + // TODO(b/358165875): account for optional bias tensor + // bool null_bias = comp_info->bias_data == nullptr ? true : false; + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* filter_scales = + FloatArrayFromFloats(filter_comp_info->scales); + TfLiteIntArray* filter_zero_points = + IntArrayFromInts(filter_comp_info->zero_points); + TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales); + TfLiteIntArray* bias_zero_points = + IntArrayFromInts(bias_comp_info->zero_points); + + TfLiteAffineQuantization filter_quant = {}; + TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor( + filter_comp_info->compressed, filter_dims, filter_scales, + filter_zero_points, &filter_quant, kConvQuantizedDimension, + false /* is_variable */, kTfLiteInt8); + SymmetricPerChannelQuantize( + filter_comp_info->data, filter_comp_info->value_table, + filter_scales->size * filter_comp_info->value_table_stride, + filter_scales->size, filter_scales->data); + + TfLiteAffineQuantization bias_quant = {}; + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_comp_info->compressed, bias_dims, input_scale, filter_scales, + bias_scales, bias_zero_points, &bias_quant, kConvQuantizedDimension, + false /* is_variable */, typeToTfLiteType()); + SymmetricPerChannelQuantize( + bias_comp_info->data, bias_comp_info->value_table, + bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size, + bias_scales->data); + + constexpr int tensors_size = kConvMaxTensors; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point), + filter_tensor, + bias_tensor, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, + output_zero_point), + }; + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateConvGoldens(tensors, tensors_size, expected_output_quantized, + output_dims_count, conv_params, registration, + output_quantized, 1.0e-5f /* tolerance */, + filter_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace testing } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/conv_test_common.cc b/tensorflow/lite/micro/kernels/conv_test_common.cc index a0f733b8e42..3825e05373c 100644 --- a/tensorflow/lite/micro/kernels/conv_test_common.cc +++ b/tensorflow/lite/micro/kernels/conv_test_common.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,88 +18,18 @@ limitations under the License. namespace tflite { namespace testing { -template -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, T* output_data) { - int inputs_array_data[] = {3, 0, 1, 2}; - TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); - int outputs_array_data[] = {1, 3}; - TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); - - micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, conv_params); - - const char* init_data = reinterpret_cast(conv_params); - TfLiteStatus status = runner.InitAndPrepare(init_data); - if (status != kTfLiteOk) { - return status; - } - return runner.Invoke(); -} - -template -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const T* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, T* output_data, - float tolerance) { - TfLiteStatus status = InvokeConv(tensors, tensors_size, output_length, - conv_params, registration, output_data); - if (status != kTfLiteOk) { - return status; - } - for (int i = 0; i < output_length; ++i) { - TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], - tolerance); - } - return kTfLiteOk; -} - -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data) { - return InvokeConv(tensors, tensors_size, output_length, conv_params, - registration, output_data); -} - -TfLiteStatus InvokeConv(TfLiteTensor* tensors, int tensors_size, - int output_length, TfLiteConvParams* conv_params, - TFLMRegistration registration, int8_t* output_data) { - return InvokeConv(tensors, tensors_size, output_length, conv_params, - registration, output_data); -} - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const float* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - float* output_data, float tolerance) { - return ValidateConvGoldens(tensors, tensors_size, expected_output_data, - output_length, conv_params, registration, - output_data, tolerance); -} - -TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, - const int8_t* expected_output_data, - int output_length, - TfLiteConvParams* conv_params, - TFLMRegistration registration, - int8_t* output_data, float tolerance) { - return ValidateConvGoldens( - tensors, tensors_size, expected_output_data, output_length, conv_params, - registration, output_data, tolerance); -} - -TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, - int* filter_dims_data, const float* filter_data, - int* bias_dims_data, const float* bias_data, - int* output_dims_data, - const float* expected_output_data, - TfLiteConvParams* conv_params, - TFLMRegistration registration, float* output_data) { +TfLiteStatus TestConvFloat( + int* input_dims_data, const float* input_data, int* filter_dims_data, + const float* filter_data, int* bias_dims_data, const float* bias_data, + int* output_dims_data, const float* expected_output_data, + TfLiteConvParams* conv_params, TFLMRegistration registration, + float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* filter_comp_info, + const TestCompressionInfo* bias_comp_info +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -117,7 +47,12 @@ TfLiteStatus TestConvFloat(int* input_dims_data, const float* input_data, return ValidateConvGoldens(tensors, tensors_size, expected_output_data, output_dims_count, conv_params, registration, - output_data); + output_data +#ifdef USE_TFLM_COMPRESSION + , + 1e-5f, filter_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } template diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc index 384dba9f7ac..39618d41f66 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,14 +52,34 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: { +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION tflite::reference_ops::Conv( ConvParamsFloat(params, op_data.reference_op_data), tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + op_data.reference_op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + op_data.reference_op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output), tflite::micro::GetTensorShape(nullptr), nullptr); diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc index f17809484d6..b5d4b5ea859 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_hifi.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -177,9 +177,30 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node, const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int16_t* input_data = tflite::micro::GetTensorData(input); +#ifdef USE_TFLM_COMPRESSION + const int8_t* filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.reference_op_data.weights_scratch_index); + const int64_t* bias_data = tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION const int8_t* filter_data = tflite::micro::GetTensorData(filter); - const int64_t* bias_data = tflite::micro::GetTensorData(bias); + const int64_t* bias_data = + tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION int16_t* output_data = tflite::micro::GetTensorData(output); int output_data_format = 0; @@ -211,7 +232,6 @@ TfLiteStatus ConvEvalHifiInt16(TfLiteContext* context, TfLiteNode* node, } else { void* p_scratch = static_cast( context->GetScratchBuffer(context, data.scratch_tensor_index)); - for (int batch = 0; batch < batches; ++batch) { int16_t* p_out_temp; p_out_temp = &output_data[batch * out_length]; @@ -275,8 +295,26 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node, const int output_height = output_shape.Dims(1); const int output_width = output_shape.Dims(2); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int8_t* input_data = tflite::micro::GetTensorData(input); - const int32_t* bias_data = tflite::micro::GetTensorData(bias); +#ifdef USE_TFLM_COMPRESSION + const int32_t* bias_data = tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, + data.reference_op_data.bias_scratch_index); +#else // USE_TFLM_COMPRESSION + const int32_t* bias_data = + tflite::micro::GetOptionalTensorData(bias); +#endif // USE_TFLM_COMPRESSION int8_t* output_data = tflite::micro::GetTensorData(output); const int8_t* filter_data; @@ -289,7 +327,13 @@ TfLiteStatus ConvEvalHifiInt8(TfLiteContext* context, TfLiteNode* node, tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } else { +#ifdef USE_TFLM_COMPRESSION + filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.reference_op_data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION filter_data = tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION } int output_data_format = 0; diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc b/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc index 2492d4b348b..c50faa43e42 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_int16_reference.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,17 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + if (bias == nullptr || bias->type == kTfLiteInt32) { reference_integer_ops::ConvPerChannel( ConvParamsQuantized(params, op_data), @@ -52,9 +63,18 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else if (bias->type == kTfLiteInt64) { @@ -64,9 +84,18 @@ TfLiteStatus ConvReferenceEvalInt16(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + op_data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); } else { diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc b/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc index 6ac07bab403..24adc64e19f 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_int8_reference.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,6 +45,17 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + +#endif // USE_TFLM_COMPRESSION + const int8_t* filter_data; if (filter->type == kTfLiteInt4) { int8_t* unpacked_filter_data = static_cast( @@ -54,7 +65,12 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(filter).FlatSize(), unpacked_filter_data); filter_data = unpacked_filter_data; } else { +#ifdef USE_TFLM_COMPRESSION + filter_data = tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, op_data.weights_scratch_index); +#else // USE_TFLM_COMPRESSION filter_data = tflite::micro::GetTensorData(filter); +#endif // USE_TFLM_COMPRESSION } reference_integer_ops::ConvPerChannel( @@ -64,7 +80,12 @@ TfLiteStatus ConvReferenceEvalInt8(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), filter_data, tflite::micro::GetTensorShape(bias), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetOptionalTensorData( + micro_context, bias, bias_comp_td, op_data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); diff --git a/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc index 812ab60ebf2..0da261f0aa4 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv_vision.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,8 +36,10 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { MicroContext* micro_context = GetMicroContext(context); TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, kConvInputTensor); + TF_LITE_ENSURE(context, input != nullptr); TfLiteTensor* bias = micro_context->AllocateTempInputTensor(node, kConvBiasTensor); + TF_LITE_ENSURE(context, bias != nullptr); const uint32_t input_height = SizeOfDimension(input, 1); const uint32_t input_width = SizeOfDimension(input, 2); @@ -47,8 +49,10 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, kConvOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); TfLiteTensor* filter = micro_context->AllocateTempInputTensor(node, kConvWeightsTensor); + TF_LITE_ENSURE(context, filter != nullptr); const uint32_t output_height = SizeOfDimension(output, 1); const uint32_t output_width = SizeOfDimension(output, 2); @@ -104,6 +108,58 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { filter_int8 = *filter; } +#ifdef USE_TFLM_COMPRESSION + + uint8_t* filter_data = nullptr; + int32_t* bias_data = nullptr; + + const CompressionTensorData* filter_comp_td = + micro_context->GetTensorCompressionData(node, kConvWeightsTensor); + if (filter_comp_td != nullptr) { + const size_t filter_data_size = + NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8); + filter_data = + micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t)); + if (filter_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* filter_eval = + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); + filter_data = static_cast(micro_context->DecompressTensorToBuffer( + *filter_eval, *filter_comp_td, filter_data)); + } else { + filter_data = GetTensorData(&filter_int8); + } + + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kConvBiasTensor); + if (bias_comp_td != nullptr) { + const size_t bias_data_size = + NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32); + bias_data = reinterpret_cast( + micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t))); + if (bias_data == nullptr) { + return kTfLiteError; + } + const TfLiteEvalTensor* bias_eval = + tflite::micro::GetEvalInput(context, node, kConvBiasTensor); + bias_data = static_cast(micro_context->DecompressTensorToBuffer( + *bias_eval, *bias_comp_td, bias_data)); + } else { + bias_data = GetTensorData(bias); + } + + if (filter_data == nullptr || bias_data == nullptr) { + return kTfLiteError; + } + +#else // USE_TFLM_COMPRESSION + + uint8_t* filter_data = GetTensorData(&filter_int8); + int32_t* bias_data = GetTensorData(bias); + +#endif // USE_TFLM_COMPRESSION + status = xiConvSetContext( data->p_context, data->context_size, input_depth, input_width, input_height, output_depth, output_width, output_height, filter_width, @@ -112,8 +168,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { data->reference_op_data.output_multiplier, data->reference_op_data.output_shift, data->reference_op_data.output_activation_min, - data->reference_op_data.output_activation_max, - (uint8_t*)GetTensorData(&filter_int8), + data->reference_op_data.output_activation_max, filter_data, data->reference_op_data.padding.width, data->reference_op_data.padding.height); if (status) { @@ -138,9 +193,7 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { status = xiConvDoCoeffReorder( data->p_context, data->context_size, reinterpret_cast(data->reorder_coefficient_bias), - data->reorder_coefficient_bias_size, - const_cast(GetTensorData(&filter_int8)), - const_cast(GetTensorData(bias))); + data->reorder_coefficient_bias_size, filter_data, bias_data); if (status) { return kTfLiteError; } @@ -149,12 +202,21 @@ TfLiteStatus ConvPrepareVision(TfLiteContext* context, TfLiteNode* node) { micro_context->DeallocateTempBuffer(GetTensorData(&filter_int8)); } +#ifdef USE_TFLM_COMPRESSION + + if (filter_comp_td) { + micro_context->DeallocateTempBuffer(filter_data); + } + if (bias_comp_td) { + micro_context->DeallocateTempBuffer(reinterpret_cast(bias_data)); + } + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(output); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); - if (bias != nullptr) { - micro_context->DeallocateTempTfLiteTensor(bias); - } + micro_context->DeallocateTempTfLiteTensor(bias); return kTfLiteOk; }