From 66cbff59ae8e071ba84f0c84edda7375c7a0d0cb Mon Sep 17 00:00:00 2001 From: Louis Jean Date: Wed, 1 Mar 2023 11:20:04 +0000 Subject: [PATCH] feat(api): add labels in service info labels are not returned per default, call service info with qery parameter labels=1 --- src/dto/info.hpp | 7 +++++++ src/http/controller.hpp | 17 ++++++++++++++--- src/jsonapi.cc | 8 ++++---- src/jsonapi.h | 11 +++++++---- src/mlservice.h | 20 +++++++++++++++++++- tests/ut-oatpp.cc | 4 +++- tests/ut-oatpp.h | 3 +++ 7 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/dto/info.hpp b/src/dto/info.hpp index 87c20f83e..1d96e317b 100644 --- a/src/dto/info.hpp +++ b/src/dto/info.hpp @@ -116,6 +116,13 @@ namespace dd DTO_FIELD(Object, model_stats); DTO_FIELD(Vector, jobs); + DTO_FIELD_INFO(labels) + { + info->description + = "Labels for classification / detection / segmentation services"; + } + DTO_FIELD(Vector, labels); + DTO_FIELD(String, repository); DTO_FIELD(Int32, width); DTO_FIELD(Int32, height); diff --git a/src/http/controller.hpp b/src/http/controller.hpp index 87cee62f1..e90854036 100644 --- a/src/http/controller.hpp +++ b/src/http/controller.hpp @@ -78,7 +78,7 @@ class DedeController : public oatpp::web::server::api::ApiController oatpp::String qs_status = queryParams.get("status"); bool status = false; if (qs_status) - status = boost::lexical_cast(std::string(qs_status)); + status = boost::lexical_cast(*qs_status); auto hit = _oja->_mlservices.begin(); while (hit != _oja->_mlservices.end()) @@ -96,9 +96,20 @@ class DedeController : public oatpp::web::server::api::ApiController info->summary = "Retrieve a service detail"; } ENDPOINT("GET", "services/{service-name}", get_service, - PATH(oatpp::String, service_name, "service-name")) + PATH(oatpp::String, service_name, "service-name"), + QUERIES(QueryParams, queryParams)) { - auto janswer = _oja->service_status(service_name); + oatpp::String qs_status = queryParams.get("status"); + bool status = true; + if (qs_status) + status = boost::lexical_cast(*qs_status); + + oatpp::String qs_labels = queryParams.get("labels"); + bool labels = false; + if (qs_labels) + labels = boost::lexical_cast(*qs_labels); + + auto janswer = _oja->service_status(service_name, status, labels); return _oja->jdoc_to_response(janswer); } diff --git a/src/jsonapi.cc b/src/jsonapi.cc index 7284268c5..2689848c3 100644 --- a/src/jsonapi.cc +++ b/src/jsonapi.cc @@ -972,7 +972,8 @@ namespace dd return jsc; } - JDoc JsonAPI::service_status(const std::string &snamein) + JDoc JsonAPI::service_status(const std::string &snamein, bool status, + bool labels) { std::string sname(snamein); std::transform(snamein.begin(), snamein.end(), sname.begin(), ::tolower); @@ -982,8 +983,8 @@ namespace dd if (!this->service_exists(sname)) return dd_service_not_found_1002(sname); auto hit = this->get_service_it(sname); - auto status_dto - = mapbox::util::apply_visitor(visitor_info(true), (*hit).second); + auto status_dto = mapbox::util::apply_visitor(visitor_info(status, labels), + (*hit).second); JDoc jst = dd_ok_200(); JVal jbody(rapidjson::kObjectType); oatpp_utils::dtoToJVal(status_dto, jst, jbody); @@ -994,7 +995,6 @@ namespace dd JDoc JsonAPI::service_delete(const std::string &snamein, const std::string &jstr) { - std::string sname(snamein); std::transform(snamein.begin(), snamein.end(), sname.begin(), ::tolower); diff --git a/src/jsonapi.h b/src/jsonapi.h index 5113dbd8d..f00c21c73 100644 --- a/src/jsonapi.h +++ b/src/jsonapi.h @@ -95,9 +95,10 @@ namespace dd // return a JSON document for every API call JDoc info(const std::string &jstr) const; JDoc service_create(const std::string &sname, const std::string &jstr); - JDoc service_status(const std::string &sname); + JDoc service_status(const std::string &sname, bool status = true, + bool labels = false); + JDoc service_labels(const std::string &sname); JDoc service_delete(const std::string &sname, const std::string &jstr); - JDoc service_predict(const std::string &jstr); JDoc service_train(const std::string &jstr); @@ -129,7 +130,8 @@ namespace dd class visitor_info { public: - visitor_info(const bool &status) : _status(status) + visitor_info(const bool &status, const bool &labels = false) + : _status(status), _labels(labels) { } ~visitor_info() @@ -138,9 +140,10 @@ namespace dd template oatpp::Object operator()(T &mllib) { - return mllib.info(_status); + return mllib.info(_status, _labels); } bool _status = false; + bool _labels = false; }; } diff --git a/src/mlservice.h b/src/mlservice.h index e2c1b6011..1181d68f8 100644 --- a/src/mlservice.h +++ b/src/mlservice.h @@ -193,7 +193,8 @@ namespace dd * \brief get info about the service * @return info data object */ - oatpp::Object info(const bool &status) const + oatpp::Object info(const bool &status, + const bool &labels = false) const { // general info auto serv_dto = DTO::Service::createShared(); @@ -272,6 +273,23 @@ namespace dd } } + // labels + if (labels) + { + auto labels_vec = oatpp::Vector::createShared(); + + if (!this->_mlmodel._hcorresp.empty()) + { + labels_vec->reserve(this->_mlmodel._hcorresp.size()); + + for (const auto &kv : this->_mlmodel._hcorresp) + { + labels_vec->push_back(kv.second); + } + } + serv_dto->labels = labels_vec; + } + // stats this->_stats.to(serv_dto); return serv_dto; diff --git a/tests/ut-oatpp.cc b/tests/ut-oatpp.cc index b8ed2b8d0..6cb5b40e7 100644 --- a/tests/ut-oatpp.cc +++ b/tests/ut-oatpp.cc @@ -79,7 +79,7 @@ void test_services(std::shared_ptr client) ASSERT_EQ(201, d["status"]["code"].GetInt()); // service info - response = client->get_services(serv.c_str()); + response = client->get_service_with_labels(serv.c_str(), "1"); message = response->readBodyToString(); ASSERT_TRUE(message != nullptr); std::cout << "jstr=" << *message << std::endl; @@ -99,6 +99,8 @@ void test_services(std::shared_ptr client) ASSERT_TRUE(d["body"]["parameters"].HasMember("output")); ASSERT_EQ(d["body"]["parameters"]["input"]["connector"].GetString(), std::string("image")); + ASSERT_TRUE(d["body"].HasMember("labels")); + ASSERT_EQ(d["body"]["labels"].Size(), 0); // info call response = client->get_info(); diff --git a/tests/ut-oatpp.h b/tests/ut-oatpp.h index ece859f43..923f1d3c3 100644 --- a/tests/ut-oatpp.h +++ b/tests/ut-oatpp.h @@ -97,6 +97,9 @@ class DedeApiTestClient : public oatpp::web::client::ApiClient API_CALL("GET", "/info", get_info) API_CALL("GET", "/services/{service-name}", get_services, PATH(oatpp::String, service_name, "service-name")) + API_CALL("GET", "/services/{service-name}", get_service_with_labels, + PATH(oatpp::String, service_name, "service-name"), + QUERY(String, labels)) API_CALL("POST", "/services/{service-name}", post_services, PATH(oatpp::String, service_name, "service-name"), BODY_STRING(oatpp::String, service_data))