Skip to content

Commit

Permalink
feat(torch): add inference visualisation at test time
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed Mar 28, 2023
1 parent 66cbff5 commit bce6a9b
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 3 deletions.
60 changes: 60 additions & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2002,6 +2002,7 @@ namespace dd
{
APIData ad_res;
APIData ad_bbox;
APIData ad_res_bbox;
APIData ad_out = ad.getobj("parameters").getobj("output");
int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses;

Expand Down Expand Up @@ -2123,6 +2124,60 @@ namespace dd
++stop;
}

// Raw results
APIData bad;
// predictions
auto bboxes_acc = bboxes_tensor.accessor<float, 2>();
auto labels_acc = labels_tensor.accessor<int64_t, 1>();
auto score_acc = score_tensor.accessor<float, 1>();
std::vector<APIData> pred_vad;

for (int k = 0; k < labels_tensor.size(0); k++)
{
APIData pred_ad;
pred_ad.add("label", labels_acc[k]);
pred_ad.add("prob", static_cast<double>(score_acc[k]));
APIData bbox_ad;
bbox_ad.add("xmin", static_cast<double>(bboxes_acc[k][0]));
bbox_ad.add("ymin", static_cast<double>(bboxes_acc[k][1]));
bbox_ad.add("xmax", static_cast<double>(bboxes_acc[k][2]));
bbox_ad.add("ymax", static_cast<double>(bboxes_acc[k][3]));
pred_ad.add("bbox", bbox_ad);
pred_vad.push_back(pred_ad);
}
bad.add("predictions", pred_vad);
// targets
auto targ_bboxes_acc = targ_bboxes.accessor<float, 2>();
auto targ_labels_acc = targ_labels.accessor<int64_t, 1>();
std::vector<APIData> targ_vad;

for (int k = start; k < stop; k++)
{
APIData targ_ad;
targ_ad.add("label", targ_labels_acc[k]);
APIData bbox_ad;
bbox_ad.add("xmin",
static_cast<double>(targ_bboxes_acc[k][0]));
bbox_ad.add("ymin",
static_cast<double>(targ_bboxes_acc[k][1]));
bbox_ad.add("xmax",
static_cast<double>(targ_bboxes_acc[k][2]));
bbox_ad.add("ymax",
static_cast<double>(targ_bboxes_acc[k][3]));
targ_ad.add("bbox", bbox_ad);
targ_vad.push_back(targ_ad);
}
bad.add("targets", targ_vad);
// pred image
std::vector<cv::Mat> img_vec;
img_vec.push_back(torch_utils::tensorToImage(
batch.data.at(0).index(
{ torch::indexing::Slice(i, i + 1) }),
/* rgb = */ true));
bad.add("image", img_vec);
ad_res_bbox.add(std::to_string(entry_id), bad);

// Comparison against ground truth
auto vbad = get_bbox_stats(
targ_bboxes.index({ torch::indexing::Slice(start, stop) }),
targ_labels.index({ torch::indexing::Slice(start, stop) }),
Expand Down Expand Up @@ -2303,12 +2358,17 @@ namespace dd
ad_res.add("bbox", true);
ad_res.add("pos_count", entry_id);
ad_res.add("0", ad_bbox);
// raw bbox results
ad_res.add("raw_bboxes", ad_res_bbox);
}
else if (_segmentation)
ad_res.add("segmentation", true);
ad_res.add("batch_size",
entry_id); // here batch_size = tested entries count
SupervisedOutput::measure(ad_res, ad_out, out, test_id, test_name);
SupervisedOutput::create_visuals(
ad_res, ad_out, this->_mlmodel._repo + this->_mlmodel._visuals_dir,
test_id);
_module.train();
return 0;
}
Expand Down
9 changes: 8 additions & 1 deletion src/backends/torch/torchutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ namespace dd
torch_utils::copy_weights(jit_module, module, device, logger, strict);
}

cv::Mat tensorToImage(torch::Tensor tensor)
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb)
{
// 4 channels: batch size, chan, width, height
auto dims = tensor.sizes();
Expand Down Expand Up @@ -285,6 +285,13 @@ namespace dd
}
}
}

// convert to bgr
if (rgb)
{
cv::cvtColor(vals_mat, vals_mat, cv::COLOR_RGB2BGR);
}

return vals_mat;
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/backends/torch/torchutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ namespace dd

/** Converts a tensor to a CV image that can be saved on the disk.
* XXX(louis) this function is currently debug only, and makes strong
* assumptions on the input tensor format. */
cv::Mat tensorToImage(torch::Tensor tensor);
* assumptions on the input tensor format.
* \param rgb wether the tensor image is rgb */
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb = false);
}
}
#endif
1 change: 1 addition & 0 deletions src/mlmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ namespace dd
std::string _corresp; /**< file name of the class correspondences (e.g.
house / 23) */
std::string _best_model_filename = "/best_model.txt";
std::string _visuals_dir = "/visuals";

#ifdef USE_SIMSEARCH
#ifdef USE_ANNOY
Expand Down
58 changes: 58 additions & 0 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#define SUPERVISEDOUTPUTCONNECTOR_H
#define TS_METRICS_EPSILON 1E-2

#include <opencv2/opencv.hpp>

#include "dto/output_connector.hpp"

template <typename T>
Expand Down Expand Up @@ -1373,6 +1375,62 @@ namespace dd
ad_out.add("measure", meas_obj);
}

/** Create visuals from test results and write them in the model
* repository. */
static void create_visuals(const APIData &ad_res, APIData &ad_out,
const std::string &visuals_folder, int test_id)
{
(void)ad_out;
int iteration = static_cast<int>(ad_res.get("iteration").get<double>());
bool bbox = ad_res.has("bbox") && ad_res.get("bbox").get<bool>();
std::string targ_dest_folder
= visuals_folder + "/target/test" + std::to_string(test_id);
std::string dest_folder = visuals_folder + "/iteration"
+ std::to_string(iteration) + "/test"
+ std::to_string(test_id);
fileops::create_dir(dest_folder, 0755);

cv::Scalar colors[]
= { { 255, 0, 0 }, { 0, 255, 0 }, { 0, 0, 255 },
{ 0, 255, 255 }, { 255, 0, 255 }, { 255, 255, 0 },
{ 255, 127, 127 }, { 127, 255, 127 }, { 127, 127, 255 } };
int ncolors = sizeof(colors) / sizeof(cv::Scalar);

if (bbox)
{
APIData images_data = ad_res.getobj("raw_bboxes");

for (size_t i = 0; i < images_data.size(); ++i)
{
APIData bad = images_data.getobj(std::to_string(i));
cv::Mat img = bad.get("image").get<std::vector<cv::Mat>>().at(0);

// pred
std::vector<APIData> preds = bad.getv("predictions");
for (size_t k = 0; k < preds.size(); ++k)
{
APIData &pred_ad = preds[k];
// float score = pred_ad.get("prob").get<float>();
int64_t label = pred_ad.get("label").get<int64_t>();
APIData bbox = pred_ad.getobj("bbox");
int xmin = static_cast<int>(bbox.get("xmin").get<double>());
int ymin = static_cast<int>(bbox.get("ymin").get<double>());
int xmax = static_cast<int>(bbox.get("xmax").get<double>());
int ymax = static_cast<int>(bbox.get("ymax").get<double>());

auto &color = colors[label % ncolors];
cv::rectangle(img, cv::Point{ xmin, ymin },
cv::Point{ xmax, ymax }, color, 3);
}

// write image
std::string out_img_path
= dest_folder + "/image" + std::to_string(i) + ".jpg";
cv::imwrite(out_img_path, img);
}
}
}

static void
timeSeriesMetrics(const APIData &ad, const int timeseries,
std::vector<double> &mape, std::vector<double> &smape,
Expand Down

0 comments on commit bce6a9b

Please sign in to comment.