Skip to content

Commit

Permalink
RSDK-4196 Add support for flat tensors in mlmodel (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
acmorrow authored Aug 1, 2023
1 parent 212ef2a commit 5f19d94
Show file tree
Hide file tree
Showing 9 changed files with 525 additions and 129 deletions.
2 changes: 2 additions & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# performance-move-const-arg: TODO(RSDK-3019)
# misc-unused-parameters: TODO(RSDK-3008) remove this and fix all lints.
# readability-function-cognitive-complexity: No, complexity is subjective and sometimes necessary.
# readability-else-after-return: No, this causes code complexification
Checks: >
-*,
bugprone-*,
Expand All @@ -30,6 +31,7 @@ Checks: >
-performance-move-const-arg,
-misc-unused-parameters,
-readability-function-cognitive-complexity,
-readability-else-after-return,
WarningsAsErrors: '*'
FormatStyle: none
CheckOptions:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService {
if (!in) {
std::ostringstream buffer;
buffer << service_name << ": Failed to open file for `model_path` "
<< model_path_string;
<< *model_path_string;
throw std::invalid_argument(buffer.str());
}
std::ostringstream model_path_contents_stream;
Expand Down
106 changes: 69 additions & 37 deletions src/viam/sdk/services/mlmodel/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ namespace sdk {

namespace {

struct tensor_storage_and_views {
mlmodel_details::tensor_storage storage;
MLModelService::named_tensor_views views;
};
constexpr bool kUseFlatTensors = false;

} // namespace

Expand All @@ -40,49 +37,84 @@ std::shared_ptr<MLModelService::named_tensor_views> MLModelServiceClient::infer(
namespace pb = ::google::protobuf;
namespace mlpb = ::viam::service::mlmodel::v1;

pb::Arena arena;
auto* const req = pb::Arena::CreateMessage<mlpb::InferRequest>(&arena);
auto arena = std::make_unique<pb::Arena>();
auto* const req = pb::Arena::CreateMessage<mlpb::InferRequest>(arena.get());

req->set_name(this->name());
auto* const resp = pb::Arena::CreateMessage<mlpb::InferResponse>(arena.get());
grpc::ClientContext ctx;

auto& mutable_input_data = *req->mutable_input_data();
auto& mutable_input_data_fields = *mutable_input_data.mutable_fields();
if (!kUseFlatTensors) {
struct tensor_storage_and_views {
mlmodel_details::tensor_storage storage;
MLModelService::named_tensor_views views;
};
auto tsav = std::make_shared<tensor_storage_and_views>();

// TODO: Currently, this doesn't validate that we are passing the
// right input type in. We could query the metadata here and
// consult it.
for (const auto& kv : inputs) {
pb::Value& value = mutable_input_data_fields[kv.first];
mlmodel_details::tensor_to_pb_value(kv.second, &value);
}
auto& mutable_input_data = *req->mutable_input_data();
auto& mutable_input_data_fields = *mutable_input_data.mutable_fields();

auto* const resp = pb::Arena::CreateMessage<mlpb::InferResponse>(&arena);
// TODO: Currently, this doesn't validate that we are passing the
// right input type in. We could query the metadata here and
// consult it.
for (const auto& kv : inputs) {
pb::Value& value = mutable_input_data_fields[kv.first];
mlmodel_details::tensor_to_pb_value(kv.second, &value);
}

grpc::ClientContext ctx;
const auto result = stub_->Infer(&ctx, *req, resp);
if (!result.ok()) {
throw std::runtime_error(result.error_message());
}

const auto result = stub_->Infer(&ctx, *req, resp);
if (!result.ok()) {
throw std::runtime_error(result.error_message());
}
// TODO(RSDK-3298): This is an extra RPC on every inference, but
// it is not clear that caching it is safe.
const auto md = metadata();

const auto& output_fields = resp->output_data().fields();
for (const auto& output : md.outputs) {
const auto where = output_fields.find(output.name);
// Ignore any outputs for which we don't have metadata, since
// we can't know what type they should decode to.
if (where != output_fields.end()) {
mlmodel_details::pb_value_to_tensor(
output, where->second, &tsav->storage, &tsav->views);
}
}
auto* const tsav_views = &tsav->views;
return {std::move(tsav), tsav_views};
} else {
struct arena_and_views {
// NOTE: It is not necessary to capture the `resp` pointer
// here, since the lifetime of that object is subsumed by
// the arena.
std::unique_ptr<pb::Arena> arena;
MLModelService::named_tensor_views views;
};
auto aav = std::make_shared<arena_and_views>();
aav->arena = std::move(arena);

// TODO(RSDK-3298): This is an extra RPC on every inference, but
// it is not clear that caching it is safe.
const auto md = metadata();

auto tsav = std::make_shared<tensor_storage_and_views>();
const auto& output_fields = resp->output_data().fields();
for (const auto& output : md.outputs) {
const auto where = output_fields.find(output.name);
// Ignore any outputs for which we don't have metadata, since
// we can't know what type they should decode to.
if (where != output_fields.end()) {
mlmodel_details::pb_value_to_tensor(
output, where->second, &tsav->storage, &tsav->views);
auto& input_tensors = *req->mutable_input_tensors()->mutable_tensors();
for (const auto& kv : inputs) {
auto& emplaced = input_tensors[kv.first];
mlmodel_details::copy_sdk_tensor_to_api_tensor(kv.second, &emplaced);
}
}

auto* const tsav_views = &tsav->views;
return {std::move(tsav), tsav_views};
const auto result = stub_->Infer(&ctx, *req, resp);
if (!result.ok()) {
throw std::runtime_error(result.error_message());
}

for (const auto& kv : resp->output_tensors().tensors()) {
// NOTE: We don't need to pass in tensor storage here,
// because the backing store for the views is the Arena we
// moved into our result above.
auto tensor = mlmodel_details::make_sdk_tensor_from_api_tensor(kv.second);
aav->views.emplace(kv.first, std::move(tensor));
}
auto* const tsav_views = &aav->views;
return {std::move(aav), tsav_views};
}
}

struct MLModelService::metadata MLModelServiceClient::metadata() {
Expand Down
47 changes: 47 additions & 0 deletions src/viam/sdk/services/mlmodel/mlmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,53 @@ const char* MLModelService::tensor_info::data_type_to_string(const data_types da
return nullptr;
}

MLModelService::tensor_info::data_types MLModelService::tensor_info::tensor_views_to_data_type(
const tensor_views& view) {
class visitor : public boost::static_visitor<data_types> {
public:
data_types operator()(const MLModelService::tensor_view<std::int8_t>& t) const {
return data_types::k_int8;
}

data_types operator()(const MLModelService::tensor_view<std::uint8_t>& t) const {
return data_types::k_uint8;
}

data_types operator()(const MLModelService::tensor_view<std::int16_t>& t) const {
return data_types::k_int16;
}

data_types operator()(const MLModelService::tensor_view<std::uint16_t>& t) const {
return data_types::k_uint16;
}

data_types operator()(const MLModelService::tensor_view<std::int32_t>& t) const {
return data_types::k_int32;
}

data_types operator()(const MLModelService::tensor_view<std::uint32_t>& t) const {
return data_types::k_uint32;
}

data_types operator()(const MLModelService::tensor_view<std::int64_t>& t) const {
return data_types::k_int64;
}

data_types operator()(const MLModelService::tensor_view<std::uint64_t>& t) const {
return data_types::k_uint64;
}

data_types operator()(const MLModelService::tensor_view<float>& t) const {
return data_types::k_float32;
}

data_types operator()(const MLModelService::tensor_view<double>& t) const {
return data_types::k_float64;
}
};
return boost::apply_visitor(visitor(), view);
}

MLModelService::MLModelService(std::string name) : Service(std::move(name)) {}

namespace {
Expand Down
2 changes: 2 additions & 0 deletions src/viam/sdk/services/mlmodel/mlmodel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class MLModelService : public Service {

static boost::optional<data_types> string_to_data_type(const std::string& str);
static const char* data_type_to_string(data_types data_type);

static data_types tensor_views_to_data_type(const tensor_views& view);
};

struct metadata {
Expand Down
Loading

0 comments on commit 5f19d94

Please sign in to comment.