diff --git a/tensorflow/lite/micro/kernels/fully_connected.cc b/tensorflow/lite/micro/kernels/fully_connected.cc index a1d22cfbd8e..b044a4bbab2 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/fully_connected.cc @@ -60,7 +60,7 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { (input->type == kTfLiteInt8 && (filter->type != kTfLiteInt8 && filter->type != kTfLiteInt4)) || (input->type == kTfLiteInt16 && filter->type != kTfLiteInt8)) { - MicroPrintf("Input type: %s with filter type : %s not supported.", + MicroPrintf("Input type: %s with filter type: %s not supported.", TfLiteTypeGetName(input->type), TfLiteTypeGetName(filter->type)); return kTfLiteError; @@ -79,6 +79,23 @@ TfLiteStatus FullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) { context, params->activation, input->type, input, filter, bias, output, data)); +#ifdef USE_TFLM_COMPRESSION + + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + if (micro_context->IsTensorCompressed(node, kFullyConnectedWeightsTensor) && + filter->type == kTfLiteInt4) { + MicroPrintf("Compression not supported with INT4 tensors"); + return kTfLiteError; + } + data->weights_scratch_index = + micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedWeightsTensor); + data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer( + node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(filter); if (bias != nullptr) { @@ -102,8 +119,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); - TFLITE_DCHECK(node->user_data != nullptr); +#ifdef USE_TFLM_COMPRESSION + + MicroContext* micro_context = GetMicroContext(context); + const CompressionTensorData* weights_comp_td = + micro_context->GetTensorCompressionData(node, + kFullyConnectedWeightsTensor); + const CompressionTensorData* bias_comp_td = + micro_context->GetTensorCompressionData(node, kFullyConnectedBiasTensor); + +#endif // USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); const auto& data = *(static_cast(node->user_data)); @@ -115,9 +143,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -152,9 +189,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)) : tflite::reference_integer_ops::FullyConnected( @@ -162,9 +209,19 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData( + micro_context, filter, weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, + data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; @@ -186,9 +243,18 @@ TfLiteStatus FullyConnectedEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), +#ifdef USE_TFLM_COMPRESSION + tflite::micro::GetTensorData(micro_context, filter, + weights_comp_td, + data.weights_scratch_index), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData( + micro_context, bias, bias_comp_td, data.bias_scratch_index), +#else // USE_TFLM_COMPRESSION tflite::micro::GetTensorData(filter), tflite::micro::GetTensorShape(bias), tflite::micro::GetOptionalTensorData(bias), +#endif // USE_TFLM_COMPRESSION tflite::micro::GetTensorShape(output), tflite::micro::GetTensorData(output)); break; diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 670488ab618..64213f0fb63 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -50,6 +50,14 @@ struct OpDataFullyConnected { int32_t* per_channel_output_shift; bool is_per_channel; #endif + +#ifdef USE_TFLM_COMPRESSION + + // scratch buffers for compressed tensors + int weights_scratch_index; + int bias_scratch_index; + +#endif // USE_TFLM_COMPRESSION }; extern const int kFullyConnectedInputTensor; diff --git a/tensorflow/lite/micro/kernels/fully_connected_test.cc b/tensorflow/lite/micro/kernels/fully_connected_test.cc index 2ad132055b8..2cf3427c874 100644 --- a/tensorflow/lite/micro/kernels/fully_connected_test.cc +++ b/tensorflow/lite/micro/kernels/fully_connected_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,6 +42,29 @@ const float simple_weights_data[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 }; +int simple_bias_dims[] = {1, 3}; +const float simple_bias_data[] = {1, 2, 3}; + +#ifdef USE_TFLM_COMPRESSION + +// compressed filter data for kBinQuant scheme +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantWeightData[] = { + 0x01, 0x23, 0x45, 0x67, 0x89, 0x01, 0x23, 0x45, + 0x67, 0x89, 0x01, 0x23, 0x45, 0x67, 0x89}; +constexpr float kBinQuantWeightValueTable[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; +constexpr size_t kBinQuantWeightValueTableElements = + std::extent::value; +constexpr int kBinQuantWeightBitWidth = 4; +// compressed bias data for kBinQuant scheme +// Align the tensor data the same as a Buffer in the schema +alignas(16) constexpr uint8_t kBinQuantBiasData[] = {0x18}; +constexpr int kBinQuantBiasBitWidth = 2; +constexpr size_t simple_bias_size = + std::extent::value; + +#endif // USE_TFLM_COMPRESSION + // TODO(b/258710417): INT4 isn't currently supported on Hexagon. #if !defined(HEXAGON) const float simple_int4_weights_data[] = { @@ -53,8 +76,6 @@ const float simple_golden_null_bias_int4_weights[] = { -28, -28, -28, 0, 0, 0, }; #endif -int simple_bias_dims[] = {1, 3}; -const float simple_bias_data[] = {1, 2, 3}; const float simple_golden[] = { 24, 25, 26, 58, 59, 60, }; @@ -241,11 +262,19 @@ const float representative_64x16_golden[] = { const int representative_64x16_output_size = 16; int representative_64x16_output_dims[] = {2, 1, 16}; -template +constexpr int kMaxTensors = 4; + +template TfLiteStatus ValidateFullyConnectedGoldens( TfLiteTensor* tensors, const int tensors_size, bool null_bias, const TfLiteFusedActivation activation, const float tolerance, - const int output_len, const T* golden, T* output_data) { + const int output_len, const T* golden, T* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* weight_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteFullyConnectedParams builtin_data = { activation, kTfLiteFullyConnectedWeightsFormatDefault, false, false, kTfLiteNoType}; @@ -272,10 +301,37 @@ TfLiteStatus ValidateFullyConnectedGoldens( TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); +#ifdef USE_TFLM_COMPRESSION + + TestCompressedList tcl; + + if (weight_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*weight_comp_info, tensors[kFullyConnectedWeightsTensor], + kFullyConnectedWeightsTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + if (bias_comp_info != nullptr) { + TF_LITE_MICRO_EXPECT_EQ( + tcl.AddInput(*bias_comp_info, tensors[kFullyConnectedBiasTensor], + kFullyConnectedBiasTensor), + kTfLiteOk); + TF_LITE_MICRO_CHECK_FAIL(); + } + const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList(); + +#endif // USE_TFLM_COMPRESSION + const TFLMRegistration registration = Register_FULLY_CONNECTED(); micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, outputs_array, - reinterpret_cast(&builtin_data)); + reinterpret_cast(&builtin_data), nullptr +#ifdef USE_TFLM_COMPRESSION + , + comp_list_p +#endif // USE_TFLM_COMPRESSION + ); TfLiteStatus status = runner.InitAndPrepare(); if (status != kTfLiteOk) { @@ -293,11 +349,18 @@ TfLiteStatus ValidateFullyConnectedGoldens( return kTfLiteOk; } +template TfLiteStatus TestFullyConnectedFloat( int* input_dims_data, const float* input_data, int* weights_dims_data, const float* weights_data, int* bias_dims_data, const float* bias_data, const float* golden, int* output_dims_data, - TfLiteFusedActivation activation, float* output_data) { + TfLiteFusedActivation activation, float* output_data +#ifdef USE_TFLM_COMPRESSION + , + const TestCompressionInfo* weight_comp_info = nullptr, + const TestCompressionInfo* bias_comp_info = nullptr +#endif // USE_TFLM_COMPRESSION +) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data); TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); @@ -305,16 +368,15 @@ TfLiteStatus TestFullyConnectedFloat( const int output_dims_count = ElementCount(*output_dims); bool null_bias = bias_data == nullptr ? true : false; - constexpr int array_size = 4; // Avoid variable length array warning. - const int inputs_size = bias_data == nullptr ? 2 : 3; + const int inputs_size = null_bias ? 2 : 3; constexpr int outputs_size = 1; const int tensors_size = inputs_size + outputs_size; - TfLiteTensor tensors[array_size]; + TfLiteTensor tensors[kMaxTensors]; tensors[0] = CreateTensor(input_data, input_dims); tensors[1] = CreateTensor(weights_data, weights_dims); - if (bias_data == nullptr) { + if (null_bias) { tensors[2] = CreateTensor(output_data, output_dims); } else { tensors[2] = CreateTensor(bias_data, bias_dims); @@ -323,7 +385,12 @@ TfLiteStatus TestFullyConnectedFloat( return ValidateFullyConnectedGoldens(tensors, tensors_size, null_bias, activation, 1e-4f, output_dims_count, - golden, output_data); + golden, output_data +#ifdef USE_TFLM_COMPRESSION + , + weight_comp_info, bias_comp_info +#endif // USE_TFLM_COMPRESSION + ); } template @@ -345,7 +412,7 @@ TfLiteStatus TestFullyConnectedQuantized( bool null_bias = bias_data == nullptr ? true : false; constexpr int array_size = 4; // Avoid variable length array warning. - const int inputs_size = bias_data == nullptr ? 2 : 3; + const int inputs_size = null_bias ? 2 : 3; constexpr int outputs_size = 1; const int tensors_size = inputs_size + outputs_size; TfLiteTensor tensors[array_size]; @@ -355,7 +422,7 @@ TfLiteStatus TestFullyConnectedQuantized( tensors[1] = CreateQuantizedTensor( weights_data, weights_quantized, weights_dims, weights_scale, weights_zero_point, false, weights_packed_type); - if (bias_data == nullptr) { + if (null_bias) { tensors[2] = CreateQuantizedTensor(output_data, output_dims, output_scale, output_zero_point); } else { @@ -373,6 +440,71 @@ TfLiteStatus TestFullyConnectedQuantized( golden_quantized, output_data); } +#ifdef USE_TFLM_COMPRESSION + +template +TfLiteStatus TestFullyConnectedQuantizedCompressed( + int* input_dims_data, const float* input_data, TIO* input_quantized, + float input_scale, int input_zero_point, int* output_dims_data, + const float* expected_output_data, TIO* expected_output_quantized, + TIO* output_quantized, float output_scale, int output_zero_point, + const TfLiteFusedActivation activation, + const TestCompressionQuantizedInfo* weight_comp_info, + const TestCompressionQuantizedInfo* bias_comp_info) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* weight_dims = IntArrayFromInts(weight_comp_info->dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + + TfLiteFloatArray* weight_scales = + FloatArrayFromFloats(weight_comp_info->scales); + TfLiteIntArray* weight_zero_points = + IntArrayFromInts(weight_comp_info->zero_points); + + TfLiteTensor weight_tensor = CreateQuantizedTensor( + weight_comp_info->compressed, weight_dims, weight_scales->data[0], + weight_zero_points->data[0], false, kTfLiteInt8); + SymmetricQuantize(weight_comp_info->data, weight_comp_info->value_table, + weight_comp_info->value_table_stride, + weight_scales->data[0]); + + TfLiteTensor bias_tensor = {}; + if (bias_comp_info != nullptr) { + bias_tensor = CreateQuantizedTensor(bias_comp_info->compressed, bias_dims, + input_scale * weight_scales->data[0], 0, + false, typeToTfLiteType()); + SymmetricQuantize(bias_comp_info->data, bias_comp_info->value_table, + bias_comp_info->value_table_stride, + bias_tensor.params.scale); + } + + TfLiteTensor output_tensor = CreateQuantizedTensor( + output_quantized, output_dims, output_scale, output_zero_point); + + const int tensors_size = + (bias_comp_info == nullptr) ? kMaxTensors - 1 : kMaxTensors; + TfLiteTensor tensors[kMaxTensors] = {}; + tensors[0] = CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point); + tensors[1] = weight_tensor; + if (bias_comp_info == nullptr) { + tensors[2] = output_tensor; + } else { + tensors[2] = bias_tensor; + tensors[3] = output_tensor; + } + + const int output_dims_count = ElementCount(*output_dims); + Quantize(expected_output_data, expected_output_quantized, output_dims_count, + output_scale, output_zero_point); + return ValidateFullyConnectedGoldens( + tensors, tensors_size, bias_comp_info == nullptr, activation, 0.0f, + output_dims_count, expected_output_quantized, output_quantized, + weight_comp_info, bias_comp_info); +} + +#endif // USE_TFLM_COMPRESSION + } // namespace } // namespace testing } // namespace tflite @@ -393,6 +525,40 @@ TF_LITE_MICRO_TEST(SimpleTest) { kTfLiteOk); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestCompressed) { + float output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionInfo weight_comp_info = {}; + tflite::testing::TestCompressionInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = tflite::testing::simple_bias_data; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedFloat( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, + tflite::testing::simple_weights_dims, + reinterpret_cast(tflite::testing::kBinQuantWeightData), + tflite::testing::simple_bias_dims, + reinterpret_cast(tflite::testing::kBinQuantBiasData), + tflite::testing::simple_golden, tflite::testing::simple_output_dims, + kTfLiteActNone, output_data, &weight_comp_info, &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + TF_LITE_MICRO_TEST(SimpleTestNullBias) { float output_data[tflite::testing::simple_output_size]; TF_LITE_MICRO_EXPECT_EQ( @@ -434,6 +600,58 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) { kTfLiteOk); } +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8Compressed) { + const float input_scale = 1.0f; + const int input_zero_point = -1; + constexpr float weights_scale[] = {1, 1.0f}; + constexpr int weights_zero_point[] = {1, 0}; + const float output_scale = 0.5f; + const int output_zero_point = -1; + + int8_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::kBinQuantWeightValueTableElements]; + int32_t bias_quantized[tflite::testing::simple_output_size]; + int8_t golden_quantized[tflite::testing::simple_output_size]; + int8_t output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionQuantizedInfo weight_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = weights_quantized; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + weight_comp_info.compressed = tflite::testing::kBinQuantWeightData; + weight_comp_info.data = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.dims_data = tflite::testing::simple_weights_dims; + weight_comp_info.scales = weights_scale; + weight_comp_info.zero_points = weights_zero_point; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasData; + bias_comp_info.data = tflite::testing::simple_bias_data; + bias_comp_info.dims_data = tflite::testing::simple_bias_dims; + // bias scales and bias zero_points are not used + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedCompressed( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_output_dims, + tflite::testing::simple_golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone, &weight_comp_info, + &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + #if !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float input_scale = 128.0 / 65536; @@ -443,7 +661,6 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { const float output_scale = 128.0 / 65536; const int output_zero_point = 0; - const float simple_golden[] = {24, 25, 26, 58, 59, 60}; int16_t input_quantized[tflite::testing::simple_input_size]; int8_t weights_quantized[tflite::testing::simple_weights_size]; int64_t bias_quantized[tflite::testing::simple_output_size]; @@ -457,12 +674,66 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) { input_zero_point, tflite::testing::simple_weights_dims, tflite::testing::simple_weights_data, weights_quantized, weights_scale, weights_zero_point, tflite::testing::simple_bias_dims, - tflite::testing::simple_bias_data, bias_quantized, simple_golden, - golden_quantized, tflite::testing::simple_output_dims, output_scale, - output_zero_point, kTfLiteActNone, output_data), + tflite::testing::simple_bias_data, bias_quantized, + tflite::testing::simple_golden, golden_quantized, + tflite::testing::simple_output_dims, output_scale, output_zero_point, + kTfLiteActNone, output_data), kTfLiteOk); } -#endif + +#ifdef USE_TFLM_COMPRESSION + +TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16Compressed) { + const float input_scale = 128.0 / 65536; + const int input_zero_point = 0; + constexpr float weights_scale[] = {1, 1.0f}; + constexpr int weights_zero_point[] = {1, 0}; + const float output_scale = 128.0 / 65536; + const int output_zero_point = 0; + + int16_t input_quantized[tflite::testing::simple_input_size]; + int8_t weights_quantized[tflite::testing::kBinQuantWeightValueTableElements]; + int64_t bias_quantized[tflite::testing::simple_output_size]; + int16_t golden_quantized[tflite::testing::simple_output_size]; + int16_t output_data[tflite::testing::simple_output_size]; + + tflite::testing::TestCompressionQuantizedInfo weight_comp_info = {}; + tflite::testing::TestCompressionQuantizedInfo bias_comp_info = {}; + + weight_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + weight_comp_info.value_table = weights_quantized; + weight_comp_info.value_table_stride = + tflite::testing::kBinQuantWeightValueTableElements; + weight_comp_info.bit_width = tflite::testing::kBinQuantWeightBitWidth; + weight_comp_info.compressed = tflite::testing::kBinQuantWeightData; + weight_comp_info.data = tflite::testing::kBinQuantWeightValueTable; + weight_comp_info.dims_data = tflite::testing::simple_weights_dims; + weight_comp_info.scales = weights_scale; + weight_comp_info.zero_points = weights_zero_point; + + bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant; + bias_comp_info.value_table = bias_quantized; + bias_comp_info.value_table_stride = tflite::testing::simple_bias_size; + bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidth; + bias_comp_info.compressed = tflite::testing::kBinQuantBiasData; + bias_comp_info.data = tflite::testing::simple_bias_data; + bias_comp_info.dims_data = tflite::testing::simple_bias_dims; + // bias scales and bias zero_points are not used + + TF_LITE_MICRO_EXPECT_EQ( + tflite::testing::TestFullyConnectedQuantizedCompressed( + tflite::testing::simple_input_dims, + tflite::testing::simple_input_data, input_quantized, input_scale, + input_zero_point, tflite::testing::simple_output_dims, + tflite::testing::simple_golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone, &weight_comp_info, + &bias_comp_info), + kTfLiteOk); +} + +#endif // USE_TFLM_COMPRESSION + +#endif // !defined(HEXAGON) TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) { const float input_scale = 1.0f; diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc index 6bc23bc37d0..34c62cda412 100644 --- a/tensorflow/lite/micro/memory_arena_threshold_test.cc +++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc @@ -63,7 +63,6 @@ constexpr int kKeywordModelOnlyTotalSize = 14472; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kKeywordModelOnlyTailSize = 13800; constexpr int kKeywordModelPersistentTfLiteTensorDataSize = 128; -constexpr int kKeywordModelPersistentBufferDataSize = 832; #else // Total size contributed by the keyword model excluding the // RecordingMicroAllocator's overhead. @@ -74,7 +73,6 @@ constexpr int kKeywordModelOnlyTotalSize = 14936; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kKeywordModelOnlyTailSize = 14264; constexpr int kKeywordModelPersistentTfLiteTensorDataSize = 224; -constexpr int kKeywordModelPersistentBufferDataSize = 840; #endif constexpr int kKeywordModelHeadSize = 672; constexpr int kKeywordModelTfLiteTensorVariableBufferDataSize = 10240; @@ -87,6 +85,12 @@ uint8_t test_conv_tensor_arena[kTestConvModelArenaSize]; constexpr int kTestConvModelTensorCount = 15; constexpr int kTestConvModelNodeAndRegistrationCount = 7; +#if defined(USE_TFLM_COMPRESSION) +constexpr int kKeywordModelPersistentBufferDataSize = 920; +#else +constexpr int kKeywordModelPersistentBufferDataSize = 840; +#endif + // NOTE: These values are measured on x86-64: // TODO(b/158651472): Consider auditing these values on non-64 bit systems. #ifdef TF_LITE_STATIC_MEMORY