1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
|
- # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
- import argparse
- from typing import Tuple, Union
- import cv2
- import numpy as np
- import yaml
- from ultralytics.utils import ASSETS
- try:
- from tflite_runtime.interpreter import Interpreter
- except ImportError:
- import tensorflow as tf
- Interpreter = tf.lite.Interpreter
- class YOLOv8TFLite:
- """
- A YOLOv8 object detection class using TensorFlow Lite for efficient inference.
- This class handles model loading, preprocessing, inference, and visualization of detection results for YOLOv8
- models converted to TensorFlow Lite format.
- Attributes:
- model (Interpreter): TensorFlow Lite interpreter for the YOLOv8 model.
- conf (float): Confidence threshold for filtering detections.
- iou (float): Intersection over Union threshold for non-maximum suppression.
- classes (dict): Dictionary mapping class IDs to class names.
- color_palette (np.ndarray): Random color palette for visualization with shape (num_classes, 3).
- in_width (int): Input width required by the model.
- in_height (int): Input height required by the model.
- in_index (int): Input tensor index in the model.
- in_scale (float): Input quantization scale factor.
- in_zero_point (int): Input quantization zero point.
- int8 (bool): Whether the model uses int8 quantization.
- out_index (int): Output tensor index in the model.
- out_scale (float): Output quantization scale factor.
- out_zero_point (int): Output quantization zero point.
- Methods:
- letterbox: Resize and pad image while maintaining aspect ratio.
- draw_detections: Draw bounding boxes and labels on the input image.
- preprocess: Preprocess the input image before inference.
- postprocess: Process model outputs to extract and visualize detections.
- detect: Perform object detection on an input image.
- Examples:
- Initialize detector and run inference
- >>> detector = YOLOv8TFLite("yolov8n.tflite", conf=0.25, iou=0.45)
- >>> result = detector.detect("image.jpg")
- >>> cv2.imshow("Result", result)
- """
- def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: Union[str, None] = None):
- """
- Initialize the YOLOv8TFLite detector.
- Args:
- model (str): Path to the TFLite model file.
- conf (float): Confidence threshold for filtering detections.
- iou (float): IoU threshold for non-maximum suppression.
- metadata (str | None): Path to the metadata file containing class names.
- """
- self.conf = conf
- self.iou = iou
- if metadata is None:
- self.classes = {i: i for i in range(1000)}
- else:
- with open(metadata) as f:
- self.classes = yaml.safe_load(f)["names"]
- np.random.seed(42) # Set seed for reproducible colors
- self.color_palette = np.random.uniform(128, 255, size=(len(self.classes), 3))
- # Initialize the TFLite interpreter
- self.model = Interpreter(model_path=model)
- self.model.allocate_tensors()
- # Get input details
- input_details = self.model.get_input_details()[0]
- self.in_width, self.in_height = input_details["shape"][1:3]
- self.in_index = input_details["index"]
- self.in_scale, self.in_zero_point = input_details["quantization"]
- self.int8 = input_details["dtype"] == np.int8
- # Get output details
- output_details = self.model.get_output_details()[0]
- self.out_index = output_details["index"]
- self.out_scale, self.out_zero_point = output_details["quantization"]
- def letterbox(
- self, img: np.ndarray, new_shape: Tuple[int, int] = (640, 640)
- ) -> Tuple[np.ndarray, Tuple[float, float]]:
- """
- Resize and pad image while maintaining aspect ratio.
- Args:
- img (np.ndarray): Input image with shape (H, W, C).
- new_shape (Tuple[int, int]): Target shape (height, width).
- Returns:
- (np.ndarray): Resized and padded image.
- (Tuple[float, float]): Padding ratios (top/height, left/width) for coordinate adjustment.
- """
- shape = img.shape[:2] # Current shape [height, width]
- # Scale ratio (new / old)
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
- # Compute padding
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
- dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
- if shape[::-1] != new_unpad: # Resize if needed
- img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
- return img, (top / img.shape[0], left / img.shape[1])
- def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None:
- """
- Draw bounding boxes and labels on the input image based on detected objects.
- Args:
- img (np.ndarray): The input image to draw detections on.
- box (np.ndarray): Detected bounding box in the format [x1, y1, width, height].
- score (np.float32): Confidence score of the detection.
- class_id (int): Class ID for the detected object.
- """
- x1, y1, w, h = box
- color = self.color_palette[class_id]
- # Draw bounding box
- cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
- # Create label with class name and score
- label = f"{self.classes[class_id]}: {score:.2f}"
- # Get text size for background rectangle
- (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
- # Position label above or below box depending on space
- label_x = x1
- label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
- # Draw label background
- cv2.rectangle(
- img,
- (int(label_x), int(label_y - label_height)),
- (int(label_x + label_width), int(label_y + label_height)),
- color,
- cv2.FILLED,
- )
- # Draw text
- cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
- def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float]]:
- """
- Preprocess the input image before performing inference.
- Args:
- img (np.ndarray): The input image to be preprocessed with shape (H, W, C).
- Returns:
- (np.ndarray): Preprocessed image ready for model input.
- (Tuple[float, float]): Padding ratios for coordinate adjustment.
- """
- img, pad = self.letterbox(img, (self.in_width, self.in_height))
- img = img[..., ::-1][None] # BGR to RGB and add batch dimension (N, H, W, C) for TFLite
- img = np.ascontiguousarray(img)
- img = img.astype(np.float32)
- return img / 255, pad # Normalize to [0, 1]
- def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: Tuple[float, float]) -> np.ndarray:
- """
- Process model outputs to extract and visualize detections.
- Args:
- img (np.ndarray): The original input image.
- outputs (np.ndarray): Raw model outputs.
- pad (Tuple[float, float]): Padding ratios from preprocessing.
- Returns:
- (np.ndarray): The input image with detections drawn on it.
- """
- # Adjust coordinates based on padding and scale to original image size
- outputs[:, 0] -= pad[1]
- outputs[:, 1] -= pad[0]
- outputs[:, :4] *= max(img.shape)
- # Transform outputs to [x, y, w, h] format
- outputs = outputs.transpose(0, 2, 1)
- outputs[..., 0] -= outputs[..., 2] / 2 # x center to top-left x
- outputs[..., 1] -= outputs[..., 3] / 2 # y center to top-left y
- for out in outputs:
- # Get scores and apply confidence threshold
- scores = out[:, 4:].max(-1)
- keep = scores > self.conf
- boxes = out[keep, :4]
- scores = scores[keep]
- class_ids = out[keep, 4:].argmax(-1)
- # Apply non-maximum suppression
- indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou).flatten()
- # Draw detections that survived NMS
- [self.draw_detections(img, boxes[i], scores[i], class_ids[i]) for i in indices]
- return img
- def detect(self, img_path: str) -> np.ndarray:
- """
- Perform object detection on an input image.
- Args:
- img_path (str): Path to the input image file.
- Returns:
- (np.ndarray): The output image with drawn detections.
- """
- # Load and preprocess image
- img = cv2.imread(img_path)
- x, pad = self.preprocess(img)
- # Apply quantization if model is int8
- if self.int8:
- x = (x / self.in_scale + self.in_zero_point).astype(np.int8)
- # Set input tensor and run inference
- self.model.set_tensor(self.in_index, x)
- self.model.invoke()
- # Get output and dequantize if necessary
- y = self.model.get_tensor(self.out_index)
- if self.int8:
- y = (y.astype(np.float32) - self.out_zero_point) * self.out_scale
- # Process detections and return result
- return self.postprocess(img, y, pad)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--model",
- type=str,
- default="yolov8n_saved_model/yolov8n_full_integer_quant.tflite",
- help="Path to TFLite model.",
- )
- parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
- parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
- parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
- parser.add_argument("--metadata", type=str, default="yolov8n_saved_model/metadata.yaml", help="Metadata yaml")
- args = parser.parse_args()
- detector = YOLOv8TFLite(args.model, args.conf, args.iou, args.metadata)
- result = detector.detect(str(ASSETS / "bus.jpg"))
- cv2.imshow("Output", result)
- cv2.waitKey(0)
|