Skip to content

Commit

Permalink
support operators that invoke subgraphs.
Browse files Browse the repository at this point in the history
bug fixes involving READ_VARIABLE and VAR_HAND:E ordering.
  • Loading branch information
ddavis-2015 committed Dec 22, 2024
1 parent c0604d8 commit 23e1c7e
Showing 1 changed file with 64 additions and 7 deletions.
71 changes: 64 additions & 7 deletions tensorflow/lite/micro/tools/reorder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,84 @@ def process_operator_assign_variable(context: Context) -> VarHandles:
assign_op.inputs_indices[0])
for read_var_op in context.get_pending_read_var_ops_by_handle(
var_handle_id, assign_op.subgraph.index):
context.append_to_reordered_operations(read_var_op)
context.remove_pending_op(read_var_op)
process_read_var_pending_ops(context, read_var_op)

process_pending_ops(context)

return set([var_handle_id])


def process_operator_call_once(context: Context) -> VarHandles:
assert False
return set()
op = context.current_op()
assert op.builtin_options_type == tflite.BuiltinOptions.CallOnceOptions
assert op.builtin_options is not None
subgraph_index: int = op.builtin_options.subgraph
var_handles: VarHandles = process_subgraph(context, subgraph_index)
for var_handle_id in var_handles:
for read_var_op in context.get_pending_read_var_ops_by_handle(
var_handle_id, op.subgraph.index):
context.remove_pending_op(read_var_op)
process_read_var_pending_ops(context, read_var_op)

context.append_to_reordered_operations(op)

return var_handles


def process_operator_if(context: Context) -> VarHandles:
assert False
return set()
op = context.current_op()
assert op.builtin_options_type == tflite.BuiltinOptions.IfOptions
assert op.builtin_options is not None
then_subgraph_index: int = op.builtin_options.thenSubgraphIndex
else_subgraph_index: int = op.builtin_options.elseSubgraphIndex
var_handles: VarHandles = process_subgraph(context, then_subgraph_index)
var_handles |= process_subgraph(context, else_subgraph_index)
for var_handle_id in var_handles:
for read_var_op in context.get_pending_read_var_ops_by_handle(
var_handle_id, op.subgraph.index):
context.remove_pending_op(read_var_op)
process_read_var_pending_ops(context, read_var_op)

process_pending_ops(context)

return var_handles


def process_operator_while(context: Context) -> VarHandles:
assert False
return set()
op = context.current_op()
assert op.builtin_options_type == tflite.BuiltinOptions.WhileOptions
assert op.builtin_options is not None
cond_subgraph_index: int = op.builtin_options.condSubgraphIndex
body_subgraph_index: int = op.builtin_options.bodySubgraphIndex
var_handles: VarHandles = process_subgraph(context, cond_subgraph_index)
var_handles |= process_subgraph(context, body_subgraph_index)
for var_handle_id in var_handles:
for read_var_op in context.get_pending_read_var_ops_by_handle(
var_handle_id, op.subgraph.index):
context.remove_pending_op(read_var_op)
process_read_var_pending_ops(context, read_var_op)

process_pending_ops(context)

return var_handles


def process_read_var_pending_ops(context: Context,
read_var_op: model_facade._Operator) -> None:
"""Process READ_VARIABLE operator against any pending operators.
Then add the READ_VARIABLE operator to the list of reordered operations.
"""
for tensor_input in read_var_op.inputs_indices:
pending_op = context.get_pending_op(tensor_input,
read_var_op.subgraph.index)
if pending_op is not None:
context.remove_pending_op(pending_op)
context.push_current_op(pending_op)
process_pending_ops(context)
context.pop_current_op()

context.append_to_reordered_operations(read_var_op)


def process_pending_ops(context: Context) -> None:
Expand Down

0 comments on commit 23e1c7e

Please sign in to comment.