Skip to content

Commit

Permalink
Fixes to transpose_conv for optional bias tensor when compression is …
Browse files Browse the repository at this point in the history
…enabled.
  • Loading branch information
ddavis-2015 committed Dec 13, 2024
1 parent 2335754 commit 0171244
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions tensorflow/lite/micro/kernels/transpose_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<float>(
micro_context, filter, filter_comp_td, data.filter_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand All @@ -327,7 +327,7 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, filter_comp_td, data.filter_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down Expand Up @@ -384,7 +384,7 @@ TfLiteStatus TransposeConvEval(TfLiteContext* context, TfLiteNode* node) {
filter_comp_td,
data.filter_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int64_t>(
tflite::micro::GetOptionalTensorData<int64_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down
10 changes: 5 additions & 5 deletions tensorflow/lite/micro/kernels/xtensa/transpose_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<float>(
micro_context, filter, filter_comp_td, data.filter_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(micro_context, bias, bias_comp_td,
data.bias_scratch_index),
tflite::micro::GetOptionalTensorData<float>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
Expand Down Expand Up @@ -419,7 +419,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
filter_comp_td,
data.filter_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand All @@ -440,7 +440,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(
micro_context, filter, filter_comp_td, data.filter_scratch_index),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(
tflite::micro::GetOptionalTensorData<int32_t>(
micro_context, bias, bias_comp_td, data.bias_scratch_index),
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
Expand Down Expand Up @@ -558,7 +558,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
#else // USE_TFLM_COMPRESSION
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int64_t>(bias),
tflite::micro::GetTensorData<int64_t>(bias),
#endif // USE_TFLM_COMPRESSION
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output),
Expand Down

0 comments on commit 0171244

Please sign in to comment.