Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ops.py 34 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
  1. # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
  2. import contextlib
  3. import math
  4. import re
  5. import time
  6. from typing import Optional
  7. import cv2
  8. import numpy as np
  9. import torch
  10. import torch.nn.functional as F
  11. from ultralytics.utils import LOGGER
  12. from ultralytics.utils.metrics import batch_probiou
  13. class Profile(contextlib.ContextDecorator):
  14. """
  15. Ultralytics Profile class for timing code execution.
  16. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
  17. measurements with CUDA synchronization support for GPU operations.
  18. Attributes:
  19. t (float): Accumulated time in seconds.
  20. device (torch.device): Device used for model inference.
  21. cuda (bool): Whether CUDA is being used for timing synchronization.
  22. Examples:
  23. Use as a context manager to time code execution
  24. >>> with Profile(device=device) as dt:
  25. ... pass # slow operation here
  26. >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
  27. Use as a decorator to time function execution
  28. >>> @Profile()
  29. ... def slow_function():
  30. ... time.sleep(0.1)
  31. """
  32. def __init__(self, t: float = 0.0, device: Optional[torch.device] = None):
  33. """
  34. Initialize the Profile class.
  35. Args:
  36. t (float): Initial accumulated time in seconds.
  37. device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
  38. """
  39. self.t = t
  40. self.device = device
  41. self.cuda = bool(device and str(device).startswith("cuda"))
  42. def __enter__(self):
  43. """Start timing."""
  44. self.start = self.time()
  45. return self
  46. def __exit__(self, type, value, traceback): # noqa
  47. """Stop timing."""
  48. self.dt = self.time() - self.start # delta-time
  49. self.t += self.dt # accumulate dt
  50. def __str__(self):
  51. """Return a human-readable string representing the accumulated elapsed time."""
  52. return f"Elapsed time is {self.t} s"
  53. def time(self):
  54. """Get current time with CUDA synchronization if applicable."""
  55. if self.cuda:
  56. torch.cuda.synchronize(self.device)
  57. return time.perf_counter()
  58. def segment2box(segment, width: int = 640, height: int = 640):
  59. """
  60. Convert segment coordinates to bounding box coordinates.
  61. Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
  62. Applies inside-image constraint and clips coordinates when necessary.
  63. Args:
  64. segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
  65. width (int): Width of the image in pixels.
  66. height (int): Height of the image in pixels.
  67. Returns:
  68. (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
  69. """
  70. x, y = segment.T # segment xy
  71. # Clip coordinates if 3 out of 4 sides are outside the image
  72. if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
  73. x = x.clip(0, width)
  74. y = y.clip(0, height)
  75. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  76. x = x[inside]
  77. y = y[inside]
  78. return (
  79. np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
  80. if any(x)
  81. else np.zeros(4, dtype=segment.dtype)
  82. ) # xyxy
  83. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
  84. """
  85. Rescale bounding boxes from one image shape to another.
  86. Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
  87. Supports both xyxy and xywh box formats.
  88. Args:
  89. img1_shape (tuple): Shape of the source image (height, width).
  90. boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
  91. img0_shape (tuple): Shape of the target image (height, width).
  92. ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
  93. padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
  94. xywh (bool): Whether box format is xywh (True) or xyxy (False).
  95. Returns:
  96. (torch.Tensor): Rescaled bounding boxes in the same format as input.
  97. """
  98. if ratio_pad is None: # calculate from img0_shape
  99. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  100. pad = (
  101. round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
  102. round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
  103. ) # wh padding
  104. else:
  105. gain = ratio_pad[0][0]
  106. pad = ratio_pad[1]
  107. if padding:
  108. boxes[..., 0] -= pad[0] # x padding
  109. boxes[..., 1] -= pad[1] # y padding
  110. if not xywh:
  111. boxes[..., 2] -= pad[0] # x padding
  112. boxes[..., 3] -= pad[1] # y padding
  113. boxes[..., :4] /= gain
  114. return clip_boxes(boxes, img0_shape)
  115. def make_divisible(x: int, divisor):
  116. """
  117. Return the nearest number that is divisible by the given divisor.
  118. Args:
  119. x (int): The number to make divisible.
  120. divisor (int | torch.Tensor): The divisor.
  121. Returns:
  122. (int): The nearest number divisible by the divisor.
  123. """
  124. if isinstance(divisor, torch.Tensor):
  125. divisor = int(divisor.max()) # to int
  126. return math.ceil(x / divisor) * divisor
  127. def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True):
  128. """
  129. Perform NMS on oriented bounding boxes using probiou and fast-nms.
  130. Args:
  131. boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format.
  132. scores (torch.Tensor): Confidence scores with shape (N,).
  133. threshold (float): IoU threshold for NMS.
  134. use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
  135. Returns:
  136. (torch.Tensor): Indices of boxes to keep after NMS.
  137. """
  138. sorted_idx = torch.argsort(scores, descending=True)
  139. boxes = boxes[sorted_idx]
  140. ious = batch_probiou(boxes, boxes)
  141. if use_triu:
  142. ious = ious.triu_(diagonal=1)
  143. # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
  144. pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
  145. else:
  146. n = boxes.shape[0]
  147. row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
  148. col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
  149. upper_mask = row_idx < col_idx
  150. ious = ious * upper_mask
  151. # Zeroing these scores ensures the additional indices would not affect the final results
  152. scores[~((ious >= threshold).sum(0) <= 0)] = 0
  153. # NOTE: return indices with fixed length to avoid TFLite reshape error
  154. pick = torch.topk(scores, scores.shape[0]).indices
  155. return sorted_idx[pick]
  156. def non_max_suppression(
  157. prediction,
  158. conf_thres: float = 0.25,
  159. iou_thres: float = 0.45,
  160. classes=None,
  161. agnostic: bool = False,
  162. multi_label: bool = False,
  163. labels=(),
  164. max_det: int = 300,
  165. nc: int = 0, # number of classes (optional)
  166. max_time_img: float = 0.05,
  167. max_nms: int = 30000,
  168. max_wh: int = 7680,
  169. in_place: bool = True,
  170. rotated: bool = False,
  171. end2end: bool = False,
  172. return_idxs: bool = False,
  173. ):
  174. """
  175. Perform non-maximum suppression (NMS) on prediction results.
  176. Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
  177. detection formats including standard boxes, rotated boxes, and masks.
  178. Args:
  179. prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
  180. containing boxes, classes, and optional masks.
  181. conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
  182. iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
  183. classes (List[int], optional): List of class indices to consider. If None, all classes are considered.
  184. agnostic (bool): Whether to perform class-agnostic NMS.
  185. multi_label (bool): Whether each box can have multiple labels.
  186. labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image.
  187. max_det (int): Maximum number of detections to keep per image.
  188. nc (int): Number of classes. Indices after this are considered masks.
  189. max_time_img (float): Maximum time in seconds for processing one image.
  190. max_nms (int): Maximum number of boxes for torchvision.ops.nms().
  191. max_wh (int): Maximum box width and height in pixels.
  192. in_place (bool): Whether to modify the input prediction tensor in place.
  193. rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
  194. end2end (bool): Whether the model is end-to-end and doesn't require NMS.
  195. return_idxs (bool): Whether to return the indices of kept detections.
  196. Returns:
  197. output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
  198. containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
  199. keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True.
  200. """
  201. import torchvision # scope for faster 'import ultralytics'
  202. # Checks
  203. assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
  204. assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
  205. if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
  206. prediction = prediction[0] # select only inference output
  207. if classes is not None:
  208. classes = torch.tensor(classes, device=prediction.device)
  209. if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
  210. output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
  211. if classes is not None:
  212. output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
  213. return output
  214. bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
  215. nc = nc or (prediction.shape[1] - 4) # number of classes
  216. extra = prediction.shape[1] - nc - 4 # number of extra info
  217. mi = 4 + nc # mask start index
  218. xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
  219. xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs
  220. # Settings
  221. # min_wh = 2 # (pixels) minimum box width and height
  222. time_limit = 2.0 + max_time_img * bs # seconds to quit after
  223. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  224. prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
  225. if not rotated:
  226. if in_place:
  227. prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
  228. else:
  229. prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
  230. t = time.time()
  231. output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
  232. keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
  233. for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
  234. # Apply constraints
  235. # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
  236. filt = xc[xi] # confidence
  237. x, xk = x[filt], xk[filt]
  238. # Cat apriori labels if autolabelling
  239. if labels and len(labels[xi]) and not rotated:
  240. lb = labels[xi]
  241. v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
  242. v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
  243. v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
  244. x = torch.cat((x, v), 0)
  245. # If none remain process next image
  246. if not x.shape[0]:
  247. continue
  248. # Detections matrix nx6 (xyxy, conf, cls)
  249. box, cls, mask = x.split((4, nc, extra), 1)
  250. if multi_label:
  251. i, j = torch.where(cls > conf_thres)
  252. x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
  253. xk = xk[i]
  254. else: # best class only
  255. conf, j = cls.max(1, keepdim=True)
  256. filt = conf.view(-1) > conf_thres
  257. x = torch.cat((box, conf, j.float(), mask), 1)[filt]
  258. xk = xk[filt]
  259. # Filter by class
  260. if classes is not None:
  261. filt = (x[:, 5:6] == classes).any(1)
  262. x, xk = x[filt], xk[filt]
  263. # Check shape
  264. n = x.shape[0] # number of boxes
  265. if not n: # no boxes
  266. continue
  267. if n > max_nms: # excess boxes
  268. filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
  269. x, xk = x[filt], xk[filt]
  270. # Batched NMS
  271. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  272. scores = x[:, 4] # scores
  273. if rotated:
  274. boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
  275. i = nms_rotated(boxes, scores, iou_thres)
  276. else:
  277. boxes = x[:, :4] + c # boxes (offset by class)
  278. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  279. i = i[:max_det] # limit detections
  280. output[xi], keepi[xi] = x[i], xk[i].reshape(-1)
  281. if (time.time() - t) > time_limit:
  282. LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
  283. break # time limit exceeded
  284. return (output, keepi) if return_idxs else output
  285. def clip_boxes(boxes, shape):
  286. """
  287. Clip bounding boxes to image boundaries.
  288. Args:
  289. boxes (torch.Tensor | numpy.ndarray): Bounding boxes to clip.
  290. shape (tuple): Image shape as (height, width).
  291. Returns:
  292. (torch.Tensor | numpy.ndarray): Clipped bounding boxes.
  293. """
  294. if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
  295. boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
  296. boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
  297. boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
  298. boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
  299. else: # np.array (faster grouped)
  300. boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
  301. boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
  302. return boxes
  303. def clip_coords(coords, shape):
  304. """
  305. Clip line coordinates to image boundaries.
  306. Args:
  307. coords (torch.Tensor | numpy.ndarray): Line coordinates to clip.
  308. shape (tuple): Image shape as (height, width).
  309. Returns:
  310. (torch.Tensor | numpy.ndarray): Clipped coordinates.
  311. """
  312. if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
  313. coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
  314. coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
  315. else: # np.array (faster grouped)
  316. coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
  317. coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
  318. return coords
  319. def scale_image(masks, im0_shape, ratio_pad=None):
  320. """
  321. Rescale masks to original image size.
  322. Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
  323. that was applied during preprocessing.
  324. Args:
  325. masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
  326. im0_shape (tuple): Original image shape as (height, width).
  327. ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
  328. Returns:
  329. (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
  330. """
  331. # Rescale coordinates (xyxy) from im1_shape to im0_shape
  332. im1_shape = masks.shape
  333. if im1_shape[:2] == im0_shape[:2]:
  334. return masks
  335. if ratio_pad is None: # calculate from im0_shape
  336. gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
  337. pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
  338. else:
  339. pad = ratio_pad[1]
  340. top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)))
  341. bottom, right = (
  342. im1_shape[0] - int(round(pad[1] + 0.1)),
  343. im1_shape[1] - int(round(pad[0] + 0.1)),
  344. )
  345. if len(masks.shape) < 2:
  346. raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
  347. masks = masks[top:bottom, left:right]
  348. masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
  349. if len(masks.shape) == 2:
  350. masks = masks[:, :, None]
  351. return masks
  352. def xyxy2xywh(x):
  353. """
  354. Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
  355. top-left corner and (x2, y2) is the bottom-right corner.
  356. Args:
  357. x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
  358. Returns:
  359. (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
  360. """
  361. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  362. y = empty_like(x) # faster than clone/copy
  363. y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
  364. y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
  365. y[..., 2] = x[..., 2] - x[..., 0] # width
  366. y[..., 3] = x[..., 3] - x[..., 1] # height
  367. return y
  368. def xywh2xyxy(x):
  369. """
  370. Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
  371. top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
  372. Args:
  373. x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
  374. Returns:
  375. (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
  376. """
  377. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  378. y = empty_like(x) # faster than clone/copy
  379. xy = x[..., :2] # centers
  380. wh = x[..., 2:] / 2 # half width-height
  381. y[..., :2] = xy - wh # top left xy
  382. y[..., 2:] = xy + wh # bottom right xy
  383. return y
  384. def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
  385. """
  386. Convert normalized bounding box coordinates to pixel coordinates.
  387. Args:
  388. x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
  389. w (int): Image width in pixels.
  390. h (int): Image height in pixels.
  391. padw (int): Padding width in pixels.
  392. padh (int): Padding height in pixels.
  393. Returns:
  394. y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
  395. x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
  396. """
  397. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  398. y = empty_like(x) # faster than clone/copy
  399. y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
  400. y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
  401. y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
  402. y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
  403. return y
  404. def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
  405. """
  406. Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
  407. width and height are normalized to image dimensions.
  408. Args:
  409. x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
  410. w (int): Image width in pixels.
  411. h (int): Image height in pixels.
  412. clip (bool): Whether to clip boxes to image boundaries.
  413. eps (float): Minimum value for box width and height.
  414. Returns:
  415. (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
  416. """
  417. if clip:
  418. x = clip_boxes(x, (h - eps, w - eps))
  419. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  420. y = empty_like(x) # faster than clone/copy
  421. y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
  422. y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
  423. y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
  424. y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
  425. return y
  426. def xywh2ltwh(x):
  427. """
  428. Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
  429. Args:
  430. x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
  431. Returns:
  432. (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
  433. """
  434. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  435. y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
  436. y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
  437. return y
  438. def xyxy2ltwh(x):
  439. """
  440. Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
  441. Args:
  442. x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
  443. Returns:
  444. (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
  445. """
  446. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  447. y[..., 2] = x[..., 2] - x[..., 0] # width
  448. y[..., 3] = x[..., 3] - x[..., 1] # height
  449. return y
  450. def ltwh2xywh(x):
  451. """
  452. Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
  453. Args:
  454. x (torch.Tensor): Input bounding box coordinates.
  455. Returns:
  456. (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
  457. """
  458. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  459. y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
  460. y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
  461. return y
  462. def xyxyxyxy2xywhr(x):
  463. """
  464. Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
  465. Args:
  466. x (numpy.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
  467. Returns:
  468. (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
  469. Rotation values are in radians from 0 to pi/2.
  470. """
  471. is_torch = isinstance(x, torch.Tensor)
  472. points = x.cpu().numpy() if is_torch else x
  473. points = points.reshape(len(x), -1, 2)
  474. rboxes = []
  475. for pts in points:
  476. # NOTE: Use cv2.minAreaRect to get accurate xywhr,
  477. # especially some objects are cut off by augmentations in dataloader.
  478. (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
  479. rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
  480. return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
  481. def xywhr2xyxyxyxy(x):
  482. """
  483. Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
  484. Args:
  485. x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
  486. Rotation values should be in radians from 0 to pi/2.
  487. Returns:
  488. (numpy.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
  489. """
  490. cos, sin, cat, stack = (
  491. (torch.cos, torch.sin, torch.cat, torch.stack)
  492. if isinstance(x, torch.Tensor)
  493. else (np.cos, np.sin, np.concatenate, np.stack)
  494. )
  495. ctr = x[..., :2]
  496. w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
  497. cos_value, sin_value = cos(angle), sin(angle)
  498. vec1 = [w / 2 * cos_value, w / 2 * sin_value]
  499. vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
  500. vec1 = cat(vec1, -1)
  501. vec2 = cat(vec2, -1)
  502. pt1 = ctr + vec1 + vec2
  503. pt2 = ctr + vec1 - vec2
  504. pt3 = ctr - vec1 - vec2
  505. pt4 = ctr - vec1 + vec2
  506. return stack([pt1, pt2, pt3, pt4], -2)
  507. def ltwh2xyxy(x):
  508. """
  509. Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
  510. Args:
  511. x (np.ndarray | torch.Tensor): Input bounding box coordinates.
  512. Returns:
  513. (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
  514. """
  515. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  516. y[..., 2] = x[..., 2] + x[..., 0] # width
  517. y[..., 3] = x[..., 3] + x[..., 1] # height
  518. return y
  519. def segments2boxes(segments):
  520. """
  521. Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
  522. Args:
  523. segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
  524. Returns:
  525. (np.ndarray): Bounding box coordinates in xywh format.
  526. """
  527. boxes = []
  528. for s in segments:
  529. x, y = s.T # segment xy
  530. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  531. return xyxy2xywh(np.array(boxes)) # cls, xywh
  532. def resample_segments(segments, n: int = 1000):
  533. """
  534. Resample segments to n points each using linear interpolation.
  535. Args:
  536. segments (list): List of (N, 2) arrays where N is the number of points in each segment.
  537. n (int): Number of points to resample each segment to.
  538. Returns:
  539. (list): Resampled segments with n points each.
  540. """
  541. for i, s in enumerate(segments):
  542. if len(s) == n:
  543. continue
  544. s = np.concatenate((s, s[0:1, :]), axis=0)
  545. x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
  546. xp = np.arange(len(s))
  547. x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
  548. segments[i] = (
  549. np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
  550. ) # segment xy
  551. return segments
  552. def crop_mask(masks, boxes):
  553. """
  554. Crop masks to bounding box regions.
  555. Args:
  556. masks (torch.Tensor): Masks with shape (N, H, W).
  557. boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
  558. Returns:
  559. (torch.Tensor): Cropped masks.
  560. """
  561. _, h, w = masks.shape
  562. x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
  563. r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
  564. c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
  565. return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
  566. def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
  567. """
  568. Apply masks to bounding boxes using mask head output.
  569. Args:
  570. protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
  571. masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
  572. bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
  573. shape (tuple): Input image size as (height, width).
  574. upsample (bool): Whether to upsample masks to original image size.
  575. Returns:
  576. (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
  577. are the height and width of the input image. The mask is applied to the bounding boxes.
  578. """
  579. c, mh, mw = protos.shape # CHW
  580. ih, iw = shape
  581. masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
  582. width_ratio = mw / iw
  583. height_ratio = mh / ih
  584. downsampled_bboxes = bboxes.clone()
  585. downsampled_bboxes[:, 0] *= width_ratio
  586. downsampled_bboxes[:, 2] *= width_ratio
  587. downsampled_bboxes[:, 3] *= height_ratio
  588. downsampled_bboxes[:, 1] *= height_ratio
  589. masks = crop_mask(masks, downsampled_bboxes) # CHW
  590. if upsample:
  591. masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
  592. return masks.gt_(0.0)
  593. def process_mask_native(protos, masks_in, bboxes, shape):
  594. """
  595. Apply masks to bounding boxes using mask head output with native upsampling.
  596. Args:
  597. protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
  598. masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
  599. bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
  600. shape (tuple): Input image size as (height, width).
  601. Returns:
  602. (torch.Tensor): Binary mask tensor with shape (H, W, N).
  603. """
  604. c, mh, mw = protos.shape # CHW
  605. masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
  606. masks = scale_masks(masks[None], shape)[0] # CHW
  607. masks = crop_mask(masks, bboxes) # CHW
  608. return masks.gt_(0.0)
  609. def scale_masks(masks, shape, padding: bool = True):
  610. """
  611. Rescale segment masks to target shape.
  612. Args:
  613. masks (torch.Tensor): Masks with shape (N, C, H, W).
  614. shape (tuple): Target height and width as (height, width).
  615. padding (bool): Whether masks are based on YOLO-style augmented images with padding.
  616. Returns:
  617. (torch.Tensor): Rescaled masks.
  618. """
  619. mh, mw = masks.shape[2:]
  620. gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
  621. pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
  622. if padding:
  623. pad[0] /= 2
  624. pad[1] /= 2
  625. top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) if padding else (0, 0) # y, x
  626. bottom, right = (
  627. mh - int(round(pad[1] + 0.1)),
  628. mw - int(round(pad[0] + 0.1)),
  629. )
  630. masks = masks[..., top:bottom, left:right]
  631. masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
  632. return masks
  633. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
  634. """
  635. Rescale segment coordinates from img1_shape to img0_shape.
  636. Args:
  637. img1_shape (tuple): Shape of the source image.
  638. coords (torch.Tensor): Coordinates to scale with shape (N, 2).
  639. img0_shape (tuple): Shape of the target image.
  640. ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
  641. normalize (bool): Whether to normalize coordinates to range [0, 1].
  642. padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
  643. Returns:
  644. (torch.Tensor): Scaled coordinates.
  645. """
  646. if ratio_pad is None: # calculate from img0_shape
  647. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  648. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  649. else:
  650. gain = ratio_pad[0][0]
  651. pad = ratio_pad[1]
  652. if padding:
  653. coords[..., 0] -= pad[0] # x padding
  654. coords[..., 1] -= pad[1] # y padding
  655. coords[..., 0] /= gain
  656. coords[..., 1] /= gain
  657. coords = clip_coords(coords, img0_shape)
  658. if normalize:
  659. coords[..., 0] /= img0_shape[1] # width
  660. coords[..., 1] /= img0_shape[0] # height
  661. return coords
  662. def regularize_rboxes(rboxes):
  663. """
  664. Regularize rotated bounding boxes to range [0, pi/2].
  665. Args:
  666. rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
  667. Returns:
  668. (torch.Tensor): Regularized rotated boxes.
  669. """
  670. x, y, w, h, t = rboxes.unbind(dim=-1)
  671. # Swap edge if t >= pi/2 while not being symmetrically opposite
  672. swap = t % math.pi >= math.pi / 2
  673. w_ = torch.where(swap, h, w)
  674. h_ = torch.where(swap, w, h)
  675. t = t % (math.pi / 2)
  676. return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
  677. def masks2segments(masks, strategy: str = "all"):
  678. """
  679. Convert masks to segments using contour detection.
  680. Args:
  681. masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
  682. strategy (str): Segmentation strategy, either 'all' or 'largest'.
  683. Returns:
  684. (list): List of segment masks as float32 arrays.
  685. """
  686. from ultralytics.data.converter import merge_multi_segment
  687. segments = []
  688. for x in masks.int().cpu().numpy().astype("uint8"):
  689. c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
  690. if c:
  691. if strategy == "all": # merge and concatenate all segments
  692. c = (
  693. np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
  694. if len(c) > 1
  695. else c[0].reshape(-1, 2)
  696. )
  697. elif strategy == "largest": # select largest segment
  698. c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
  699. else:
  700. c = np.zeros((0, 2)) # no segments found
  701. segments.append(c.astype("float32"))
  702. return segments
  703. def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
  704. """
  705. Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
  706. Args:
  707. batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
  708. Returns:
  709. (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
  710. """
  711. return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
  712. def clean_str(s):
  713. """
  714. Clean a string by replacing special characters with '_' character.
  715. Args:
  716. s (str): A string needing special characters replaced.
  717. Returns:
  718. (str): A string with special characters replaced by an underscore _.
  719. """
  720. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  721. def empty_like(x):
  722. """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
  723. return (
  724. torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
  725. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...