Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

main.py 4.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  1. import argparse
  2. import cv2.dnn
  3. import numpy as np
  4. from ultralytics.utils import ASSETS, yaml_load
  5. from ultralytics.utils.checks import check_yaml
  6. CLASSES = yaml_load(check_yaml('coco128.yaml'))['names']
  7. colors = np.random.uniform(0, 255, size=(len(CLASSES), 3))
  8. def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
  9. """
  10. Draws bounding boxes on the input image based on the provided arguments.
  11. Args:
  12. img (numpy.ndarray): The input image to draw the bounding box on.
  13. class_id (int): Class ID of the detected object.
  14. confidence (float): Confidence score of the detected object.
  15. x (int): X-coordinate of the top-left corner of the bounding box.
  16. y (int): Y-coordinate of the top-left corner of the bounding box.
  17. x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box.
  18. y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box.
  19. """
  20. label = f'{CLASSES[class_id]} ({confidence:.2f})'
  21. color = colors[class_id]
  22. cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2)
  23. cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
  24. def main(onnx_model, input_image):
  25. """
  26. Main function to load ONNX model, perform inference, draw bounding boxes, and display the output image.
  27. Args:
  28. onnx_model (str): Path to the ONNX model.
  29. input_image (str): Path to the input image.
  30. Returns:
  31. list: List of dictionaries containing detection information such as class_id, class_name, confidence, etc.
  32. """
  33. # Load the ONNX model
  34. model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model)
  35. # Read the input image
  36. original_image: np.ndarray = cv2.imread(input_image)
  37. [height, width, _] = original_image.shape
  38. # Prepare a square image for inference
  39. length = max((height, width))
  40. image = np.zeros((length, length, 3), np.uint8)
  41. image[0:height, 0:width] = original_image
  42. # Calculate scale factor
  43. scale = length / 640
  44. # Preprocess the image and prepare blob for model
  45. blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)
  46. model.setInput(blob)
  47. # Perform inference
  48. outputs = model.forward()
  49. # Prepare output array
  50. outputs = np.array([cv2.transpose(outputs[0])])
  51. rows = outputs.shape[1]
  52. boxes = []
  53. scores = []
  54. class_ids = []
  55. # Iterate through output to collect bounding boxes, confidence scores, and class IDs
  56. for i in range(rows):
  57. classes_scores = outputs[0][i][4:]
  58. (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
  59. if maxScore >= 0.25:
  60. box = [
  61. outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]),
  62. outputs[0][i][2], outputs[0][i][3]]
  63. boxes.append(box)
  64. scores.append(maxScore)
  65. class_ids.append(maxClassIndex)
  66. # Apply NMS (Non-maximum suppression)
  67. result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5)
  68. detections = []
  69. # Iterate through NMS results to draw bounding boxes and labels
  70. for i in range(len(result_boxes)):
  71. index = result_boxes[i]
  72. box = boxes[index]
  73. detection = {
  74. 'class_id': class_ids[index],
  75. 'class_name': CLASSES[class_ids[index]],
  76. 'confidence': scores[index],
  77. 'box': box,
  78. 'scale': scale}
  79. detections.append(detection)
  80. draw_bounding_box(original_image, class_ids[index], scores[index], round(box[0] * scale), round(box[1] * scale),
  81. round((box[0] + box[2]) * scale), round((box[1] + box[3]) * scale))
  82. # Display the image with bounding boxes
  83. cv2.imshow('image', original_image)
  84. cv2.waitKey(0)
  85. cv2.destroyAllWindows()
  86. return detections
  87. if __name__ == '__main__':
  88. parser = argparse.ArgumentParser()
  89. parser.add_argument('--model', default='yolov8n.onnx', help='Input your ONNX model.')
  90. parser.add_argument('--img', default=str(ASSETS / 'bus.jpg'), help='Path to input image.')
  91. args = parser.parse_args()
  92. main(args.model, args.img)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...