Skip to content

Commit

Permalink
Optimized TFLM LSTM integration code for HiFi5/4 for 8x8 and 16x8 LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
cad-audio committed Sep 4, 2023
1 parent 378aa87 commit 44adbe5
Show file tree
Hide file tree
Showing 2 changed files with 416 additions and 113 deletions.
215 changes: 129 additions & 86 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ const int32_t kInt16Min = std::numeric_limits<int16_t>::min();

void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
int n_input, int16_t* output) {
#if defined(HIFI5) || defined(HIFI4)
WORD32 err;
err = xa_nn_elm_add_16x16_16(output, input_1, input_2, n_batch * n_input);
#else
#if !(defined(HIFI5) || defined(HIFI4))
for (int batch = 0; batch < n_batch; ++batch) {
for (int i = 0; i < n_input; ++i) {
const int index = batch * n_input + i;
Expand All @@ -124,6 +121,9 @@ void AddElementWise(const int16_t* input_1, const int16_t* input_2, int n_batch,
output[index] = static_cast<int16_t>(sum_clamped);
}
}
#else
WORD32 err;
err = xa_nn_elm_add_16x16_16(output, input_1, input_2, n_batch * n_input);
#endif
}

Expand All @@ -137,18 +137,13 @@ void AddElementWise(const float* input_1, const float* input_2, int n_batch,
}
}

#if !(defined(HIFI5) || defined(HIFI4))
void Sigmoid(const RuntimeShape& data_shape, int16_t* data) {
#if defined(HIFI5) || defined(HIFI4)
WORD32 err;
err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0,
data_shape.FlatSize());
#else
reference_integer_ops::Logistic(
0 /*data->input_multiplier*/, 0 /*data->input_left_shift */,
data_shape.FlatSize() /*NumElements(input->dims)*/,
data /* tflite::micro::GetTensorData<int16_t>(input) */,
data /*tflite::micro::GetTensorData<int16_t>(output) */);
#endif
}

void Sigmoid(const RuntimeShape& data_shape, float* data) {
Expand All @@ -160,24 +155,6 @@ void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
int16_t* output_data) {
int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
int32_t input_multiplier = 0;
#if defined(HIFI5) || defined(HIFI4)
#if (defined(USE_HIFI_ACT_TIE) && (defined(AE_TANH16X4X2) || defined(AE_TANH16X4)))
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
tanh_input_left_shift = -tanh_input_left_shift;
input_multiplier = 1;
}
#else
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
tanh_input_left_shift = -tanh_input_left_shift;
input_multiplier = 3;
}
#endif
WORD32 err;
err = xa_nn_vec_tanh_sym16s_sym16s(output_data, input_data, input_multiplier,
tanh_input_left_shift, input_data_shape.FlatSize());
#else
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
tanh_input_left_shift = -tanh_input_left_shift;
Expand All @@ -186,7 +163,6 @@ void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
reference_integer_ops::Tanh(input_multiplier, tanh_input_left_shift,
input_data_shape, input_data, output_data_shape,
output_data);
#endif
}

void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
Expand All @@ -200,35 +176,16 @@ void Tanh(int32_t cell_state_scale_power, const RuntimeShape& input_data_shape,
void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
const int16_t* input1_data, const int16_t* input2_data,
int8_t* output_data) {
#if defined(HIFI5) || defined(HIFI4)
WORD32 err;
err = xa_nn_elm_mul_sym16sxsym16s_asym8s(output_data, params.output_offset,
params.output_shift, params.output_multiplier,
params.quantized_activation_min, params.quantized_activation_max,
input1_data, input2_data, shape.FlatSize());
#else
return reference_integer_ops::MulElementwise(
shape.FlatSize(), params, input1_data, input2_data, output_data);
#endif
}

// Input and output have the same shape in LSTM
void Mul(const RuntimeShape& shape, const ArithmeticParams& params,
const int16_t* input1_data, const int16_t* input2_data,
int16_t* output_data) {
#if defined(HIFI5) || defined(HIFI4)
WORD32 err;
const RuntimeShape extended_shape = RuntimeShape::ExtendedShape(4, shape);
err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(output_data,
extended_shape.DimsData(), params.output_shift, params.output_multiplier,
params.quantized_activation_min, params.quantized_activation_max,
input1_data, extended_shape.DimsData(), input2_data,
extended_shape.DimsData());
return;
#else
return reference_integer_ops::MulElementwise(
shape.FlatSize(), params, input1_data, input2_data, output_data);
#endif
}

// Input and output have the same shape in LSTM
Expand All @@ -244,15 +201,107 @@ void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& filter_shape, const int8_t* filter_data,
const RuntimeShape& bias_shape, const int32_t* bias_data,
const RuntimeShape& output_shape, int16_t* output_data) {
#if defined(HIFI5) || defined(HIFI4)
return tflite::reference_integer_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}

void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& input_shape, const int16_t* input_data,
const RuntimeShape& filter_shape, const int8_t* filter_data,
const RuntimeShape& bias_shape, const int64_t* bias_data,
const RuntimeShape& output_shape, int16_t* output_data) {
return tflite::reference_integer_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}

void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& filter_shape, const float* filter_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data) {
return tflite::reference_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}
#else // #if !(defined(HIFI5) || defined(HIFI4))
void Sigmoid(int16_t* data, int32_t data_size) {
WORD32 err;
const int num_batches =
FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
const int output_depth =
output_shape.Dims(output_shape.DimensionsCount() - 1);
const int filter_dim_count = filter_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
err = xa_nn_vec_sigmoid_sym16s_sym16s(data, data, 0, 0,
data_size);
}

void Sigmoid(float* data, int32_t data_size) {
int data_dims[2] = {1, data_size};
RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
reference_ops::Logistic(data_shape, data, data_shape, data);
}

void Tanh(int32_t cell_state_scale_power, int16_t* input_data,
int16_t* output_data, int32_t data_size) {
int32_t tanh_input_left_shift = (15 + cell_state_scale_power) - 3;
int32_t input_multiplier = 0;
if (tanh_input_left_shift < 0) /* handling negative shift value */
{
tanh_input_left_shift = -tanh_input_left_shift;
#if (defined(USE_HIFI_ACT_TIE) && (defined(AE_TANH16X4X2) || defined(AE_TANH16X4)))
input_multiplier = 1;
#else
input_multiplier = 3;
#endif
}
WORD32 err;
err = xa_nn_vec_tanh_sym16s_sym16s(output_data, input_data, input_multiplier,
tanh_input_left_shift, data_size);
}

void Tanh(int32_t cell_state_scale_power, float* input_data, float* output_data,
int32_t data_size) {
int data_dims[2] = {1, data_size};
RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(data_dims));
reference_ops::Tanh(data_shape, input_data, data_shape,
output_data);
}

// Input and output have the same shape in LSTM
void Mul(const ArithmeticParams& params, const int16_t* input1_data,
const int16_t* input2_data, int8_t* output_data, int32_t data_size) {
WORD32 err;
err = xa_nn_elm_mul_sym16sxsym16s_asym8s(output_data, params.output_offset,
params.output_shift, params.output_multiplier,
params.quantized_activation_min, params.quantized_activation_max,
input1_data, input2_data, data_size);
}

// Input and output have the same shape in LSTM
void Mul(const ArithmeticParams& params, const int16_t* input1_data,
const int16_t* input2_data, int16_t* output_data, int32_t data_size) {
int dims_4D[4] = {1, 1, 1, data_size};
WORD32 err;
err = xa_nn_elm_mul_broadcast_4D_sym16sxsym16s_sym16s(output_data,
dims_4D, params.output_shift, params.output_multiplier,
params.quantized_activation_min, params.quantized_activation_max,
input1_data, dims_4D, input2_data, dims_4D);
return;
}

// Input and output have the same shape in LSTM
void Mul(const ArithmeticParams& params, const float* input1_data,
const float* input2_data, float* output_data, int32_t data_size) {
int dims_2D[2] = {1, data_size};
RuntimeShape data_shape(2, reinterpret_cast<const int32_t*>(dims_2D));
return reference_ops::Mul(params, data_shape, input1_data, data_shape,
input2_data, data_shape, output_data);
}

void FullyConnected(const FullyConnectedParams& params,
const int8_t* input_data, const int8_t* filter_data,
const int32_t* bias_data, int16_t* output_data,
const int num_batches, const int output_depth,
const int accum_depth) {
WORD32 err;
#pragma loop_count min=1
for(int b = 0; b < num_batches; b++) {
err = xa_nn_matXvec_out_stride_sym8sxasym8s_16(
output_data + b * output_depth, filter_data,
Expand All @@ -261,48 +310,39 @@ void FullyConnected(const FullyConnectedParams& params,
params.output_shift);
}
return;
#else
return tflite::reference_integer_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
#endif
}

void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& input_shape, const int16_t* input_data,
const RuntimeShape& filter_shape, const int8_t* filter_data,
const RuntimeShape& bias_shape, const int64_t* bias_data,
const RuntimeShape& output_shape, int16_t* output_data) {
#if defined(HIFI5) || defined(HIFI4)
const int16_t* input_data, const int8_t* filter_data,
const int64_t* bias_data, int16_t* output_data,
const int num_batches, const int output_depth,
const int accum_depth) {
WORD32 err;
const int num_batches =
FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
const int output_depth =
output_shape.Dims(output_shape.DimensionsCount() - 1);
const int filter_dim_count = filter_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);

err = xa_nn_matmul_sym8sxsym16s_sym16s(output_data, filter_data, input_data,
bias_data, output_depth, accum_depth, accum_depth, num_batches,
accum_depth, output_depth, 1, params.input_offset,
params.output_multiplier, params.output_shift, params.output_offset);
return;
#else
return tflite::reference_integer_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
#endif
}

void FullyConnected(const FullyConnectedParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& filter_shape, const float* filter_data,
const RuntimeShape& bias_shape, const float* bias_data,
const RuntimeShape& output_shape, float* output_data) {
const float* input_data, const float* filter_data,
const float* bias_data, float* output_data,
const int num_batches, const int output_depth,
const int accum_depth) {
int input_dims[2] = {num_batches, output_depth};
RuntimeShape input_shape(2, reinterpret_cast<const int32_t*>(input_dims));
RuntimeShape bias_shape(1, bias_data == NULL ? 0 : output_depth);
int filter_dims[2] = {output_depth, accum_depth};
RuntimeShape filter_shape(2, reinterpret_cast<const int32_t*>(filter_dims));
int output_dims[2] = {num_batches, output_depth};
RuntimeShape output_shape(2, reinterpret_cast<const int32_t*>(output_dims));
return tflite::reference_ops::FullyConnected(
params, input_shape, input_data, filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
}
#endif // #if !(defined(HIFI5) || defined(HIFI4))

void Clipping(const int v_size, const CellStateInfo& cell_state_info,
int16_t* vector) {
Expand Down Expand Up @@ -332,12 +372,13 @@ void UpdateLstmCell(const LstmStepManager& step_info,
const ArithmeticParams& forget_cell_mul_params,
const ArithmeticParams& input_mul_params,
const CellStateInfo& cell_state_info, int16_t* buffer) {

auto cell_state_shape = step_info.StateShape();
// Check offset validity to avoid memory overflow
TFLITE_DCHECK_LE(
step_info.CellStateOffset() + step_info.StateShape().FlatSize(),
step_info.CellStateOffset() + cell_state_shape.FlatSize(),
tflite::micro::GetTensorShape(cell_state).FlatSize());

auto cell_state_shape = step_info.StateShape();
WORD32 err;
// Multiplier is equivalent to 0.5 here so adding 1 to shifts
err = xa_nn_lstm_cell_state_update_16(
Expand Down Expand Up @@ -366,14 +407,15 @@ void UpdateLstmCell(const LstmStepManager& step_info,

auto cell_state_shape = step_info.StateShape();
// Forget Gate x Cell State
Mul(cell_state_shape, forget_cell_mul_params, forget_gate_output,
Mul(forget_cell_mul_params, forget_gate_output,
tflite::micro::GetTensorData<float>(cell_state) +
step_info.CellStateOffset(),
tflite::micro::GetTensorData<float>(cell_state) +
step_info.CellStateOffset());
step_info.CellStateOffset(),
cell_state_shape.FlatSize());
// Input Gate x Cell Gate
Mul(cell_state_shape, input_mul_params, input_gate_output, cell_gate_output,
buffer);
Mul(input_mul_params, input_gate_output, cell_gate_output, buffer,
cell_state_shape.FlatSize());

// Update the cell state
AddElementWise(tflite::micro::GetTensorData<float>(cell_state) +
Expand Down Expand Up @@ -449,5 +491,6 @@ RuntimeShape LstmStepManager::StateShape() const {
return RuntimeShape(2, dims_data);
}


} // namespace lstm_internal
} // namespace tflite
Loading

0 comments on commit 44adbe5

Please sign in to comment.