Skip to content

Commit

Permalink
Sync from upstream TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
TFLM-bot committed Sep 19, 2023
1 parent 127f88f commit 4d91d38
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 38 deletions.
91 changes: 53 additions & 38 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -872,44 +872,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return kTfLiteOk;
}
case BuiltinOperator_STABLEHLO_SCATTER: {
auto params = safe_allocator.Allocate<TfLiteStablehloScatterParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* shlo_scatter_params =
op->builtin_options_2_as_StablehloScatterOptions()) {
params->indices_are_sorted = shlo_scatter_params->indices_are_sorted();

TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray<int64_t>(
shlo_scatter_params->update_window_dims()->size() * sizeof(int64_t),
shlo_scatter_params->update_window_dims(),
params->update_window_dims, error_reporter, "stablehlo_scatter"));
params->num_update_window_dims =
shlo_scatter_params->update_window_dims()->size();

TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray<int64_t>(
shlo_scatter_params->inserted_window_dims()->size() *
sizeof(int64_t),
shlo_scatter_params->inserted_window_dims(),
params->inserted_window_dims, error_reporter, "stablehlo_scatter"));
params->num_inserted_window_dims =
shlo_scatter_params->inserted_window_dims()->size();

TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray<int64_t>(
shlo_scatter_params->scatter_dims_to_operand_dims()->size() *
sizeof(int64_t),
shlo_scatter_params->scatter_dims_to_operand_dims(),
params->scatter_dims_to_operand_dims, error_reporter,
"stablehlo_scatter"));
params->num_scatter_dims_to_operand_dims =
shlo_scatter_params->scatter_dims_to_operand_dims()->size();

params->index_vector_dim = shlo_scatter_params->index_vector_dim();
params->unique_indices = shlo_scatter_params->unique_indices();
params->update_computation_subgraph_index =
shlo_scatter_params->update_computation_subgraph_index();
}

*builtin_data = params.release();
return kTfLiteOk;
return ParseStablehloScatter(op, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR: {
return ParseStablehloRngBitGenerator(op, error_reporter, allocator,
Expand Down Expand Up @@ -2100,6 +2063,58 @@ TfLiteStatus ParseResizeNearestNeighbor(const Operator* op,
return kTfLiteOk;
}

TfLiteStatus ParseStablehloScatter(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);

SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteStablehloScatterParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteStablehloScatterParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);

const StablehloScatterOptions* schema_params =
op->builtin_options_2_as_StablehloScatterOptions();
if (schema_params) {
params->indices_are_sorted = schema_params->indices_are_sorted();

TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray<int64_t>(
schema_params->update_window_dims()->size() * sizeof(int64_t),
schema_params->update_window_dims(), params->update_window_dims,
error_reporter, "stablehlo_scatter"));
params->num_update_window_dims =
schema_params->update_window_dims()->size();

TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray<int64_t>(
schema_params->inserted_window_dims()->size() * sizeof(int64_t),
schema_params->inserted_window_dims(), params->inserted_window_dims,
error_reporter, "stablehlo_scatter"));
params->num_inserted_window_dims =
schema_params->inserted_window_dims()->size();

TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray<int64_t>(
schema_params->scatter_dims_to_operand_dims()->size() * sizeof(int64_t),
schema_params->scatter_dims_to_operand_dims(),
params->scatter_dims_to_operand_dims, error_reporter,
"stablehlo_scatter"));
params->num_scatter_dims_to_operand_dims =
schema_params->scatter_dims_to_operand_dims()->size();

params->index_vector_dim = schema_params->index_vector_dim();
params->unique_indices = schema_params->unique_indices();
params->update_computation_subgraph_index =
schema_params->update_computation_subgraph_index();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}

TfLiteStatus ParseStablehloRngBitGenerator(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,11 @@ TfLiteStatus ParseRightShift(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

TfLiteStatus ParseStablehloScatter(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

TfLiteStatus ParseStablehloRngBitGenerator(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
Expand Down

0 comments on commit 4d91d38

Please sign in to comment.