From 61b36a7e99e544f452d6959f83ec9c814d8793f1 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 17 Dec 2024 17:45:51 +0000 Subject: [PATCH] FC per channel quantization fix for 3d input --- .../internal/reference/integer_ops/fully_connected.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h index 3a74402ed98..0ebcfefba2d 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h @@ -42,12 +42,13 @@ void FullyConnectedPerChannel( const int32_t output_activation_min = params.quantized_activation_min; const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); + TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); const int filter_dim_count = filter_shape.DimensionsCount(); - const int batches = output_shape.Dims(0); - const int output_depth = output_shape.Dims(1); + const int output_dim_count = output_shape.DimensionsCount(); + const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); + const int output_depth = output_shape.Dims(output_dim_count - 1); TFLITE_DCHECK_LE(output_depth, filter_shape.Dims(filter_dim_count - 2)); const int accum_depth = filter_shape.Dims(filter_dim_count - 1); for (int b = 0; b < batches; ++b) {