Skip to content

Commit

Permalink
CMSIS-NN: Move kernel sums to prepare for FC and SVDF INT8
Browse files Browse the repository at this point in the history
This updates to the latest CMSIS-NN in the third-party download script.

Change-Id: I45b75531be996e07bb17bc864e64154ae752e8b2
  • Loading branch information
mansnils committed Sep 19, 2023
1 parent 2f2c744 commit 3fb89e9
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 11 deletions.
35 changes: 34 additions & 1 deletion tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {
Expand All @@ -42,6 +43,8 @@ struct OpData {
// Index to buffer for optimizations if applicable.
int buffer_idx;

int32_t* kernel_sums;

int32_t batches;
int32_t accum_depth;
int32_t output_depth;
Expand Down Expand Up @@ -124,6 +127,35 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
} else {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);

if (buf_size > 0) {
data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));

int8_t* filter_data = GetTensorData<int8_t>(filter);

if (filter->type == kTfLiteInt4) {
size_t filter_size = GetTensorShape(filter).FlatSize();
int8_t* unpacked_filter_buf =
reinterpret_cast<int8_t*>(micro_context->AllocateTempBuffer(
filter_size, tflite::MicroArenaBufferAlignment()));

tflite::tensor_utils::UnpackDenseInt4IntoInt8(
filter_data, filter_size, unpacked_filter_buf);
filter_data = unpacked_filter_buf;
}

arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth,
filter_data);

if (filter->type == kTfLiteInt4) {
micro_context->DeallocateTempBuffer(
reinterpret_cast<uint8_t*>(filter_data));
}

// Do not request a scratch buffer since using persistent memory
buf_size = 0;
}
}
}

Expand Down Expand Up @@ -252,6 +284,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
fc_params.activation.min = data.reference_op_data.output_activation_min;
fc_params.activation.max = data.reference_op_data.output_activation_max;

ctx.buf = data.kernel_sums;
TF_LITE_ENSURE_EQ(
context,
arm_fully_connected_s8(
Expand Down
207 changes: 199 additions & 8 deletions tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,195 @@ limitations under the License.
namespace tflite {
namespace {

struct CmsisNnOpDataSvdf {
int32_t effective_scale_1_a;
int32_t effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
// shift value - typically between [-32, 32].
int effective_scale_1_b;
int effective_scale_2_b;
int scratch_tensor_index;
int scratch_output_tensor_index;

// Cached tensor zero point values for quantized operations.
int input_zero_point;
int output_zero_point;
int activation_state_zero_point;
int32_t* kernel_sums;
};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataSvdf));
return context->AllocatePersistentBuffer(context, sizeof(CmsisNnOpDataSvdf));
}

TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);

const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);

MicroContext* micro_context = GetMicroContext(context);

// Validate Tensor Inputs (dtype depends on quantization):
// [0] = Input, {2, batch_size, input_size}
// [1] = Weights Feature, {2, num_filters, input_size}
// [2] = Weights Time, {2, num_filters, memory_size}
// [3] = Bias (optional), {1, num_units}
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kSvdfInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* weights_feature =
micro_context->AllocateTempInputTensor(node, kSvdfWeightsFeatureTensor);
TF_LITE_ENSURE(context, weights_feature != nullptr);
TfLiteTensor* weights_time =
micro_context->AllocateTempInputTensor(node, kSvdfWeightsTimeTensor);
TF_LITE_ENSURE(context, weights_time != nullptr);
TfLiteTensor* bias =
micro_context->AllocateTempInputTensor(node, kSvdfBiasTensor);
TfLiteTensor* activation_state = micro_context->AllocateTempInputTensor(
node, kSvdfInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);

// Define input constants based on input tensor definition above:
const int rank = params->rank;
const int input_size = input->dims->data[1];
const int batch_size = input->dims->data[0];
const int num_filters = weights_feature->dims->data[0];
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
const int num_units = num_filters / rank;
const int memory_size = weights_time->dims->data[1];

// Validate Input Tensor:
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);

// Validate Tensor Output:
// [0] = float/int8_t, {2, batch_size, num_units}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kSvdfOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);

// Validate Weights Feature Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);

// Validate Weights Time Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);

// Validate Optional Bias Input Tensor:
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
}

// Validate Activation State Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
memory_size * num_filters);
// Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);

TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);

TFLITE_DCHECK(node->user_data != nullptr);
CmsisNnOpDataSvdf* data = static_cast<CmsisNnOpDataSvdf*>(node->user_data);

if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
TF_LITE_ENSURE(context, (weights_time->type == kTfLiteInt16) ||
(weights_time->type == kTfLiteInt8));
TF_LITE_ENSURE(context, (activation_state->type == kTfLiteInt16) ||
(activation_state->type == kTfLiteInt8));
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
}

TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);

const double effective_scale_1 = static_cast<double>(
input->params.scale * weights_feature->params.scale /
activation_state->params.scale);
const double effective_scale_2 =
static_cast<double>(activation_state->params.scale *
weights_time->params.scale / output->params.scale);

// TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
// TODO(#1751): account for optional bias tensor
TF_LITE_ENSURE(
context,
std::abs(static_cast<double>(bias->params.scale) -
static_cast<double>(activation_state->params.scale *
weights_time->params.scale)) < 1e-5);

QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
&(data->effective_scale_1_b));
QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
&(data->effective_scale_2_b));

data->input_zero_point = input->params.zero_point;
data->output_zero_point = output->params.zero_point;
data->activation_state_zero_point = activation_state->params.zero_point;

TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);

const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
context, batch_size * num_filters * sizeof(int32_t),
&(data->scratch_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_status);

const TfLiteStatus scratch_output_status =
context->RequestScratchBufferInArena(
context, batch_size * num_units * sizeof(int32_t),
&(data->scratch_output_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_output_status);

cmsis_nn_dims weights_feature_dims;
weights_feature_dims.n = num_filters;
weights_feature_dims.h = input_size;

const int32_t buf_size = arm_svdf_s8_get_buffer_size(&weights_feature_dims);

if (buf_size > 0) {
data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));

arm_vector_sum_s8(data->kernel_sums, input_size, num_filters,
GetTensorData<int8_t>(weights_feature));
}

} else {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);

TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
context, batch_size * num_filters * sizeof(float),
&(data->scratch_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_status);
}

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(weights_feature);
micro_context->DeallocateTempTfLiteTensor(weights_time);
micro_context->DeallocateTempTfLiteTensor(activation_state);
micro_context->DeallocateTempTfLiteTensor(output);
// TODO(#1751): account for optional bias tensor
micro_context->DeallocateTempTfLiteTensor(bias);
return kTfLiteOk;
}

TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
Expand All @@ -44,7 +230,7 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpDataSvdf& data) {
const CmsisNnOpDataSvdf& data) {
cmsis_nn_dims input_dims;
input_dims.n = input_tensor->dims->data[0];
input_dims.h = input_tensor->dims->data[1];
Expand Down Expand Up @@ -102,9 +288,12 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,

switch (weights_time_tensor->type) {
case kTfLiteInt8: {
cmsis_nn_context ctx;
ctx.buf = data.kernel_sums;

arm_svdf_s8(
&scratch_ctx, &scratch_output_ctx, &svdf_params, &in_quant_params,
&out_quant_params, &input_dims,
&ctx, &scratch_ctx, &scratch_output_ctx, &svdf_params,
&in_quant_params, &out_quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input_tensor), &state_dims,
tflite::micro::GetTensorData<int8_t>(activation_state_tensor),
&weights_feature_dims,
Expand Down Expand Up @@ -141,7 +330,8 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus EvalSvdf(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
const CmsisNnOpDataSvdf& data =
*(static_cast<const CmsisNnOpDataSvdf*>(node->user_data));

const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
Expand Down Expand Up @@ -184,7 +374,8 @@ TfLiteStatus EvalSvdf(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus EvalSvdfInt8(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
const CmsisNnOpDataSvdf& data =
*(static_cast<const CmsisNnOpDataSvdf*>(node->user_data));

const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
Expand Down Expand Up @@ -213,11 +404,11 @@ TfLiteStatus EvalSvdfInt8(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_SVDF() {
return tflite::micro::RegisterOp(Init, PrepareSvdf, EvalSvdf);
return tflite::micro::RegisterOp(Init, CmsisNnPrepareSvdf, EvalSvdf);
}

TFLMRegistration Register_SVDF_INT8() {
return tflite::micro::RegisterOp(Init, PrepareSvdf, EvalSvdfInt8);
return tflite::micro::RegisterOp(Init, CmsisNnPrepareSvdf, EvalSvdfInt8);
}

} // namespace tflite
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ if [ -d ${DOWNLOADED_CMSIS_NN_PATH} ]; then
echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download."
else

ZIP_PREFIX_NN="dc64e488f6655aa2792d2aceca316c896f78b4db"
ZIP_PREFIX_NN="58f177057699d6d0a8d3af34c01c271202b6f85e"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
CMSIS_NN_MD5="80f9cf0bcc10a4aefb6531ae53942044"
CMSIS_NN_MD5="b9f05caa9cd9bb4b545bf4d7f8fc5274"

# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.
Expand Down

0 comments on commit 3fb89e9

Please sign in to comment.