Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

main.py 8.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import argparse
  3. from typing import List, Tuple, Union
  4. import cv2
  5. import numpy as np
  6. import onnxruntime as ort
  7. import torch
  8. import ultralytics.utils.ops as ops
  9. from ultralytics.engine.results import Results
  10. from ultralytics.utils import ASSETS, YAML
  11. from ultralytics.utils.checks import check_yaml
  12. class YOLOv8Seg:
  13. """
  14. YOLOv8 segmentation model for performing instance segmentation using ONNX Runtime.
  15. This class implements a YOLOv8 instance segmentation model using ONNX Runtime for inference. It handles
  16. preprocessing of input images, running inference with the ONNX model, and postprocessing the results to
  17. generate bounding boxes and segmentation masks.
  18. Attributes:
  19. session (ort.InferenceSession): ONNX Runtime inference session for model execution.
  20. imgsz (Tuple[int, int]): Input image size as (height, width) for the model.
  21. classes (dict): Dictionary mapping class indices to class names from the dataset.
  22. conf (float): Confidence threshold for filtering detections.
  23. iou (float): IoU threshold used by non-maximum suppression.
  24. Methods:
  25. letterbox: Resize and pad image while maintaining aspect ratio.
  26. preprocess: Preprocess the input image before feeding it into the model.
  27. postprocess: Post-process model predictions to extract meaningful results.
  28. process_mask: Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
  29. Examples:
  30. >>> model = YOLOv8Seg("yolov8n-seg.onnx", conf=0.25, iou=0.7)
  31. >>> img = cv2.imread("image.jpg")
  32. >>> results = model(img)
  33. >>> cv2.imshow("Segmentation", results[0].plot())
  34. """
  35. def __init__(self, onnx_model: str, conf: float = 0.25, iou: float = 0.7, imgsz: Union[int, Tuple[int, int]] = 640):
  36. """
  37. Initialize the instance segmentation model using an ONNX model.
  38. Args:
  39. onnx_model (str): Path to the ONNX model file.
  40. conf (float, optional): Confidence threshold for filtering detections.
  41. iou (float, optional): IoU threshold for non-maximum suppression.
  42. imgsz (int | Tuple[int, int], optional): Input image size of the model. Can be an integer for square
  43. input or a tuple for rectangular input.
  44. """
  45. self.session = ort.InferenceSession(
  46. onnx_model,
  47. providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
  48. if torch.cuda.is_available()
  49. else ["CPUExecutionProvider"],
  50. )
  51. self.imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
  52. self.classes = YAML.load(check_yaml("coco8.yaml"))["names"]
  53. self.conf = conf
  54. self.iou = iou
  55. def __call__(self, img: np.ndarray) -> List[Results]:
  56. """
  57. Run inference on the input image using the ONNX model.
  58. Args:
  59. img (np.ndarray): The original input image in BGR format.
  60. Returns:
  61. (List[Results]): Processed detection results after post-processing, containing bounding boxes and
  62. segmentation masks.
  63. """
  64. prep_img = self.preprocess(img, self.imgsz)
  65. outs = self.session.run(None, {self.session.get_inputs()[0].name: prep_img})
  66. return self.postprocess(img, prep_img, outs)
  67. def letterbox(self, img: np.ndarray, new_shape: Tuple[int, int] = (640, 640)) -> np.ndarray:
  68. """
  69. Resize and pad image while maintaining aspect ratio.
  70. Args:
  71. img (np.ndarray): Input image in BGR format.
  72. new_shape (Tuple[int, int], optional): Target shape as (height, width).
  73. Returns:
  74. (np.ndarray): Resized and padded image.
  75. """
  76. shape = img.shape[:2] # current shape [height, width]
  77. # Scale ratio (new / old)
  78. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  79. # Compute padding
  80. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  81. dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
  82. if shape[::-1] != new_unpad: # resize
  83. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  84. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  85. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  86. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
  87. return img
  88. def preprocess(self, img: np.ndarray, new_shape: Tuple[int, int]) -> np.ndarray:
  89. """
  90. Preprocess the input image before feeding it into the model.
  91. Args:
  92. img (np.ndarray): The input image in BGR format.
  93. new_shape (Tuple[int, int]): The target shape for resizing as (height, width).
  94. Returns:
  95. (np.ndarray): Preprocessed image ready for model inference, with shape (1, 3, height, width) and
  96. normalized to [0, 1].
  97. """
  98. img = self.letterbox(img, new_shape)
  99. img = img[..., ::-1].transpose([2, 0, 1])[None] # BGR to RGB, BHWC to BCHW
  100. img = np.ascontiguousarray(img)
  101. img = img.astype(np.float32) / 255 # Normalize to [0, 1]
  102. return img
  103. def postprocess(self, img: np.ndarray, prep_img: np.ndarray, outs: List) -> List[Results]:
  104. """
  105. Post-process model predictions to extract meaningful results.
  106. Args:
  107. img (np.ndarray): The original input image.
  108. prep_img (np.ndarray): The preprocessed image used for inference.
  109. outs (List): Model outputs containing predictions and prototype masks.
  110. Returns:
  111. (List[Results]): Processed detection results containing bounding boxes and segmentation masks.
  112. """
  113. preds, protos = [torch.from_numpy(p) for p in outs]
  114. preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))
  115. results = []
  116. for i, pred in enumerate(preds):
  117. pred[:, :4] = ops.scale_boxes(prep_img.shape[2:], pred[:, :4], img.shape)
  118. masks = self.process_mask(protos[i], pred[:, 6:], pred[:, :4], img.shape[:2])
  119. results.append(Results(img, path="", names=self.classes, boxes=pred[:, :6], masks=masks))
  120. return results
  121. def process_mask(
  122. self, protos: torch.Tensor, masks_in: torch.Tensor, bboxes: torch.Tensor, shape: Tuple[int, int]
  123. ) -> torch.Tensor:
  124. """
  125. Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
  126. Args:
  127. protos (torch.Tensor): Prototype masks with shape (mask_dim, mask_h, mask_w).
  128. masks_in (torch.Tensor): Predicted mask coefficients with shape (N, mask_dim), where N is number of
  129. detections.
  130. bboxes (torch.Tensor): Bounding boxes with shape (N, 4), where N is number of detections.
  131. shape (Tuple[int, int]): The size of the input image as (height, width).
  132. Returns:
  133. (torch.Tensor): Binary segmentation masks with shape (N, height, width).
  134. """
  135. c, mh, mw = protos.shape # CHW
  136. masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # Matrix multiplication
  137. masks = ops.scale_masks(masks[None], shape)[0] # Scale masks to original image size
  138. masks = ops.crop_mask(masks, bboxes) # Crop masks to bounding boxes
  139. return masks.gt_(0.0) # Convert to binary masks
  140. if __name__ == "__main__":
  141. parser = argparse.ArgumentParser()
  142. parser.add_argument("--model", type=str, required=True, help="Path to ONNX model")
  143. parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
  144. parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
  145. parser.add_argument("--iou", type=float, default=0.7, help="NMS IoU threshold")
  146. args = parser.parse_args()
  147. model = YOLOv8Seg(args.model, args.conf, args.iou)
  148. img = cv2.imread(args.source)
  149. results = model(img)
  150. cv2.imshow("Segmented Image", results[0].plot())
  151. cv2.waitKey(0)
  152. cv2.destroyAllWindows()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...