Skip to content

Commit

Permalink
Sync from upstream TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
TFLM-bot committed Nov 7, 2024
1 parent 9245002 commit d71186d
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 692 deletions.
6 changes: 6 additions & 0 deletions tensorflow/compiler/mlir/lite/schema/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ enum BuiltinOperator : int32 {
STABLEHLO_COMPOSITE = 206, // WARNING: No runtime support
STABLEHLO_SHIFT_LEFT = 207,
STABLEHLO_CBRT = 208, // WARNING: No runtime support
STABLEHLO_CASE = 209,
}
// LINT.ThenChange(nnapi_linter/linter.proto)

Expand Down Expand Up @@ -633,6 +634,7 @@ union BuiltinOptions2{
ReduceWindowOptions (deprecated),
StableHLOCompositeOptions,
StablehloShiftLeftOptions,
StablehloCaseOptions,
}

table StablehloGatherOptions{
Expand Down Expand Up @@ -777,6 +779,10 @@ table StablehloScatterOptions {
update_computation_subgraph_index: int;
}

table StablehloCaseOptions{
branch_subgraph_indices : [int];
}

enum RngAlgorithm : byte {
// An algorithm auto-selected by the system according to device type.
DEFAULT = 0,
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/builtin_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ typedef enum {
kTfLiteBuiltinStablehloComposite = 206,
kTfLiteBuiltinStablehloShiftLeft = 207,
kTfLiteBuiltinStablehloCbrt = 208,
kTfLiteBuiltinStablehloCase = 209,
} TfLiteBuiltinOperator;

#ifdef __cplusplus
Expand Down
47 changes: 47 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/schema/schema_generated.h"

// TODO(sosagarcia): Rework all function implementations to wrap around the
// compiler flatbuffer_conversions.
// LINT.IfChange
namespace tflite {

namespace {
Expand Down Expand Up @@ -928,6 +931,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseStablehloShiftLeft(op, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_STABLEHLO_CASE: {
return ParseStablehloCase(op, error_reporter, allocator, builtin_data);
}
// TODO: skip param parsing for now since ops below don't have kernels
case BuiltinOperator_STABLEHLO_SLICE:
case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
Expand Down Expand Up @@ -2421,6 +2427,46 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op,
return kTfLiteOk;
}

TfLiteStatus ParseStablehloCase(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);

SafeBuiltinDataAllocator safe_allocator(allocator);
auto params = safe_allocator.Allocate<TfLiteStablehloCaseParams>();

const StablehloCaseOptions* schema_params =
op->builtin_options_2_as_StablehloCaseOptions();
if (schema_params) {
auto LoadAttr =
[&error_reporter](
int32_t* params_array, const size_t params_array_size_bytes,
const flatbuffers::Vector<int32_t>* const flatbuffer_vector,
const char* const attr_name) -> TfLiteStatus {
TfLiteStatus status = FlatBufferIntVectorToArray(
params_array_size_bytes, flatbuffer_vector, params_array,
error_reporter, "stablehlo.case");
if (status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "Check the '%s' attribute.",
attr_name);
}
return status;
};

TF_LITE_ENSURE_STATUS(LoadAttr(params->branch_subgraph_indices,
sizeof(params->branch_subgraph_indices),
schema_params->branch_subgraph_indices(),
"branch subgraph indices"));
params->num_branches = schema_params->branch_subgraph_indices()->size();
*builtin_data = params.release();
return kTfLiteOk;
}
TF_LITE_REPORT_ERROR(error_reporter,
"Could not get 'stablehlo.case' operation parameters.");
return kTfLiteError;
}

// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
Expand Down Expand Up @@ -2943,3 +2989,4 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}

} // namespace tflite
// LINT.ThenChange(//tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc)
5 changes: 5 additions & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op,
BuiltinDataAllocator* allocator,
void** builtin_data);

TfLiteStatus ParseStablehloCase(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

} // namespace tflite

#endif // TENSORFLOW_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
Loading

0 comments on commit d71186d

Please sign in to comment.