Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix upstream sync #2728

Merged
merged 13 commits into from
Oct 29, 2024
1 change: 1 addition & 0 deletions tensorflow/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ cc_library(
srcs = ["array.cc"],
hdrs = ["array.h"],
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/core/c:common",
],
)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/lite/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "tensorflow/lite/array.h"

#include "tensorflow/lite/c/common.h"

namespace tflite {
namespace array_internal {

Expand Down
9 changes: 7 additions & 2 deletions tensorflow/lite/kernels/internal/reference/batch_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
const float* scaling_factors,
const int32_t* input_offset, int32_t* row_sums,
const RuntimeShape& output_shape, float* output_data,
bool* compute_row_sums) {
bool* compute_row_sums,
const float* per_channel_scales) {
const RuntimeShape extended_lhs_shape =
RuntimeShape::ExtendedShape(5, lhs_shape);
const RuntimeShape extended_rhs_shape =
Expand Down Expand Up @@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
int32_t row_sum = woff_ptr2[i];
total -= row_sum * batch_offset;
int idx = lhs_rows * j + i;
out_ptr[idx] += batch_scaling_factor * total;
float scale = batch_scaling_factor;
if (per_channel_scales) {
scale *= per_channel_scales[i];
}
out_ptr[idx] += scale * total;
}
}
}
Expand Down
Loading