From 6147f20441aaa0be00920a407dacd06a44364188 Mon Sep 17 00:00:00 2001 From: Adrian Lundell <36153706+AdrianLundell@users.noreply.github.com> Date: Tue, 17 Sep 2024 19:58:53 +0200 Subject: [PATCH] Update CMSIS-NN vector_sum_s8 calls to new API (#2685) BUG=Add rhs_offset argument --- tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc | 8 ++++++-- tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc | 4 +++- .../kernels/cmsis_nn/unidirectional_sequence_lstm.cc | 2 +- .../lite/micro/tools/make/ext_libs/cmsis_nn_download.sh | 4 ++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc index 7c373b53639..8c85bc34870 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc @@ -147,8 +147,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->kernel_sums = static_cast( context->AllocatePersistentBuffer(context, buf_size)); + int32_t input_offset = -data->reference_op_data.input_zero_point; + int32_t filter_offset = -data->reference_op_data.filter_zero_point; arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth, - filter_data, 1, nullptr); + filter_data, input_offset, filter_offset, + tflite::GetTensorData(bias)); // Do not request a scratch buffer since using persistent memory buf_size = 0; @@ -321,7 +324,8 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, // If behaving like batch matmul we calculate kernel sums in eval. arm_vector_sum_s8( static_cast(ctx.buf), filter_dims.n, data.output_depth, - tflite::micro::GetTensorData(filter), 1, nullptr); + tflite::micro::GetTensorData(filter), fc_params.input_offset, + fc_params.filter_offset, bias_data); } TF_LITE_ENSURE_EQ( diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc b/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc index bf64016b13f..d39ae616c0f 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc @@ -193,7 +193,9 @@ TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) { context->AllocatePersistentBuffer(context, buf_size)); arm_vector_sum_s8(data->kernel_sums, input_size, num_filters, - GetTensorData(weights_feature), 1, nullptr); + GetTensorData(weights_feature), + -data->input_zero_point, + -data->activation_state_zero_point, nullptr); } } else { diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc b/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc index 49da4d916d0..fbbdab33ca0 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc @@ -53,7 +53,7 @@ LSTMBuffers CMSIS_NN_CreateLSTMBuffers(TfLiteContext* context, void CMSIS_NN_VectorSum(int32_t* kernel_sum, const int32_t size1, const int32_t size2, const int8_t* weights, const int32_t offset, const int32_t* biases) { - arm_vector_sum_s8(kernel_sum, size1, size2, weights, offset, biases); + arm_vector_sum_s8(kernel_sum, size1, size2, weights, offset, 0, biases); } void CMSIS_NN_VectorSum(int64_t* kernel_sum, const int32_t size1, diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh index 393c184d1e7..04e76dd508c 100755 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh @@ -38,9 +38,9 @@ source ${TENSORFLOW_ROOT}tensorflow/lite/micro/tools/make/bash_helpers.sh DOWNLOADS_DIR=${1} DOWNLOADED_CMSIS_NN_PATH=${DOWNLOADS_DIR}/cmsis_nn -ZIP_PREFIX_NN="95f293df19c9a38806868fe12a64a4f9b457f9c1" +ZIP_PREFIX_NN="f2cb41ca1450a4eb4307b2779dd5aae9028285a5" CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip" -CMSIS_NN_MD5="5e0c4cd60a5f074c4d26d1be236caefd" +CMSIS_NN_MD5="4d0e623432d6f8d3b201cbcd89218adf" should_download=$(check_should_download ${DOWNLOADS_DIR})