Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

main.py 10 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import argparse
  3. from typing import Tuple, Union
  4. import cv2
  5. import numpy as np
  6. import yaml
  7. from ultralytics.utils import ASSETS
  8. try:
  9. from tflite_runtime.interpreter import Interpreter
  10. except ImportError:
  11. import tensorflow as tf
  12. Interpreter = tf.lite.Interpreter
  13. class YOLOv8TFLite:
  14. """
  15. A YOLOv8 object detection class using TensorFlow Lite for efficient inference.
  16. This class handles model loading, preprocessing, inference, and visualization of detection results for YOLOv8
  17. models converted to TensorFlow Lite format.
  18. Attributes:
  19. model (Interpreter): TensorFlow Lite interpreter for the YOLOv8 model.
  20. conf (float): Confidence threshold for filtering detections.
  21. iou (float): Intersection over Union threshold for non-maximum suppression.
  22. classes (dict): Dictionary mapping class IDs to class names.
  23. color_palette (np.ndarray): Random color palette for visualization with shape (num_classes, 3).
  24. in_width (int): Input width required by the model.
  25. in_height (int): Input height required by the model.
  26. in_index (int): Input tensor index in the model.
  27. in_scale (float): Input quantization scale factor.
  28. in_zero_point (int): Input quantization zero point.
  29. int8 (bool): Whether the model uses int8 quantization.
  30. out_index (int): Output tensor index in the model.
  31. out_scale (float): Output quantization scale factor.
  32. out_zero_point (int): Output quantization zero point.
  33. Methods:
  34. letterbox: Resize and pad image while maintaining aspect ratio.
  35. draw_detections: Draw bounding boxes and labels on the input image.
  36. preprocess: Preprocess the input image before inference.
  37. postprocess: Process model outputs to extract and visualize detections.
  38. detect: Perform object detection on an input image.
  39. Examples:
  40. Initialize detector and run inference
  41. >>> detector = YOLOv8TFLite("yolov8n.tflite", conf=0.25, iou=0.45)
  42. >>> result = detector.detect("image.jpg")
  43. >>> cv2.imshow("Result", result)
  44. """
  45. def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: Union[str, None] = None):
  46. """
  47. Initialize the YOLOv8TFLite detector.
  48. Args:
  49. model (str): Path to the TFLite model file.
  50. conf (float): Confidence threshold for filtering detections.
  51. iou (float): IoU threshold for non-maximum suppression.
  52. metadata (str | None): Path to the metadata file containing class names.
  53. """
  54. self.conf = conf
  55. self.iou = iou
  56. if metadata is None:
  57. self.classes = {i: i for i in range(1000)}
  58. else:
  59. with open(metadata) as f:
  60. self.classes = yaml.safe_load(f)["names"]
  61. np.random.seed(42) # Set seed for reproducible colors
  62. self.color_palette = np.random.uniform(128, 255, size=(len(self.classes), 3))
  63. # Initialize the TFLite interpreter
  64. self.model = Interpreter(model_path=model)
  65. self.model.allocate_tensors()
  66. # Get input details
  67. input_details = self.model.get_input_details()[0]
  68. self.in_width, self.in_height = input_details["shape"][1:3]
  69. self.in_index = input_details["index"]
  70. self.in_scale, self.in_zero_point = input_details["quantization"]
  71. self.int8 = input_details["dtype"] == np.int8
  72. # Get output details
  73. output_details = self.model.get_output_details()[0]
  74. self.out_index = output_details["index"]
  75. self.out_scale, self.out_zero_point = output_details["quantization"]
  76. def letterbox(
  77. self, img: np.ndarray, new_shape: Tuple[int, int] = (640, 640)
  78. ) -> Tuple[np.ndarray, Tuple[float, float]]:
  79. """
  80. Resize and pad image while maintaining aspect ratio.
  81. Args:
  82. img (np.ndarray): Input image with shape (H, W, C).
  83. new_shape (Tuple[int, int]): Target shape (height, width).
  84. Returns:
  85. (np.ndarray): Resized and padded image.
  86. (Tuple[float, float]): Padding ratios (top/height, left/width) for coordinate adjustment.
  87. """
  88. shape = img.shape[:2] # Current shape [height, width]
  89. # Scale ratio (new / old)
  90. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  91. # Compute padding
  92. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  93. dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
  94. if shape[::-1] != new_unpad: # Resize if needed
  95. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  96. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  97. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  98. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
  99. return img, (top / img.shape[0], left / img.shape[1])
  100. def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None:
  101. """
  102. Draw bounding boxes and labels on the input image based on detected objects.
  103. Args:
  104. img (np.ndarray): The input image to draw detections on.
  105. box (np.ndarray): Detected bounding box in the format [x1, y1, width, height].
  106. score (np.float32): Confidence score of the detection.
  107. class_id (int): Class ID for the detected object.
  108. """
  109. x1, y1, w, h = box
  110. color = self.color_palette[class_id]
  111. # Draw bounding box
  112. cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
  113. # Create label with class name and score
  114. label = f"{self.classes[class_id]}: {score:.2f}"
  115. # Get text size for background rectangle
  116. (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
  117. # Position label above or below box depending on space
  118. label_x = x1
  119. label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
  120. # Draw label background
  121. cv2.rectangle(
  122. img,
  123. (int(label_x), int(label_y - label_height)),
  124. (int(label_x + label_width), int(label_y + label_height)),
  125. color,
  126. cv2.FILLED,
  127. )
  128. # Draw text
  129. cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
  130. def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float]]:
  131. """
  132. Preprocess the input image before performing inference.
  133. Args:
  134. img (np.ndarray): The input image to be preprocessed with shape (H, W, C).
  135. Returns:
  136. (np.ndarray): Preprocessed image ready for model input.
  137. (Tuple[float, float]): Padding ratios for coordinate adjustment.
  138. """
  139. img, pad = self.letterbox(img, (self.in_width, self.in_height))
  140. img = img[..., ::-1][None] # BGR to RGB and add batch dimension (N, H, W, C) for TFLite
  141. img = np.ascontiguousarray(img)
  142. img = img.astype(np.float32)
  143. return img / 255, pad # Normalize to [0, 1]
  144. def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: Tuple[float, float]) -> np.ndarray:
  145. """
  146. Process model outputs to extract and visualize detections.
  147. Args:
  148. img (np.ndarray): The original input image.
  149. outputs (np.ndarray): Raw model outputs.
  150. pad (Tuple[float, float]): Padding ratios from preprocessing.
  151. Returns:
  152. (np.ndarray): The input image with detections drawn on it.
  153. """
  154. # Adjust coordinates based on padding and scale to original image size
  155. outputs[:, 0] -= pad[1]
  156. outputs[:, 1] -= pad[0]
  157. outputs[:, :4] *= max(img.shape)
  158. # Transform outputs to [x, y, w, h] format
  159. outputs = outputs.transpose(0, 2, 1)
  160. outputs[..., 0] -= outputs[..., 2] / 2 # x center to top-left x
  161. outputs[..., 1] -= outputs[..., 3] / 2 # y center to top-left y
  162. for out in outputs:
  163. # Get scores and apply confidence threshold
  164. scores = out[:, 4:].max(-1)
  165. keep = scores > self.conf
  166. boxes = out[keep, :4]
  167. scores = scores[keep]
  168. class_ids = out[keep, 4:].argmax(-1)
  169. # Apply non-maximum suppression
  170. indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou).flatten()
  171. # Draw detections that survived NMS
  172. [self.draw_detections(img, boxes[i], scores[i], class_ids[i]) for i in indices]
  173. return img
  174. def detect(self, img_path: str) -> np.ndarray:
  175. """
  176. Perform object detection on an input image.
  177. Args:
  178. img_path (str): Path to the input image file.
  179. Returns:
  180. (np.ndarray): The output image with drawn detections.
  181. """
  182. # Load and preprocess image
  183. img = cv2.imread(img_path)
  184. x, pad = self.preprocess(img)
  185. # Apply quantization if model is int8
  186. if self.int8:
  187. x = (x / self.in_scale + self.in_zero_point).astype(np.int8)
  188. # Set input tensor and run inference
  189. self.model.set_tensor(self.in_index, x)
  190. self.model.invoke()
  191. # Get output and dequantize if necessary
  192. y = self.model.get_tensor(self.out_index)
  193. if self.int8:
  194. y = (y.astype(np.float32) - self.out_zero_point) * self.out_scale
  195. # Process detections and return result
  196. return self.postprocess(img, y, pad)
  197. if __name__ == "__main__":
  198. parser = argparse.ArgumentParser()
  199. parser.add_argument(
  200. "--model",
  201. type=str,
  202. default="yolov8n_saved_model/yolov8n_full_integer_quant.tflite",
  203. help="Path to TFLite model.",
  204. )
  205. parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
  206. parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
  207. parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
  208. parser.add_argument("--metadata", type=str, default="yolov8n_saved_model/metadata.yaml", help="Metadata yaml")
  209. args = parser.parse_args()
  210. detector = YOLOv8TFLite(args.model, args.conf, args.iou, args.metadata)
  211. result = detector.detect(str(ASSETS / "bus.jpg"))
  212. cv2.imshow("Output", result)
  213. cv2.waitKey(0)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...