Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CMSIS-NN: Move kernel sums to prepare for FC and SVDF INT8 #2233

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_log.h"

namespace tflite {
Expand All @@ -42,6 +43,8 @@ struct OpData {
// Index to buffer for optimizations if applicable.
int buffer_idx;

int32_t* kernel_sums;

int32_t batches;
int32_t accum_depth;
int32_t output_depth;
Expand Down Expand Up @@ -124,6 +127,35 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
} else {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);

if (buf_size > 0) {
data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));

int8_t* filter_data = GetTensorData<int8_t>(filter);

if (filter->type == kTfLiteInt4) {
size_t filter_size = GetTensorShape(filter).FlatSize();
int8_t* unpacked_filter_buf =
reinterpret_cast<int8_t*>(micro_context->AllocateTempBuffer(
filter_size, tflite::MicroArenaBufferAlignment()));

tflite::tensor_utils::UnpackDenseInt4IntoInt8(
filter_data, filter_size, unpacked_filter_buf);
filter_data = unpacked_filter_buf;
}

arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth,
filter_data);

if (filter->type == kTfLiteInt4) {
micro_context->DeallocateTempBuffer(
reinterpret_cast<uint8_t*>(filter_data));
}

// Do not request a scratch buffer since using persistent memory
buf_size = 0;
}
}
}

Expand Down Expand Up @@ -252,6 +284,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
fc_params.activation.min = data.reference_op_data.output_activation_min;
fc_params.activation.max = data.reference_op_data.output_activation_max;

ctx.buf = data.kernel_sums;
TF_LITE_ENSURE_EQ(
context,
arm_fully_connected_s8(
Expand Down
207 changes: 199 additions & 8 deletions tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,195 @@ limitations under the License.
namespace tflite {
namespace {

struct CmsisNnOpDataSvdf {
int32_t effective_scale_1_a;
int32_t effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
// shift value - typically between [-32, 32].
int effective_scale_1_b;
int effective_scale_2_b;
int scratch_tensor_index;
int scratch_output_tensor_index;

// Cached tensor zero point values for quantized operations.
int input_zero_point;
int output_zero_point;
int activation_state_zero_point;
int32_t* kernel_sums;
};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpDataSvdf));
return context->AllocatePersistentBuffer(context, sizeof(CmsisNnOpDataSvdf));
}

TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);

const auto* params = static_cast<const TfLiteSVDFParams*>(node->builtin_data);

MicroContext* micro_context = GetMicroContext(context);

// Validate Tensor Inputs (dtype depends on quantization):
// [0] = Input, {2, batch_size, input_size}
// [1] = Weights Feature, {2, num_filters, input_size}
// [2] = Weights Time, {2, num_filters, memory_size}
// [3] = Bias (optional), {1, num_units}
// [4] = Activation State (variable),
// {2, batch_size, memory_size * num_filters}
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kSvdfInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* weights_feature =
micro_context->AllocateTempInputTensor(node, kSvdfWeightsFeatureTensor);
TF_LITE_ENSURE(context, weights_feature != nullptr);
TfLiteTensor* weights_time =
micro_context->AllocateTempInputTensor(node, kSvdfWeightsTimeTensor);
TF_LITE_ENSURE(context, weights_time != nullptr);
TfLiteTensor* bias =
micro_context->AllocateTempInputTensor(node, kSvdfBiasTensor);
TfLiteTensor* activation_state = micro_context->AllocateTempInputTensor(
node, kSvdfInputActivationStateTensor);
TF_LITE_ENSURE(context, activation_state != nullptr);

// Define input constants based on input tensor definition above:
const int rank = params->rank;
const int input_size = input->dims->data[1];
const int batch_size = input->dims->data[0];
const int num_filters = weights_feature->dims->data[0];
TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
const int num_units = num_filters / rank;
const int memory_size = weights_time->dims->data[1];

// Validate Input Tensor:
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);

// Validate Tensor Output:
// [0] = float/int8_t, {2, batch_size, num_units}
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kSvdfOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 2);
TF_LITE_ENSURE_EQ(context, output->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, output->dims->data[1], num_units);

// Validate Weights Feature Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_feature), 2);
TF_LITE_ENSURE_EQ(context, weights_feature->dims->data[1], input_size);

// Validate Weights Time Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(weights_time), 2);
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
TF_LITE_ENSURE_EQ(context, weights_time->dims->data[1], memory_size);

// Validate Optional Bias Input Tensor:
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
}

// Validate Activation State Input Tensor:
TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[0], batch_size);
TF_LITE_ENSURE_EQ(context, activation_state->dims->data[1],
memory_size * num_filters);
// Since is_variable is not part of TFLiteEvalTensor, check is_variable here.
TF_LITE_ENSURE_EQ(context, activation_state->is_variable, true);

TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);

TFLITE_DCHECK(node->user_data != nullptr);
CmsisNnOpDataSvdf* data = static_cast<CmsisNnOpDataSvdf*>(node->user_data);

if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
TF_LITE_ENSURE(context, (weights_time->type == kTfLiteInt16) ||
(weights_time->type == kTfLiteInt8));
TF_LITE_ENSURE(context, (activation_state->type == kTfLiteInt16) ||
(activation_state->type == kTfLiteInt8));
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
}

TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);

const double effective_scale_1 = static_cast<double>(
input->params.scale * weights_feature->params.scale /
activation_state->params.scale);
const double effective_scale_2 =
static_cast<double>(activation_state->params.scale *
weights_time->params.scale / output->params.scale);

// TODO(b/162018098): Use TF_LITE_ENSURE_NEAR when it is ready.
// TODO(#1751): account for optional bias tensor
TF_LITE_ENSURE(
context,
std::abs(static_cast<double>(bias->params.scale) -
static_cast<double>(activation_state->params.scale *
weights_time->params.scale)) < 1e-5);

QuantizeMultiplier(effective_scale_1, &(data->effective_scale_1_a),
&(data->effective_scale_1_b));
QuantizeMultiplier(effective_scale_2, &(data->effective_scale_2_a),
&(data->effective_scale_2_b));

data->input_zero_point = input->params.zero_point;
data->output_zero_point = output->params.zero_point;
data->activation_state_zero_point = activation_state->params.zero_point;

TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);

const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
context, batch_size * num_filters * sizeof(int32_t),
&(data->scratch_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_status);

const TfLiteStatus scratch_output_status =
context->RequestScratchBufferInArena(
context, batch_size * num_units * sizeof(int32_t),
&(data->scratch_output_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_output_status);

cmsis_nn_dims weights_feature_dims;
weights_feature_dims.n = num_filters;
weights_feature_dims.h = input_size;

const int32_t buf_size = arm_svdf_s8_get_buffer_size(&weights_feature_dims);

if (buf_size > 0) {
data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));

arm_vector_sum_s8(data->kernel_sums, input_size, num_filters,
GetTensorData<int8_t>(weights_feature));
}

} else {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, weights_time->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, activation_state->type, kTfLiteFloat32);
if (bias != nullptr) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
}
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);

TFLITE_DCHECK(context->RequestScratchBufferInArena != nullptr);
const TfLiteStatus scratch_status = context->RequestScratchBufferInArena(
context, batch_size * num_filters * sizeof(float),
&(data->scratch_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_status);
}

micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(weights_feature);
micro_context->DeallocateTempTfLiteTensor(weights_time);
micro_context->DeallocateTempTfLiteTensor(activation_state);
micro_context->DeallocateTempTfLiteTensor(output);
// TODO(#1751): account for optional bias tensor
micro_context->DeallocateTempTfLiteTensor(bias);
return kTfLiteOk;
}

TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
Expand All @@ -44,7 +230,7 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpDataSvdf& data) {
const CmsisNnOpDataSvdf& data) {
cmsis_nn_dims input_dims;
input_dims.n = input_tensor->dims->data[0];
input_dims.h = input_tensor->dims->data[1];
Expand Down Expand Up @@ -102,9 +288,12 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,

switch (weights_time_tensor->type) {
case kTfLiteInt8: {
cmsis_nn_context ctx;
ctx.buf = data.kernel_sums;

arm_svdf_s8(
&scratch_ctx, &scratch_output_ctx, &svdf_params, &in_quant_params,
&out_quant_params, &input_dims,
&ctx, &scratch_ctx, &scratch_output_ctx, &svdf_params,
&in_quant_params, &out_quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input_tensor), &state_dims,
tflite::micro::GetTensorData<int8_t>(activation_state_tensor),
&weights_feature_dims,
Expand Down Expand Up @@ -141,7 +330,8 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
TfLiteStatus EvalSvdf(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
const CmsisNnOpDataSvdf& data =
*(static_cast<const CmsisNnOpDataSvdf*>(node->user_data));

const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
Expand Down Expand Up @@ -184,7 +374,8 @@ TfLiteStatus EvalSvdf(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus EvalSvdfInt8(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
const CmsisNnOpDataSvdf& data =
*(static_cast<const CmsisNnOpDataSvdf*>(node->user_data));

const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
Expand Down Expand Up @@ -213,11 +404,11 @@ TfLiteStatus EvalSvdfInt8(TfLiteContext* context, TfLiteNode* node) {
} // namespace

TFLMRegistration Register_SVDF() {
return tflite::micro::RegisterOp(Init, PrepareSvdf, EvalSvdf);
return tflite::micro::RegisterOp(Init, CmsisNnPrepareSvdf, EvalSvdf);
}

TFLMRegistration Register_SVDF_INT8() {
return tflite::micro::RegisterOp(Init, PrepareSvdf, EvalSvdfInt8);
return tflite::micro::RegisterOp(Init, CmsisNnPrepareSvdf, EvalSvdfInt8);
}

} // namespace tflite
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ if [ -d ${DOWNLOADED_CMSIS_NN_PATH} ]; then
echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download."
else

ZIP_PREFIX_NN="dc64e488f6655aa2792d2aceca316c896f78b4db"
ZIP_PREFIX_NN="58f177057699d6d0a8d3af34c01c271202b6f85e"
CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip"
CMSIS_NN_MD5="80f9cf0bcc10a4aefb6531ae53942044"
CMSIS_NN_MD5="b9f05caa9cd9bb4b545bf4d7f8fc5274"

# wget is much faster than git clone of the entire repo. So we wget a specific
# version and can then apply a patch, as needed.
Expand Down
Loading