Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

inference.cc 7.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
  1. #include "inference.h"
  2. #include <memory>
  3. #include <opencv2/dnn.hpp>
  4. #include <random>
  5. namespace yolo {
  6. // Constructor to initialize the model with default input shape
  7. Inference::Inference(const std::string &model_path, const float &model_confidence_threshold, const float &model_NMS_threshold) {
  8. model_input_shape_ = cv::Size(640, 640); // Set the default size for models with dynamic shapes to prevent errors.
  9. model_confidence_threshold_ = model_confidence_threshold;
  10. model_NMS_threshold_ = model_NMS_threshold;
  11. InitializeModel(model_path);
  12. }
  13. // Constructor to initialize the model with specified input shape
  14. Inference::Inference(const std::string &model_path, const cv::Size model_input_shape, const float &model_confidence_threshold, const float &model_NMS_threshold) {
  15. model_input_shape_ = model_input_shape;
  16. model_confidence_threshold_ = model_confidence_threshold;
  17. model_NMS_threshold_ = model_NMS_threshold;
  18. InitializeModel(model_path);
  19. }
  20. void Inference::InitializeModel(const std::string &model_path) {
  21. ov::Core core; // OpenVINO core object
  22. std::shared_ptr<ov::Model> model = core.read_model(model_path); // Read the model from file
  23. // If the model has dynamic shapes, reshape it to the specified input shape
  24. if (model->is_dynamic()) {
  25. model->reshape({1, 3, static_cast<long int>(model_input_shape_.height), static_cast<long int>(model_input_shape_.width)});
  26. }
  27. // Preprocessing setup for the model
  28. ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(model);
  29. ppp.input().tensor().set_element_type(ov::element::u8).set_layout("NHWC").set_color_format(ov::preprocess::ColorFormat::BGR);
  30. ppp.input().preprocess().convert_element_type(ov::element::f32).convert_color(ov::preprocess::ColorFormat::RGB).scale({255, 255, 255});
  31. ppp.input().model().set_layout("NCHW");
  32. ppp.output().tensor().set_element_type(ov::element::f32);
  33. model = ppp.build(); // Build the preprocessed model
  34. // Compile the model for inference
  35. compiled_model_ = core.compile_model(model, "AUTO");
  36. inference_request_ = compiled_model_.create_infer_request(); // Create inference request
  37. short width, height;
  38. // Get input shape from the model
  39. const std::vector<ov::Output<ov::Node>> inputs = model->inputs();
  40. const ov::Shape input_shape = inputs[0].get_shape();
  41. height = input_shape[1];
  42. width = input_shape[2];
  43. model_input_shape_ = cv::Size2f(width, height);
  44. // Get output shape from the model
  45. const std::vector<ov::Output<ov::Node>> outputs = model->outputs();
  46. const ov::Shape output_shape = outputs[0].get_shape();
  47. height = output_shape[1];
  48. width = output_shape[2];
  49. model_output_shape_ = cv::Size(width, height);
  50. }
  51. // Method to run inference on an input frame
  52. void Inference::RunInference(cv::Mat &frame) {
  53. Preprocessing(frame); // Preprocess the input frame
  54. inference_request_.infer(); // Run inference
  55. PostProcessing(frame); // Postprocess the inference results
  56. }
  57. // Method to preprocess the input frame
  58. void Inference::Preprocessing(const cv::Mat &frame) {
  59. cv::Mat resized_frame;
  60. cv::resize(frame, resized_frame, model_input_shape_, 0, 0, cv::INTER_AREA); // Resize the frame to match the model input shape
  61. // Calculate scaling factor
  62. scale_factor_.x = static_cast<float>(frame.cols / model_input_shape_.width);
  63. scale_factor_.y = static_cast<float>(frame.rows / model_input_shape_.height);
  64. float *input_data = (float *)resized_frame.data; // Get pointer to resized frame data
  65. const ov::Tensor input_tensor = ov::Tensor(compiled_model_.input().get_element_type(), compiled_model_.input().get_shape(), input_data); // Create input tensor
  66. inference_request_.set_input_tensor(input_tensor); // Set input tensor for inference
  67. }
  68. // Method to postprocess the inference results
  69. void Inference::PostProcessing(cv::Mat &frame) {
  70. std::vector<int> class_list;
  71. std::vector<float> confidence_list;
  72. std::vector<cv::Rect> box_list;
  73. // Get the output tensor from the inference request
  74. const float *detections = inference_request_.get_output_tensor().data<const float>();
  75. const cv::Mat detection_outputs(model_output_shape_, CV_32F, (float *)detections); // Create OpenCV matrix from output tensor
  76. // Iterate over detections and collect class IDs, confidence scores, and bounding boxes
  77. for (int i = 0; i < detection_outputs.cols; ++i) {
  78. const cv::Mat classes_scores = detection_outputs.col(i).rowRange(4, detection_outputs.rows);
  79. cv::Point class_id;
  80. double score;
  81. cv::minMaxLoc(classes_scores, nullptr, &score, nullptr, &class_id); // Find the class with the highest score
  82. // Check if the detection meets the confidence threshold
  83. if (score > model_confidence_threshold_) {
  84. class_list.push_back(class_id.y);
  85. confidence_list.push_back(score);
  86. const float x = detection_outputs.at<float>(0, i);
  87. const float y = detection_outputs.at<float>(1, i);
  88. const float w = detection_outputs.at<float>(2, i);
  89. const float h = detection_outputs.at<float>(3, i);
  90. cv::Rect box;
  91. box.x = static_cast<int>(x);
  92. box.y = static_cast<int>(y);
  93. box.width = static_cast<int>(w);
  94. box.height = static_cast<int>(h);
  95. box_list.push_back(box);
  96. }
  97. }
  98. // Apply Non-Maximum Suppression (NMS) to filter overlapping bounding boxes
  99. std::vector<int> NMS_result;
  100. cv::dnn::NMSBoxes(box_list, confidence_list, model_confidence_threshold_, model_NMS_threshold_, NMS_result);
  101. // Collect final detections after NMS
  102. for (int i = 0; i < NMS_result.size(); ++i) {
  103. Detection result;
  104. const unsigned short id = NMS_result[i];
  105. result.class_id = class_list[id];
  106. result.confidence = confidence_list[id];
  107. result.box = GetBoundingBox(box_list[id]);
  108. DrawDetectedObject(frame, result);
  109. }
  110. }
  111. // Method to get the bounding box in the correct scale
  112. cv::Rect Inference::GetBoundingBox(const cv::Rect &src) const {
  113. cv::Rect box = src;
  114. box.x = (box.x - box.width / 2) * scale_factor_.x;
  115. box.y = (box.y - box.height / 2) * scale_factor_.y;
  116. box.width *= scale_factor_.x;
  117. box.height *= scale_factor_.y;
  118. return box;
  119. }
  120. void Inference::DrawDetectedObject(cv::Mat &frame, const Detection &detection) const {
  121. const cv::Rect &box = detection.box;
  122. const float &confidence = detection.confidence;
  123. const int &class_id = detection.class_id;
  124. // Generate a random color for the bounding box
  125. std::random_device rd;
  126. std::mt19937 gen(rd());
  127. std::uniform_int_distribution<int> dis(120, 255);
  128. const cv::Scalar &color = cv::Scalar(dis(gen), dis(gen), dis(gen));
  129. // Draw the bounding box around the detected object
  130. cv::rectangle(frame, cv::Point(box.x, box.y), cv::Point(box.x + box.width, box.y + box.height), color, 3);
  131. // Prepare the class label and confidence text
  132. std::string classString = classes_[class_id] + std::to_string(confidence).substr(0, 4);
  133. // Get the size of the text box
  134. cv::Size textSize = cv::getTextSize(classString, cv::FONT_HERSHEY_DUPLEX, 0.75, 2, 0);
  135. cv::Rect textBox(box.x, box.y - 40, textSize.width + 10, textSize.height + 20);
  136. // Draw the text box
  137. cv::rectangle(frame, textBox, color, cv::FILLED);
  138. // Put the class label and confidence text above the bounding box
  139. cv::putText(frame, classString, cv::Point(box.x + 5, box.y - 10), cv::FONT_HERSHEY_DUPLEX, 0.75, cv::Scalar(0, 0, 0), 2, 0);
  140. }
  141. } // namespace yolo
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...