diff --git a/docs/api.md b/docs/api.md index f0277afef..3a89b30a8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1052,7 +1052,7 @@ sentences | bool | yes | false characters | bool | yes | false | character-level text processing, as opposed to word-based text processing sequence | int | yes | N/A | for character-level text processing, the fixed length of each sample of text read_forward | bool | yes | false | for character-level text processing, whether to read content from left to right -alphabet | string | yes | abcdefghijklmnopqrstuvwxyz 0123456789 ,;.!?:'"/\\\ \|_@#$%^&*~\`+-=<>()[]{} | for character-level text processing, the alphabet of recognized symbols +alphabet | string | yes | abcdefghijklmnopqrstuvwxyz 0123456789 ,;.!?:'"/\\\ \|\_@#$%^&\*~\`+-=<>()[]{} | for character-level text processing, the alphabet of recognized symbols sparse | bool | yes | false | whether to use sparse features (and sparce computations with Caffe for huge memory savings, for xgboost use `svm` connector instead) - SVM (`svm`) @@ -1069,6 +1069,7 @@ network | object | yes | empty | Output netw measure | array | yes | empty | Output measures requested, from `acc`: accuracy, `acc-k`: top-k accuracy, replace k with number (e.g. `acc-5`), `f1`: f1, precision and recall, `mcll`: multi-class log loss, `auc`: area under the curve, `cmdiag`: diagonal of confusion matrix (requires `f1`), `cmfull`: full confusion matrix (requires `f1`), `mcc`: Matthews correlation coefficient confidence_threshold | double | yes | 0.0 | only returns classifications or detections with probability strictly above threshold bbox | bool | yes | false | returns bounding boxes around object when using an object detection model, such that (xmin,ymax) yields the top left corner and (xmax,ymin) the lower right corner of a box. +best_bbox | int | yes | -1 | if > 0, returns only the `best_bbox` with highest confidence regression | bool | yes | false | whether the output of a model is a regression target (i.e. vector of one or more floats) rois | string | yes | empty | set the ROI layer from which to extract the features from bounding boxes. Both the boxes and features ar returned when using an object detection model with ROI pooling layer index | bool | yes | false | whether to index the output from prediction, for similarity search diff --git a/src/backends/caffe/caffelib.cc b/src/backends/caffe/caffelib.cc index 8cd064412..5299bbcf5 100644 --- a/src/backends/caffe/caffelib.cc +++ b/src/backends/caffe/caffelib.cc @@ -2802,6 +2802,7 @@ namespace dd int blank_label = -1; std::string roi_layer; double confidence_threshold = 0.0; + int best_bbox = -1; if (ad_output.has("confidence_threshold")) { try @@ -2816,6 +2817,8 @@ namespace dd ad_output.get("confidence_threshold").get()); } } + if (ad_output.has("best_bbox")) + best_bbox = ad_output.get("best_bbox").get(); if (inputc._timeserie && ad.getobj("parameters").getobj("input").has("timesteps")) @@ -3247,6 +3250,9 @@ namespace dd int curi = -1; while (true && k < results_height) { + if (best_bbox > 0 + && bboxes.size() >= static_cast(best_bbox)) + break; if (outr[0] == -1) { // skipping invalid detection diff --git a/src/backends/ncnn/ncnnlib.cc b/src/backends/ncnn/ncnnlib.cc index 69f580a78..f311d8140 100644 --- a/src/backends/ncnn/ncnnlib.cc +++ b/src/backends/ncnn/ncnnlib.cc @@ -252,6 +252,10 @@ namespace dd for (int i = 0; i < inputc._out.at(b).h; i++) { const float *values = inputc._out.at(b).row(i); + if (output_params->best_bbox > 0 + && bboxes.size() + >= static_cast(output_params->best_bbox)) + break; if (values[1] < output_params->confidence_threshold) break; // output is sorted by confidence diff --git a/src/backends/tensorrt/tensorrtlib.cc b/src/backends/tensorrt/tensorrtlib.cc index a7e728cc6..33d0d27bd 100644 --- a/src/backends/tensorrt/tensorrtlib.cc +++ b/src/backends/tensorrt/tensorrtlib.cc @@ -811,6 +811,11 @@ namespace dd int curi = -1; while (true && k < results_height) { + if (output_params->best_bbox > 0 + && bboxes.size() >= static_cast( + output_params->best_bbox)) + break; + if (outr[0] == -1) { // skipping invalid detection diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 9d8fbe548..bd43013b6 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -1131,6 +1131,7 @@ namespace dd bool bbox = _bbox; double confidence_threshold = 0.0; int best_count = _nclasses; + int best_bbox = -1; if (params.has("mllib")) { @@ -1188,7 +1189,13 @@ namespace dd } } if (output_params.has("best")) - best_count = output_params.get("best").get(); + { + best_count = output_params.get("best").get(); + } + if (output_params.has("best_bbox")) + { + best_bbox = output_params.get("best_bbox").get(); + } bool lstm_continuation = false; TInputConnectorStrategy inputc(this->_inputc); @@ -1388,6 +1395,10 @@ namespace dd for (int j = 0; j < labels_tensor.size(0); ++j) { + if (best_bbox > 0 + && bboxes.size() >= static_cast(best_bbox)) + break; + double score = score_acc[j]; if (score < confidence_threshold) continue; diff --git a/src/dto/output_connector.hpp b/src/dto/output_connector.hpp index f662d324c..242408ff5 100644 --- a/src/dto/output_connector.hpp +++ b/src/dto/output_connector.hpp @@ -43,6 +43,7 @@ namespace dd DTO_FIELD(Boolean, ctc) = false; DTO_FIELD(Float32, confidence_threshold) = 0.0; DTO_FIELD(Int32, best); + DTO_FIELD(Int32, best_bbox) = -1; /* ncnn */ DTO_FIELD(Int32, blank_label) = -1; diff --git a/tests/ut-torchapi.cc b/tests/ut-torchapi.cc index 023233546..8e492fbf1 100644 --- a/tests/ut-torchapi.cc +++ b/tests/ut-torchapi.cc @@ -253,7 +253,6 @@ TEST(torchapi, service_predict_object_detection) "\"width\":224},\"output\":{\"bbox\":true, " "\"confidence_threshold\":0.8}},\"data\":[\"" + detect_repo + "cat.jpg\"]}"; - // TODO changer image test ? joutstr = japi.jrender(japi.service_predict(jpredictstr)); JDoc jd; @@ -274,6 +273,23 @@ TEST(torchapi, service_predict_object_detection) && bbox["ymax"].GetDouble() > 300); // Check confidence threshold ASSERT_TRUE(preds[preds.Size() - 1]["prob"].GetDouble() >= 0.8); + + // best + jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{" + "\"input\":{\"height\":224," + "\"width\":224},\"output\":{\"bbox\":true, " + "\"best_bbox\":3}},\"data\":[\"" + + detect_repo + "cat.jpg\"]}"; + joutstr = japi.jrender(japi.service_predict(jpredictstr)); + std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(200, jd["status"]["code"]); + ASSERT_TRUE(jd["body"]["predictions"].IsArray()); + + auto &preds_best = jd["body"]["predictions"][0]["classes"]; + ASSERT_EQ(preds_best.Size(), 3); } TEST(torchapi, service_predict_txt_classification)