Skip to content

Commit

Permalink
Adds GetCurrentOperatorIndex to micro_interpreter (#2605)
Browse files Browse the repository at this point in the history
In this way an AOT can add metadata that could then be accessed with the couple GetCurrentSubgraphIndex, GetCurrentOperatorIndex.
These can be accessed from tflite::GetMicroContext(context)->graph()

Refers to #2593
BUG=#2593
  • Loading branch information
DanieleParravicini-Synthara authored Jun 24, 2024
1 parent 8cd499a commit dc4dcb7
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 36 deletions.
97 changes: 62 additions & 35 deletions tensorflow/lite/micro/micro_interpreter_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ MicroInterpreterGraph::MicroInterpreterGraph(
model_(model),
allocator_(allocator),
current_subgraph_index_(0),
current_operator_index_(0),
resource_variables_(resource_variables) {
if (model != nullptr) {
subgraphs_ = model->subgraphs();
Expand All @@ -54,17 +55,21 @@ MicroInterpreterGraph::~MicroInterpreterGraph() {}

TfLiteStatus MicroInterpreterGraph::InitSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
uint32_t previous_operator_idx = current_operator_index_;

for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
for (current_operator_index_ = 0; current_operator_index_ < operators_size;
++current_operator_index_) {
TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.node);
const TFLMRegistration* registration =
subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.registration;
size_t init_data_size;
const char* init_data;
if (registration->builtin_code == BuiltinOperator_CUSTOM) {
Expand All @@ -81,52 +86,62 @@ TfLiteStatus MicroInterpreterGraph::InitSubgraphs() {
}
}
current_subgraph_index_ = previous_subgraph_idx;
current_operator_index_ = previous_operator_idx;

return kTfLiteOk;
}

TfLiteStatus MicroInterpreterGraph::PrepareSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;

uint32_t previous_operator_idx = current_operator_index_;
for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
for (current_operator_index_ = 0; current_operator_index_ < operators_size;
++current_operator_index_) {
TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.node);
const TFLMRegistration* registration =
subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.registration;
if (registration->prepare != nullptr) {
TfLiteStatus prepare_status = registration->prepare(context_, node);
if (prepare_status != kTfLiteOk) {
MicroPrintf("Node %s (number %df) failed to prepare with status %d",
OpNameFromRegistration(registration), i, prepare_status);
OpNameFromRegistration(registration),
current_operator_index_, prepare_status);
return kTfLiteError;
}
}
allocator_->FinishPrepareNodeAllocations(/*node_id=*/i);
allocator_->FinishPrepareNodeAllocations(
/*node_id=*/current_operator_index_);
}
}
current_subgraph_index_ = previous_subgraph_idx;

current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}

TfLiteStatus MicroInterpreterGraph::ResetSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
uint32_t previous_operator_idx = current_operator_index_;

for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
for (current_operator_index_ = 0; current_operator_index_ < operators_size;
++current_operator_index_) {
TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.node);
const TFLMRegistration* registration =
subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.registration;
// registration is allocated outside the interpreter, so double check to
// make sure it's not nullptr;
if (registration != nullptr && registration->reset != nullptr) {
Expand All @@ -135,23 +150,28 @@ TfLiteStatus MicroInterpreterGraph::ResetSubgraphs() {
}
}
current_subgraph_index_ = previous_subgraph_idx;
current_operator_index_ = previous_operator_idx;

return kTfLiteOk;
}

TfLiteStatus MicroInterpreterGraph::FreeSubgraphs() {
int previous_subgraph_idx = current_subgraph_index_;
uint32_t previous_operator_idx = current_operator_index_;

for (size_t subgraph_idx = 0; subgraph_idx < subgraphs_->size();
subgraph_idx++) {
current_subgraph_index_ = subgraph_idx;
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
for (current_operator_index_ = 0; current_operator_index_ < operators_size;
++current_operator_index_) {
TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.node);
const TFLMRegistration* registration =
subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.registration;
// registration is allocated outside the interpreter, so double check to
// make sure it's not nullptr;
if (registration != nullptr && registration->free != nullptr) {
Expand All @@ -160,12 +180,14 @@ TfLiteStatus MicroInterpreterGraph::FreeSubgraphs() {
}
}
current_subgraph_index_ = previous_subgraph_idx;
current_operator_index_ = previous_operator_idx;

return kTfLiteOk;
}

TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
int previous_subgraph_idx = current_subgraph_index_;
uint32_t previous_operator_idx = current_operator_index_;
current_subgraph_index_ = subgraph_idx;

if (static_cast<size_t>(subgraph_idx) >= subgraphs_->size()) {
Expand All @@ -174,12 +196,15 @@ TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {
return kTfLiteError;
}
uint32_t operators_size = NumSubgraphOperators(model_, subgraph_idx);
for (size_t i = 0; i < operators_size; ++i) {
TfLiteNode* node =
&(subgraph_allocations_[subgraph_idx].node_and_registrations[i].node);
const TFLMRegistration* registration = subgraph_allocations_[subgraph_idx]
.node_and_registrations[i]
.registration;
for (current_operator_index_ = 0; current_operator_index_ < operators_size;
++current_operator_index_) {
TfLiteNode* node = &(subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.node);
const TFLMRegistration* registration =
subgraph_allocations_[subgraph_idx]
.node_and_registrations[current_operator_index_]
.registration;

// This ifdef is needed (even though ScopedMicroProfiler itself is a no-op with
// -DTF_LITE_STRIP_ERROR_STRINGS) because the function OpNameFromRegistration is
Expand All @@ -201,13 +226,15 @@ TfLiteStatus MicroInterpreterGraph::InvokeSubgraph(int subgraph_idx) {

if (invoke_status == kTfLiteError) {
MicroPrintf("Node %s (number %d) failed to invoke with status %d",
OpNameFromRegistration(registration), i, invoke_status);
OpNameFromRegistration(registration), current_operator_index_,
invoke_status);
return kTfLiteError;
} else if (invoke_status != kTfLiteOk) {
return invoke_status;
}
}
current_subgraph_index_ = previous_subgraph_idx;
current_operator_index_ = previous_operator_idx;
return kTfLiteOk;
}

Expand Down
8 changes: 7 additions & 1 deletion tensorflow/lite/micro/micro_interpreter_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ class MicroInterpreterGraph : public MicroGraph {
// to be the subgraph of that operator.
int GetCurrentSubgraphIndex() { return current_subgraph_index_; }

// Gets the list of alloctions for each subgraph. This is the source of truth
// Get the current operator index inside a subgraph.
// The couple GetCurrentSubgraphIndex GetCurrentSubgraphIndex creates a unique
// identifier of the operator inside the subgraph
int GetCurrentOperatorIndex() { return current_operator_index_; }

// Gets the list of allocations for each subgraph. This is the source of truth
// for all per-subgraph allocation data.
SubgraphAllocations* GetAllocations() { return subgraph_allocations_; }

Expand All @@ -99,6 +104,7 @@ class MicroInterpreterGraph : public MicroGraph {
MicroAllocator* allocator_;
SubgraphAllocations* subgraph_allocations_ = nullptr;
int current_subgraph_index_;
uint32_t current_operator_index_;
MicroResourceVariables* resource_variables_;
const flatbuffers::Vector<flatbuffers::Offset<SubGraph>>* subgraphs_;

Expand Down

0 comments on commit dc4dcb7

Please sign in to comment.