diff --git a/tensorflow/lite/micro/tools/reorder_ops.py b/tensorflow/lite/micro/tools/reorder_ops.py index 8f8945cc114..9341076de89 100644 --- a/tensorflow/lite/micro/tools/reorder_ops.py +++ b/tensorflow/lite/micro/tools/reorder_ops.py @@ -223,8 +223,8 @@ def process_operator_assign_variable(context: Context) -> VarHandles: assign_op.inputs_indices[0]) for read_var_op in context.get_pending_read_var_ops_by_handle( var_handle_id, assign_op.subgraph.index): - context.append_to_reordered_operations(read_var_op) context.remove_pending_op(read_var_op) + process_read_var_pending_ops(context, read_var_op) process_pending_ops(context) @@ -232,18 +232,75 @@ def process_operator_assign_variable(context: Context) -> VarHandles: def process_operator_call_once(context: Context) -> VarHandles: - assert False - return set() + op = context.current_op() + assert op.builtin_options_type == tflite.BuiltinOptions.CallOnceOptions + assert op.builtin_options is not None + subgraph_index: int = op.builtin_options.subgraph + var_handles: VarHandles = process_subgraph(context, subgraph_index) + for var_handle_id in var_handles: + for read_var_op in context.get_pending_read_var_ops_by_handle( + var_handle_id, op.subgraph.index): + context.remove_pending_op(read_var_op) + process_read_var_pending_ops(context, read_var_op) + + context.append_to_reordered_operations(op) + + return var_handles def process_operator_if(context: Context) -> VarHandles: - assert False - return set() + op = context.current_op() + assert op.builtin_options_type == tflite.BuiltinOptions.IfOptions + assert op.builtin_options is not None + then_subgraph_index: int = op.builtin_options.thenSubgraphIndex + else_subgraph_index: int = op.builtin_options.elseSubgraphIndex + var_handles: VarHandles = process_subgraph(context, then_subgraph_index) + var_handles |= process_subgraph(context, else_subgraph_index) + for var_handle_id in var_handles: + for read_var_op in context.get_pending_read_var_ops_by_handle( + var_handle_id, op.subgraph.index): + context.remove_pending_op(read_var_op) + process_read_var_pending_ops(context, read_var_op) + + process_pending_ops(context) + + return var_handles def process_operator_while(context: Context) -> VarHandles: - assert False - return set() + op = context.current_op() + assert op.builtin_options_type == tflite.BuiltinOptions.WhileOptions + assert op.builtin_options is not None + cond_subgraph_index: int = op.builtin_options.condSubgraphIndex + body_subgraph_index: int = op.builtin_options.bodySubgraphIndex + var_handles: VarHandles = process_subgraph(context, cond_subgraph_index) + var_handles |= process_subgraph(context, body_subgraph_index) + for var_handle_id in var_handles: + for read_var_op in context.get_pending_read_var_ops_by_handle( + var_handle_id, op.subgraph.index): + context.remove_pending_op(read_var_op) + process_read_var_pending_ops(context, read_var_op) + + process_pending_ops(context) + + return var_handles + + +def process_read_var_pending_ops(context: Context, + read_var_op: model_facade._Operator) -> None: + """Process READ_VARIABLE operator against any pending operators. + Then add the READ_VARIABLE operator to the list of reordered operations. + """ + for tensor_input in read_var_op.inputs_indices: + pending_op = context.get_pending_op(tensor_input, + read_var_op.subgraph.index) + if pending_op is not None: + context.remove_pending_op(pending_op) + context.push_current_op(pending_op) + process_pending_ops(context) + context.pop_current_op() + + context.append_to_reordered_operations(read_var_op) def process_pending_ops(context: Context) -> None: