Skip to content

Commit

Permalink
Use proto value get in example code (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
lia-viam authored Oct 7, 2024
1 parent 90246f0 commit 688d80a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 57 deletions.
15 changes: 7 additions & 8 deletions src/viam/examples/modules/complex/gizmo/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ std::string find_arg1(ResourceConfig cfg) {
buffer << gizmo_name << ": Required parameter `arg1` not found in configuration";
throw std::invalid_argument(buffer.str());
}

const ProtoValue& arg1_val = arg1->second;
if (arg1_val.is_a<std::string>() && !arg1_val.get_unchecked<std::string>().empty()) {
return arg1_val.get_unchecked<std::string>();
const auto* const arg1_string = arg1->second.get<std::string>();
if (!arg1_string || arg1_string->empty()) {
std::ostringstream buffer;
buffer << gizmo_name << ": Required non-empty string parameter `arg1`"
<< "` is either not a string or is an empty string";
throw std::invalid_argument(buffer.str());
}
std::ostringstream buffer;
buffer << gizmo_name << ": Required non-empty string parameter `arg1`"
<< "` is either not a string or is an empty string";
throw std::invalid_argument(buffer.str());
return *arg1_string;
}

void MyGizmo::reconfigure(const Dependencies& deps, const ResourceConfig& cfg) {
Expand Down
18 changes: 8 additions & 10 deletions src/viam/examples/modules/simple/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,15 @@ class Printer : public GenericService, public Reconfigurable {
buffer << printer_name << ": Required parameter `to_print` not found in configuration";
throw std::invalid_argument(buffer.str());
}
const ProtoValue& to_print_val = to_print->second;
if (to_print_val.is_a<std::string>() &&
!to_print_val.get_unchecked<std::string>().empty()) {
return to_print_val.get_unchecked<std::string>();
const auto* const to_print_string = to_print->second.get<std::string>();
if (!to_print_string || to_print_string->empty()) {
std::ostringstream buffer;
buffer << printer_name
<< ": Required non-empty string parameter `to_print` is either not a string "
"or is an empty string";
throw std::invalid_argument(buffer.str());
}

std::ostringstream buffer;
buffer << printer_name
<< ": Required non-empty string parameter `to_print` is either not a string "
"or is an empty string";
throw std::invalid_argument(buffer.str());
return *to_print_string;
}

private:
Expand Down
68 changes: 29 additions & 39 deletions src/viam/examples/modules/tflite/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,59 +298,55 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
<< ": Required parameter `model_path` not found in configuration";
throw std::invalid_argument(buffer.str());
}

const vsdk::ProtoValue& model_path_val = model_path->second;
if (!model_path_val.is_a<std::string>() ||
model_path_val.get_unchecked<std::string>().empty()) {
const auto* const model_path_string = model_path->second.get<std::string>();
if (!model_path_string || model_path_string->empty()) {
std::ostringstream buffer;
buffer << service_name
<< ": Required non-empty string parameter `model_path` is either not a string "
"or is an empty string";
throw std::invalid_argument(buffer.str());
}
const std::string& model_path_string = model_path_val.get_unchecked<std::string>();

// Process any tensor name remappings provided in the config.
auto remappings = attributes.find("tensor_name_remappings");
if (remappings != attributes.end()) {
if (!remappings->second.is_a<vsdk::ProtoStruct>()) {
const auto remappings_attributes = remappings->second.get<vsdk::ProtoStruct>();
if (!remappings_attributes) {
std::ostringstream buffer;
buffer << service_name
<< ": Optional parameter `tensor_name_remappings` must be a dictionary";
throw std::invalid_argument(buffer.str());
}
const auto remappings_attributes =
remappings->second.get_unchecked<vsdk::ProtoStruct>();

const auto populate_remappings = [](const vsdk::ProtoValue& source, auto& target) {
if (!source.is_a<vsdk::ProtoStruct>()) {
const auto source_attributes = source.get<vsdk::ProtoStruct>();
if (!source_attributes) {
std::ostringstream buffer;
buffer << service_name
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` "
"must be "
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` must be "
"dictionaries";
throw std::invalid_argument(buffer.str());
}
for (const auto& kv : source.get_unchecked<vsdk::ProtoStruct>()) {
for (const auto& kv : *source_attributes) {
const auto& k = kv.first;
if (!kv.second.is_a<std::string>()) {
const auto* const kv_string = kv.second.get<std::string>();
if (!kv_string) {
std::ostringstream buffer;
buffer << service_name
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` "
"must "
"be dictionaries with string values";
buffer
<< service_name
<< ": Fields `inputs` and `outputs` of `tensor_name_remappings` must "
"be dictionaries with string values";
throw std::invalid_argument(buffer.str());
}
target[kv.first] = kv.second.get_unchecked<std::string>();
target[kv.first] = *kv_string;
}
};

const auto inputs_where = remappings_attributes.find("inputs");
if (inputs_where != remappings_attributes.end()) {
const auto inputs_where = remappings_attributes->find("inputs");
if (inputs_where != remappings_attributes->end()) {
populate_remappings(inputs_where->second, state->input_name_remappings);
}
const auto outputs_where = remappings_attributes.find("outputs");
if (outputs_where != remappings_attributes.end()) {
const auto outputs_where = remappings_attributes->find("outputs");
if (outputs_where != remappings_attributes->end()) {
populate_remappings(outputs_where->second, state->output_name_remappings);
}
}
Expand All @@ -366,11 +362,11 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// buffer which we can use with `TfLiteModelCreate`. That
// still requires that the buffer be kept valid, but that's
// more easily done.
const std::ifstream in(model_path_string, std::ios::in | std::ios::binary);
const std::ifstream in(*model_path_string, std::ios::in | std::ios::binary);
if (!in) {
std::ostringstream buffer;
buffer << service_name << ": Failed to open file for `model_path` "
<< model_path_string;
<< *model_path_string;
throw std::invalid_argument(buffer.str());
}
std::ostringstream model_path_contents_stream;
Expand Down Expand Up @@ -405,27 +401,21 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// object to carry that information.
auto num_threads = attributes.find("num_threads");
if (num_threads != attributes.end()) {
auto throwError = [&] {
const auto* num_threads_double = num_threads->second.get<double>();
if (!num_threads_double || !std::isnormal(*num_threads_double) ||
(*num_threads_double < 0) ||
(*num_threads_double >= std::numeric_limits<std::int32_t>::max()) ||
(std::trunc(*num_threads_double) != *num_threads_double)) {
std::ostringstream buffer;
buffer << service_name
<< ": Value for field `num_threads` is not a positive integer";
<< ": Value for field `num_threads` is not a positive integer: "
<< *num_threads_double;
throw std::invalid_argument(buffer.str());
};

if (!num_threads->second.is_a<double>()) {
throwError();
}

double num_threads_double = num_threads->second.get_unchecked<double>();
if (!std::isnormal(num_threads_double) || (num_threads_double < 0) ||
(num_threads_double >= std::numeric_limits<std::int32_t>::max()) ||
(std::trunc(num_threads_double) != num_threads_double)) {
throwError();
}

state->interpreter_options.reset(TfLiteInterpreterOptionsCreate());
TfLiteInterpreterOptionsSetNumThreads(state->interpreter_options.get(),
static_cast<int32_t>(num_threads_double));
static_cast<int32_t>(*num_threads_double));
}

// Build the single interpreter.
Expand Down

0 comments on commit 688d80a

Please sign in to comment.