Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(torch): add inference visualisation at test time #1518

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,7 @@ namespace dd
{
APIData ad_res;
APIData ad_bbox;
APIData ad_res_bbox;
APIData ad_out = ad.getobj("parameters").getobj("output");
int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses;

Expand Down Expand Up @@ -2159,6 +2160,59 @@ namespace dd
ad_bbox_per_iou[iou_thres].add(std::to_string(entry_id),
vbad);
}

// Raw results
APIData bad;
// predictions
auto bboxes_acc = bboxes_tensor.accessor<float, 2>();
auto labels_acc = labels_tensor.accessor<int64_t, 1>();
auto score_acc = score_tensor.accessor<float, 1>();
std::vector<APIData> pred_vad;

for (int k = 0; k < labels_tensor.size(0); k++)
{
APIData pred_ad;
pred_ad.add("label", labels_acc[k]);
pred_ad.add("prob", static_cast<double>(score_acc[k]));
APIData bbox_ad;
bbox_ad.add("xmin", static_cast<double>(bboxes_acc[k][0]));
bbox_ad.add("ymin", static_cast<double>(bboxes_acc[k][1]));
bbox_ad.add("xmax", static_cast<double>(bboxes_acc[k][2]));
bbox_ad.add("ymax", static_cast<double>(bboxes_acc[k][3]));
pred_ad.add("bbox", bbox_ad);
pred_vad.push_back(pred_ad);
}
bad.add("predictions", pred_vad);
// targets
auto targ_bboxes_acc = targ_bboxes.accessor<float, 2>();
auto targ_labels_acc = targ_labels.accessor<int64_t, 1>();
std::vector<APIData> targ_vad;

for (int k = start; k < stop; k++)
{
APIData targ_ad;
targ_ad.add("label", targ_labels_acc[k]);
APIData bbox_ad;
bbox_ad.add("xmin",
static_cast<double>(targ_bboxes_acc[k][0]));
bbox_ad.add("ymin",
static_cast<double>(targ_bboxes_acc[k][1]));
bbox_ad.add("xmax",
static_cast<double>(targ_bboxes_acc[k][2]));
bbox_ad.add("ymax",
static_cast<double>(targ_bboxes_acc[k][3]));
targ_ad.add("bbox", bbox_ad);
targ_vad.push_back(targ_ad);
}
bad.add("targets", targ_vad);
// pred image
std::vector<cv::Mat> img_vec;
img_vec.push_back(torch_utils::tensorToImage(
batch.data.at(0).index(
{ torch::indexing::Slice(i, i + 1) }),
/* rgb = */ true));
bad.add("image", img_vec);
ad_res_bbox.add(std::to_string(entry_id), bad);
++entry_id;
}
}
Expand Down Expand Up @@ -2340,12 +2394,17 @@ namespace dd
ad_bbox_per_iou[iou_thres]);
}
ad_res.add("0", ad_bbox);
// raw bbox results
ad_res.add("raw_bboxes", ad_res_bbox);
}
else if (_segmentation)
ad_res.add("segmentation", true);
ad_res.add("batch_size",
entry_id); // here batch_size = tested entries count
SupervisedOutput::measure(ad_res, ad_out, out, test_id, test_name);
SupervisedOutput::create_visuals(
ad_res, ad_out, this->_mlmodel._repo + this->_mlmodel._visuals_dir,
test_id);
_module.train();
return 0;
}
Expand Down
9 changes: 8 additions & 1 deletion src/backends/torch/torchutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ namespace dd
torch_utils::copy_weights(jit_module, module, device, logger, strict);
}

cv::Mat tensorToImage(torch::Tensor tensor)
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb)
{
// 4 channels: batch size, chan, width, height
auto dims = tensor.sizes();
Expand Down Expand Up @@ -285,6 +285,13 @@ namespace dd
}
}
}

// convert to bgr
if (rgb)
{
cv::cvtColor(vals_mat, vals_mat, cv::COLOR_RGB2BGR);
}

return vals_mat;
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/backends/torch/torchutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,9 @@ namespace dd

/** Converts a tensor to a CV image that can be saved on the disk.
* XXX(louis) this function is currently debug only, and makes strong
* assumptions on the input tensor format. */
cv::Mat tensorToImage(torch::Tensor tensor);
* assumptions on the input tensor format.
* \param rgb wether the tensor image is rgb */
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb = false);
}
}
#endif
1 change: 1 addition & 0 deletions src/mlmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ namespace dd
std::string _corresp; /**< file name of the class correspondences (e.g.
house / 23) */
std::string _best_model_filename = "/best_model.txt";
std::string _visuals_dir = "/visuals";

#ifdef USE_SIMSEARCH
#ifdef USE_ANNOY
Expand Down
58 changes: 58 additions & 0 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <sstream>
#include <iomanip>

#include <opencv2/opencv.hpp>

#include "dto/output_connector.hpp"
#include "dto/predict_out.hpp"

Expand Down Expand Up @@ -1402,6 +1404,62 @@ namespace dd
ad_out.add("measure", meas_obj);
}

/** Create visuals from test results and write them in the model
* repository. */
static void create_visuals(const APIData &ad_res, APIData &ad_out,
const std::string &visuals_folder, int test_id)
{
(void)ad_out;
int iteration = static_cast<int>(ad_res.get("iteration").get<double>());
bool bbox = ad_res.has("bbox") && ad_res.get("bbox").get<bool>();
std::string targ_dest_folder
= visuals_folder + "/target/test" + std::to_string(test_id);
std::string dest_folder = visuals_folder + "/iteration"
+ std::to_string(iteration) + "/test"
+ std::to_string(test_id);
fileops::create_dir(dest_folder, 0755);

cv::Scalar colors[]
= { { 255, 0, 0 }, { 0, 255, 0 }, { 0, 0, 255 },
{ 0, 255, 255 }, { 255, 0, 255 }, { 255, 255, 0 },
{ 255, 127, 127 }, { 127, 255, 127 }, { 127, 127, 255 } };
int ncolors = sizeof(colors) / sizeof(cv::Scalar);

if (bbox)
{
APIData images_data = ad_res.getobj("raw_bboxes");

for (size_t i = 0; i < images_data.size(); ++i)
{
APIData bad = images_data.getobj(std::to_string(i));
cv::Mat img = bad.get("image").get<std::vector<cv::Mat>>().at(0);

// pred
std::vector<APIData> preds = bad.getv("predictions");
for (size_t k = 0; k < preds.size(); ++k)
{
APIData &pred_ad = preds[k];
// float score = pred_ad.get("prob").get<float>();
int64_t label = pred_ad.get("label").get<int64_t>();
APIData bbox = pred_ad.getobj("bbox");
int xmin = static_cast<int>(bbox.get("xmin").get<double>());
int ymin = static_cast<int>(bbox.get("ymin").get<double>());
int xmax = static_cast<int>(bbox.get("xmax").get<double>());
int ymax = static_cast<int>(bbox.get("ymax").get<double>());

auto &color = colors[label % ncolors];
cv::rectangle(img, cv::Point{ xmin, ymin },
cv::Point{ xmax, ymax }, color, 3);
}

// write image
std::string out_img_path
= dest_folder + "/image" + std::to_string(i) + ".jpg";
cv::imwrite(out_img_path, img);
}
}
}

static void
timeSeriesMetrics(const APIData &ad, const int timeseries,
std::vector<double> &mape, std::vector<double> &smape,
Expand Down