-
Notifications
You must be signed in to change notification settings - Fork 19
/
detector.h
50 lines (41 loc) · 1.44 KB
/
detector.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#ifndef Detector_H
#define Detector_H
#include <vector>
#include <string>
//opencv
#include <opencv2/core.hpp>
#include "opencv2/opencv.hpp"
//tensorflow
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "text_box.h"
class Detector{
public:
Detector(){};
Detector(const std::string frozen_graph_filename){
init_graph(frozen_graph_filename);
}
bool init_graph(const std::string& frozen_graph_filename){
if (!ReadBinaryProto(tensorflow::Env::Default(), frozen_graph_filename, &graph_def).ok()) {
LOG(ERROR) << "error when reading proto" << frozen_graph_fliename;
return -1;
}
tensorflow::SessionOptions sess_opt;
sess_opt.config.mutable_gpu_options()->set_allow_growth(true);
(&session)->reset(tensorflow::NewSession(sess_opt));
if (!session->Create(graph_def).ok()) {
LOG(ERROR) << "error create graph";
return -1;
}
}
virtual int run_graph(const cv::Mat& image, std::vector<TextBox>& results) = 0;
tensorflow::GraphDef graph_def;
std::string input_layer; //for detector, we assume there is only one input
std::unique_ptr<tensorflow::Session> session;
std::vector<std::string> output_layers;
};
#endif