Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compression): extend interpreter to handle compressed tensors #3002

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tensorflow/lite/micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ tflm_cc_library(
"micro_context.h",
],
deps = [
":compression",
":micro_common",
":micro_graph",
":micro_log",
":micro_profiler",
"//tensorflow/lite:type_to_tflitetype",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro/kernels:decompress",
],
)

Expand Down Expand Up @@ -145,6 +149,7 @@ tflm_cc_library(
":memory_helpers",
":micro_allocator",
":micro_common",
":micro_context",
":micro_graph",
":micro_log",
":micro_profiler",
Expand Down Expand Up @@ -180,6 +185,7 @@ tflm_cc_library(
"micro_allocator.h",
],
deps = [
":compression",
":flatbuffer_utils",
":memory_helpers",
":micro_arena_constants",
Expand All @@ -192,6 +198,7 @@ tflm_cc_library(
"//tensorflow/lite/micro/arena_allocator:non_persistent_arena_buffer_allocator",
"//tensorflow/lite/micro/arena_allocator:persistent_arena_buffer_allocator",
"//tensorflow/lite/micro/arena_allocator:simple_memory_allocator",
"//tensorflow/lite/micro/compression:metadata_saved",
"//tensorflow/lite/micro/memory_planner:greedy_memory_planner",
"//tensorflow/lite/micro/memory_planner:linear_memory_planner",
"//tensorflow/lite/micro/memory_planner:micro_memory_planner",
Expand Down Expand Up @@ -245,7 +252,9 @@ tflm_cc_library(
"test_helpers.h",
],
deps = [
":compression",
":memory_helpers",
":micro_log",
":micro_utils",
":op_resolvers",
"//tensorflow/lite:type_to_tflitetype",
Expand Down
79 changes: 74 additions & 5 deletions tensorflow/lite/micro/fake_micro_context.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -23,10 +23,23 @@ limitations under the License.

namespace tflite {

FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph)
: graph_(*micro_graph), tensors_(tensors), allocator_(allocator) {}
FakeMicroContext::FakeMicroContext(
TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph
#ifdef USE_TFLM_COMPRESSION
,
const CompressedTensorList* compressed_tensors
#endif // USE_TFLM_COMPRESSION
)
: graph_(*micro_graph),
tensors_(tensors),
allocator_(allocator)
#ifdef USE_TFLM_COMPRESSION
,
compressed_tensors_(compressed_tensors)
#endif // USE_TFLM_COMPRESSION
{
}

TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
allocated_temp_count_++;
Expand Down Expand Up @@ -112,4 +125,60 @@ void* FakeMicroContext::external_context() { return nullptr; }

MicroGraph& FakeMicroContext::graph() { return graph_; }

#ifdef USE_TFLM_COMPRESSION

// Available during Prepare & Eval. Returns false if tensor is not
// compressed.
bool FakeMicroContext::IsTensorCompressed(const TfLiteNode* node,
int tensor_idx) {
if (compressed_tensors_ != nullptr && tensor_idx < node->inputs->size) {
int index = node->inputs->data[tensor_idx];
if (index >= 0 && compressed_tensors_->tensors[index] != nullptr) {
return true;
}
}

return false;
}

// Only available during Prepare. The kernel is responsible for storing the
// scratch buffer handle.
int FakeMicroContext::AllocateDecompressionScratchBuffer(const TfLiteNode* node,
int tensor_idx) {
if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
return -1;
}
int index = node->inputs->data[tensor_idx];
if (index < 0 || compressed_tensors_->tensors[index] == nullptr) {
return -1;
}
TfLiteTensor* tensor = &tensors_[index];
int scratch_index = -1;
TfLiteStatus result =
RequestScratchBufferInArena(tensor->bytes, &scratch_index);
if (result != kTfLiteOk) {
return -1;
}

return scratch_index;
}

// Available during Prepare & Eval. Returns nullptr if tensor is not
// compressed.
const CompressionTensorData* FakeMicroContext::GetTensorCompressionData(
const TfLiteNode* node, int tensor_idx) {
if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
return nullptr;
}

int index = node->inputs->data[tensor_idx];
if (index < 0) {
return nullptr;
}

return compressed_tensors_->tensors[index];
}

#endif // USE_TFLM_COMPRESSION

} // namespace tflite
36 changes: 34 additions & 2 deletions tensorflow/lite/micro/fake_micro_context.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,7 +30,12 @@ class FakeMicroContext : public MicroContext {
~FakeMicroContext() = default;

FakeMicroContext(TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
MicroGraph* micro_graph);
MicroGraph* micro_graph
#ifdef USE_TFLM_COMPRESSION
,
const CompressedTensorList* compressed_tensors = nullptr
#endif // USE_TFLM_COMPRESSION
);

void* AllocatePersistentBuffer(size_t bytes) override;
TfLiteStatus RequestScratchBufferInArena(size_t bytes,
Expand All @@ -50,6 +55,24 @@ class FakeMicroContext : public MicroContext {
void* external_context() override;
MicroGraph& graph() override;

#ifdef USE_TFLM_COMPRESSION

// Available during Prepare & Eval. Returns false if tensor is not
// compressed.
bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) override;

// Only available during Prepare. The kernel is responsible for storing the
// scratch buffer handle.
int AllocateDecompressionScratchBuffer(const TfLiteNode* node,
int tensor_idx) override;

// Available during Prepare & Eval. Returns nullptr if tensor is not
// compressed.
const CompressionTensorData* GetTensorCompressionData(
const TfLiteNode* node, int tensor_idx) override;

#endif // USE_TFLM_COMPRESSION

private:
static constexpr int kNumScratchBuffers_ = 12;

Expand All @@ -62,6 +85,15 @@ class FakeMicroContext : public MicroContext {

SingleArenaBufferAllocator* allocator_;

#ifdef USE_TFLM_COMPRESSION

//
// Compression
//
const CompressedTensorList* compressed_tensors_;

#endif // USE_TFLM_COMPRESSION

TF_LITE_REMOVE_VIRTUAL_DELETE
};

Expand Down
17 changes: 13 additions & 4 deletions tensorflow/lite/micro/kernels/kernel_runner.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/test_helpers.h"

namespace tflite {
namespace micro {
Expand All @@ -38,12 +37,22 @@ KernelRunner::KernelRunner(const TFLMRegistration& registration,
TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
const void* builtin_data,
TfLiteIntArray* intermediates)
TfLiteIntArray* intermediates
#ifdef USE_TFLM_COMPRESSION
,
const CompressedTensorList* compressed_tensors
#endif // USE_TFLM_COMPRESSION
)
: registration_(registration),
allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_,
kKernelRunnerBufferSize_)),
mock_micro_graph_(allocator_),
fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
fake_micro_context_(tensors, allocator_, &mock_micro_graph_
#ifdef USE_TFLM_COMPRESSION
,
compressed_tensors
#endif // USE_TFLM_COMPRESSION
) {
// Prepare TfLiteContext:
context_.impl_ = static_cast<void*>(&fake_micro_context_);
context_.ReportError = MicroContextReportOpError;
Expand Down
9 changes: 7 additions & 2 deletions tensorflow/lite/micro/kernels/kernel_runner.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,7 +36,12 @@ class KernelRunner {
KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors,
int tensors_size, TfLiteIntArray* inputs,
TfLiteIntArray* outputs, const void* builtin_data,
TfLiteIntArray* intermediates = nullptr);
TfLiteIntArray* intermediates = nullptr
#ifdef USE_TFLM_COMPRESSION
,
const CompressedTensorList* compressed_tensors = nullptr
#endif // USE_TFLM_COMPRESSION
);

// Calls init and prepare on the kernel (i.e. TFLMRegistration) struct.
// Any exceptions will be DebugLog'd and returned as a status code.
Expand Down
41 changes: 40 additions & 1 deletion tensorflow/lite/micro/kernels/kernel_util.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,13 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/micro_context.h"

#ifdef USE_TFLM_COMPRESSION

#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_utils.h"

#endif // USE_TFLM_COMPRESSION

namespace tflite {
namespace micro {

Expand Down Expand Up @@ -91,6 +98,38 @@ const T* GetOptionalTensorData(const TfLiteEvalTensor* tensor) {
: reinterpret_cast<const T*>(tensor->data.raw);
}

#ifdef USE_TFLM_COMPRESSION

// Overloads existing GetTensorData. If not compressed, this will return
// tensor->data.
template <typename T>
const T* GetTensorData(MicroContext* micro_context,
const TfLiteEvalTensor* tensor,
const CompressionTensorData* compression_data,
int scratch_buffer_handle) {
if (tensor == nullptr) {
return nullptr;
}
if (compression_data == nullptr) {
return reinterpret_cast<const T*>(tensor->data.data);
}

void* scratch_buffer = nullptr;
if (scratch_buffer_handle != -1) {
scratch_buffer = micro_context->GetScratchBuffer(scratch_buffer_handle);
} else {
size_t bytes_to_allocate = EvalTensorBytes(tensor);
scratch_buffer = micro_context->AllocateDecompressionMemory(
bytes_to_allocate, MicroArenaBufferAlignment());
}
TFLITE_DCHECK(scratch_buffer != nullptr);
void* uncompressed_data = micro_context->DecompressTensorToBuffer(
*tensor, *compression_data, scratch_buffer);
return reinterpret_cast<const T*>(uncompressed_data);
}

#endif // USE_TFLM_COMPRESSION

// Returns the shape of a TfLiteEvalTensor struct.
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);

Expand Down
Loading
Loading