From 123c630e2abaed6fcd7e1bbbc9c74f31af9b0017 Mon Sep 17 00:00:00 2001 From: ais-lab Date: Wed, 3 Aug 2022 12:43:24 +0200 Subject: [PATCH] new save file functions added --- include/darknet.h | 2 +- src/darknet.c | 2 +- src/detector.c | 32 ++++++--------------------- src/image.h | 1 + src/utils.c | 55 +++++++++++++++++++++++++++++++++++++++++++++++ src/utils.h | 5 +++++ 6 files changed, 70 insertions(+), 27 deletions(-) diff --git a/include/darknet.h b/include/darknet.h index d72027cc45f..05c5bc21a36 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -1049,7 +1049,7 @@ LIB_API float *network_predict_image_letterbox(network *net, image im); LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net); LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, float thresh, float iou_thresh, int mjpeg_port, int show_imgs, int benchmark_layers, char* chart_path); LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, - float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box, int benchmark_layers); + float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box, int benchmark_layers, int save_labels_actual); LIB_API int network_width(network *net); LIB_API int network_height(network *net); LIB_API void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm); diff --git a/src/darknet.c b/src/darknet.c index 13ab75f3d38..a74210fca74 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -495,7 +495,7 @@ int main(int argc, char **argv) float thresh = find_float_arg(argc, argv, "-thresh", .24); int ext_output = find_arg(argc, argv, "-ext_output"); char *filename = (argc > 4) ? argv[4]: 0; - test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL, 0, 0); + test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL, 0, 0, 0); } else if (0 == strcmp(argv[1], "cifar")){ run_cifar(argc, argv); } else if (0 == strcmp(argv[1], "go")){ diff --git a/src/detector.c b/src/detector.c index 0b947b69089..ba18baf21f4 100644 --- a/src/detector.c +++ b/src/detector.c @@ -1624,7 +1624,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, - float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box, int benchmark_layers) + float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box, int benchmark_layers, int save_labels_actual) { list *options = read_data_cfg(datacfg); char *name_list = option_find_str(options, "names", "data/names.list"); @@ -1729,29 +1729,10 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam } // pseudo labeling concept - fast.ai - if (save_labels) - { - char labelpath[4096]; - replace_image_to_label(input, labelpath); - - FILE* fw = fopen(labelpath, "wb"); - int i; - for (i = 0; i < nboxes; ++i) { - char buff[1024]; - int class_id = -1; - float prob = 0; - for (j = 0; j < l.classes; ++j) { - if (dets[i].prob[j] > thresh && dets[i].prob[j] > prob) { - prob = dets[i].prob[j]; - class_id = j; - } - } - if (class_id >= 0) { - sprintf(buff, "%d %2.4f %2.4f %2.4f %2.4f\n", class_id, dets[i].bbox.x, dets[i].bbox.y, dets[i].bbox.w, dets[i].bbox.h); - fwrite(buff, sizeof(char), strlen(buff), fw); - } - } - fclose(fw); + if (save_labels){ + save_outputs(input, nboxes, dets, thresh, l.classes); + } else if(save_labels_actual){ + save_outputs_actual(im, input, nboxes, dets, thresh, names); } free_detections(dets, nboxes); @@ -1997,6 +1978,7 @@ void run_detector(int argc, char **argv) // and for recall mode (extended output table-like format with results for best_class fit) int ext_output = find_arg(argc, argv, "-ext_output"); int save_labels = find_arg(argc, argv, "-save_labels"); + int save_labels_actual = find_arg(argc, argv, "-save_labels_actual"); char* chart_path = find_char_arg(argc, argv, "-chart", 0); if (argc < 4) { fprintf(stderr, "usage: %s %s [train/test/valid/demo/map] [data] [cfg] [weights (optional)]\n", argv[0], argv[1]); @@ -2035,7 +2017,7 @@ void run_detector(int argc, char **argv) if (strlen(weights) > 0) if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0; char *filename = (argc > 6) ? argv[6] : 0; - if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box, benchmark_layers); + if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box, benchmark_layers, save_labels_actual); else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, thresh, iou_thresh, mjpeg_port, show_imgs, benchmark_layers, chart_path); else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); diff --git a/src/image.h b/src/image.h index 90e6a04868b..771ee2c33b0 100644 --- a/src/image.h +++ b/src/image.h @@ -33,6 +33,7 @@ void write_label(image a, int r, int c, image *characters, char *string, float * void draw_detections(image im, int num, float thresh, box *boxes, float **probs, char **names, image **labels, int classes); void draw_detections_v3(image im, detection *dets, int num, float thresh, char **names, image **alphabet, int classes, int ext_output); image image_distance(image a, image b); +int compare_by_lefts(const void *a_ptr, const void *b_ptr); void scale_image(image m, float s); // image crop_image(image im, int dx, int dy, int w, int h); image random_crop_image(image im, int w, int h); diff --git a/src/utils.c b/src/utils.c index 0c271b7bf41..f95e5569f00 100644 --- a/src/utils.c +++ b/src/utils.c @@ -221,6 +221,61 @@ void find_replace(const char* str, char* orig, char* rep, char* output) free(buffer); } +void save_outputs(const char* input, int nboxes, detection *dets, float thresh, int classes){ + char labelpath[4096]; + replace_image_to_label(input, labelpath); + + FILE* fw = fopen(labelpath, "wb"); + int i,j; + for (i = 0; i < nboxes; ++i) { + char buff[1024]; + int class_id = -1; + float prob = 0; + for (j = 0; j < classes; ++j) { + if (dets[i].prob[j] > thresh && dets[i].prob[j] > prob) { + prob = dets[i].prob[j]; + class_id = j; + } + } + if (class_id >= 0) { + sprintf(buff, "1 -%d %2.4f %2.4f %2.4f %2.4f\n", class_id, dets[i].bbox.x, dets[i].bbox.y, dets[i].bbox.w, dets[i].bbox.h); + fwrite(buff, sizeof(char), strlen(buff), fw); + } + } + fclose(fw); +} + +void save_outputs_actual(image im, const char* input, int nboxes, detection *dets, float thresh, char **names){ + char labelpath[4096]; + replace_image_to_label(input, labelpath); + + int selected_detections_num; + detection_with_class* selected_detections = get_actual_detections(dets, nboxes, thresh, &selected_detections_num, names); + + qsort(selected_detections, selected_detections_num, sizeof(*selected_detections), compare_by_lefts); + + + FILE* fw = fopen(labelpath, "wb"); + int i,j; + for (i = 0; i < selected_detections_num; ++i) { + char buff[1024]; + + const int best_class = selected_detections[i].best_class; + + sprintf(buff, "%i %.0f %4.0f %4.0f %4.0f %4.0f \n", + best_class, // best class id + (selected_detections[i].det.prob[best_class] * 100), // probability + round((selected_detections[i].det.bbox.x - selected_detections[i].det.bbox.w / 2)*im.w), // left_x + round((selected_detections[i].det.bbox.y - selected_detections[i].det.bbox.h / 2)*im.h), // top_y + round(selected_detections[i].det.bbox.w*im.w), // width + round(selected_detections[i].det.bbox.h*im.h)); // height + fwrite(buff, sizeof(char), strlen(buff), fw); + } + fclose(fw); + free(selected_detections); + +} + void trim(char *str) { char* buffer = (char*)xcalloc(8192, sizeof(char)); diff --git a/src/utils.h b/src/utils.h index f84054b8b39..ce9d101e9d7 100644 --- a/src/utils.h +++ b/src/utils.h @@ -2,6 +2,9 @@ #define UTILS_H #include "darknet.h" #include "list.h" +#include "box.h" +#include "image.h" + #include #include @@ -44,6 +47,8 @@ int read_int(int fd); void write_int(int fd, int n); void read_all(int fd, char *buffer, size_t bytes); void write_all(int fd, char *buffer, size_t bytes); +void save_outputs(const char* input, int nboxes, detection *dets, float thresh, int classes); +void save_outputs_actual(image im, const char* input, int nboxes, detection *dets, float thresh, char **names); int read_all_fail(int fd, char *buffer, size_t bytes); int write_all_fail(int fd, char *buffer, size_t bytes); LIB_API void find_replace(const char* str, char* orig, char* rep, char* output);