Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#20413 YOLOE: Fix visual prompt training

Merged
Ghost merged 1 commits into Ultralytics:main from ultralytics:yoloe-vp-fix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
  1. # Ultralytics ๐Ÿš€ AGPL-3.0 License - https://ultralytics.com/license
  2. """Model validation metrics."""
  3. import math
  4. import warnings
  5. from pathlib import Path
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import torch
  9. from ultralytics.utils import LOGGER, SimpleClass, TryExcept, checks, plt_settings
  10. OKS_SIGMA = (
  11. np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
  12. / 10.0
  13. )
  14. def bbox_ioa(box1, box2, iou=False, eps=1e-7):
  15. """
  16. Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
  17. Args:
  18. box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.
  19. box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.
  20. iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.
  21. eps (float, optional): A small value to avoid division by zero.
  22. Returns:
  23. (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area.
  24. """
  25. # Get the coordinates of bounding boxes
  26. b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
  27. b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
  28. # Intersection area
  29. inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
  30. np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
  31. ).clip(0)
  32. # Box2 area
  33. area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
  34. if iou:
  35. box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
  36. area = area + box1_area[:, None] - inter_area
  37. # Intersection over box2 area
  38. return inter_area / (area + eps)
  39. def box_iou(box1, box2, eps=1e-7):
  40. """
  41. Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  42. Based on https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py.
  43. Args:
  44. box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
  45. box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
  46. eps (float, optional): A small value to avoid division by zero.
  47. Returns:
  48. (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
  49. """
  50. # NOTE: Need .float() to get accurate iou values
  51. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  52. (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
  53. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
  54. # IoU = inter / (area1 + area2 - inter)
  55. return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  56. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  57. """
  58. Calculate the Intersection over Union (IoU) between bounding boxes.
  59. This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
  60. For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
  61. Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
  62. or (x1, y1, x2, y2) if `xywh=False`.
  63. Args:
  64. box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
  65. box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
  66. xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
  67. (x1, y1, x2, y2) format.
  68. GIoU (bool, optional): If True, calculate Generalized IoU.
  69. DIoU (bool, optional): If True, calculate Distance IoU.
  70. CIoU (bool, optional): If True, calculate Complete IoU.
  71. eps (float, optional): A small value to avoid division by zero.
  72. Returns:
  73. (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
  74. """
  75. # Get the coordinates of bounding boxes
  76. if xywh: # transform from xywh to xyxy
  77. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
  78. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  79. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  80. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  81. else: # x1, y1, x2, y2 = box1
  82. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
  83. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
  84. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  85. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  86. # Intersection area
  87. inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
  88. b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
  89. ).clamp_(0)
  90. # Union Area
  91. union = w1 * h1 + w2 * h2 - inter + eps
  92. # IoU
  93. iou = inter / union
  94. if CIoU or DIoU or GIoU:
  95. cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
  96. ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
  97. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  98. c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared
  99. rho2 = (
  100. (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)
  101. ) / 4 # center dist**2
  102. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  103. v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
  104. with torch.no_grad():
  105. alpha = v / (v - iou + (1 + eps))
  106. return iou - (rho2 / c2 + v * alpha) # CIoU
  107. return iou - rho2 / c2 # DIoU
  108. c_area = cw * ch + eps # convex area
  109. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  110. return iou # IoU
  111. def mask_iou(mask1, mask2, eps=1e-7):
  112. """
  113. Calculate masks IoU.
  114. Args:
  115. mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
  116. product of image width and height.
  117. mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
  118. product of image width and height.
  119. eps (float, optional): A small value to avoid division by zero.
  120. Returns:
  121. (torch.Tensor): A tensor of shape (N, M) representing masks IoU.
  122. """
  123. intersection = torch.matmul(mask1, mask2.T).clamp_(0)
  124. union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
  125. return intersection / (union + eps)
  126. def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
  127. """
  128. Calculate Object Keypoint Similarity (OKS).
  129. Args:
  130. kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
  131. kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
  132. area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
  133. sigma (list): A list containing 17 values representing keypoint scales.
  134. eps (float, optional): A small value to avoid division by zero.
  135. Returns:
  136. (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
  137. """
  138. d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)
  139. sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
  140. kpt_mask = kpt1[..., 2] != 0 # (N, 17)
  141. e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval
  142. # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
  143. return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
  144. def _get_covariance_matrix(boxes):
  145. """
  146. Generate covariance matrix from oriented bounding boxes.
  147. Args:
  148. boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
  149. Returns:
  150. (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
  151. """
  152. # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
  153. gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
  154. a, b, c = gbbs.split(1, dim=-1)
  155. cos = c.cos()
  156. sin = c.sin()
  157. cos2 = cos.pow(2)
  158. sin2 = sin.pow(2)
  159. return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin
  160. def probiou(obb1, obb2, CIoU=False, eps=1e-7):
  161. """
  162. Calculate probabilistic IoU between oriented bounding boxes.
  163. Args:
  164. obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
  165. obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
  166. CIoU (bool, optional): If True, calculate CIoU.
  167. eps (float, optional): Small value to avoid division by zero.
  168. Returns:
  169. (torch.Tensor): OBB similarities, shape (N,).
  170. Notes:
  171. - OBB format: [center_x, center_y, width, height, rotation_angle].
  172. - Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf.
  173. """
  174. x1, y1 = obb1[..., :2].split(1, dim=-1)
  175. x2, y2 = obb2[..., :2].split(1, dim=-1)
  176. a1, b1, c1 = _get_covariance_matrix(obb1)
  177. a2, b2, c2 = _get_covariance_matrix(obb2)
  178. t1 = (
  179. ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
  180. ) * 0.25
  181. t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
  182. t3 = (
  183. ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
  184. / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
  185. + eps
  186. ).log() * 0.5
  187. bd = (t1 + t2 + t3).clamp(eps, 100.0)
  188. hd = (1.0 - (-bd).exp() + eps).sqrt()
  189. iou = 1 - hd
  190. if CIoU: # only include the wh aspect ratio part
  191. w1, h1 = obb1[..., 2:4].split(1, dim=-1)
  192. w2, h2 = obb2[..., 2:4].split(1, dim=-1)
  193. v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
  194. with torch.no_grad():
  195. alpha = v / (v - iou + (1 + eps))
  196. return iou - v * alpha # CIoU
  197. return iou
  198. def batch_probiou(obb1, obb2, eps=1e-7):
  199. """
  200. Calculate the probabilistic IoU between oriented bounding boxes.
  201. Args:
  202. obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
  203. obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
  204. eps (float, optional): A small value to avoid division by zero.
  205. Returns:
  206. (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
  207. References:
  208. https://arxiv.org/pdf/2106.06072v1.pdf
  209. """
  210. obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
  211. obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
  212. x1, y1 = obb1[..., :2].split(1, dim=-1)
  213. x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
  214. a1, b1, c1 = _get_covariance_matrix(obb1)
  215. a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
  216. t1 = (
  217. ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
  218. ) * 0.25
  219. t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
  220. t3 = (
  221. ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
  222. / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
  223. + eps
  224. ).log() * 0.5
  225. bd = (t1 + t2 + t3).clamp(eps, 100.0)
  226. hd = (1.0 - (-bd).exp() + eps).sqrt()
  227. return 1 - hd
  228. def smooth_bce(eps=0.1):
  229. """
  230. Compute smoothed positive and negative Binary Cross-Entropy targets.
  231. Args:
  232. eps (float, optional): The epsilon value for label smoothing.
  233. Returns:
  234. (tuple): A tuple containing the positive and negative label smoothing BCE targets.
  235. References:
  236. https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  237. """
  238. return 1.0 - 0.5 * eps, 0.5 * eps
  239. class ConfusionMatrix:
  240. """
  241. A class for calculating and updating a confusion matrix for object detection and classification tasks.
  242. Attributes:
  243. task (str): The type of task, either 'detect' or 'classify'.
  244. matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
  245. nc (int): The number of classes.
  246. conf (float): The confidence threshold for detections.
  247. iou_thres (float): The Intersection over Union threshold.
  248. """
  249. def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"):
  250. """
  251. Initialize a ConfusionMatrix instance.
  252. Args:
  253. nc (int): Number of classes.
  254. conf (float, optional): Confidence threshold for detections.
  255. iou_thres (float, optional): IoU threshold for matching detections to ground truth.
  256. task (str, optional): Type of task, either 'detect' or 'classify'.
  257. """
  258. self.task = task
  259. self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
  260. self.nc = nc # number of classes
  261. self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
  262. self.iou_thres = iou_thres
  263. def process_cls_preds(self, preds, targets):
  264. """
  265. Update confusion matrix for classification task.
  266. Args:
  267. preds (Array[N, min(nc,5)]): Predicted class labels.
  268. targets (Array[N, 1]): Ground truth class labels.
  269. """
  270. preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
  271. for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
  272. self.matrix[p][t] += 1
  273. def process_batch(self, detections, gt_bboxes, gt_cls):
  274. """
  275. Update confusion matrix for object detection task.
  276. Args:
  277. detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information.
  278. Each row should contain (x1, y1, x2, y2, conf, class)
  279. or with an additional element `angle` when it's obb.
  280. gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format.
  281. gt_cls (Array[M]): The class labels.
  282. """
  283. if gt_cls.shape[0] == 0: # Check if labels is empty
  284. if detections is not None:
  285. detections = detections[detections[:, 4] > self.conf]
  286. detection_classes = detections[:, 5].int()
  287. for dc in detection_classes:
  288. self.matrix[dc, self.nc] += 1 # false positives
  289. return
  290. if detections is None:
  291. gt_classes = gt_cls.int()
  292. for gc in gt_classes:
  293. self.matrix[self.nc, gc] += 1 # background FN
  294. return
  295. detections = detections[detections[:, 4] > self.conf]
  296. gt_classes = gt_cls.int()
  297. detection_classes = detections[:, 5].int()
  298. is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension
  299. iou = (
  300. batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
  301. if is_obb
  302. else box_iou(gt_bboxes, detections[:, :4])
  303. )
  304. x = torch.where(iou > self.iou_thres)
  305. if x[0].shape[0]:
  306. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  307. if x[0].shape[0] > 1:
  308. matches = matches[matches[:, 2].argsort()[::-1]]
  309. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  310. matches = matches[matches[:, 2].argsort()[::-1]]
  311. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  312. else:
  313. matches = np.zeros((0, 3))
  314. n = matches.shape[0] > 0
  315. m0, m1, _ = matches.transpose().astype(int)
  316. for i, gc in enumerate(gt_classes):
  317. j = m0 == i
  318. if n and sum(j) == 1:
  319. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  320. else:
  321. self.matrix[self.nc, gc] += 1 # true background
  322. for i, dc in enumerate(detection_classes):
  323. if not any(m1 == i):
  324. self.matrix[dc, self.nc] += 1 # predicted background
  325. def matrix(self):
  326. """Return the confusion matrix."""
  327. return self.matrix
  328. def tp_fp(self):
  329. """
  330. Return true positives and false positives.
  331. Returns:
  332. (tuple): True positives and false positives.
  333. """
  334. tp = self.matrix.diagonal() # true positives
  335. fp = self.matrix.sum(1) - tp # false positives
  336. # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
  337. return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect
  338. @TryExcept(msg="ConfusionMatrix plot failure")
  339. @plt_settings()
  340. def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
  341. """
  342. Plot the confusion matrix using seaborn and save it to a file.
  343. Args:
  344. normalize (bool): Whether to normalize the confusion matrix.
  345. save_dir (str): Directory where the plot will be saved.
  346. names (tuple): Names of classes, used as labels on the plot.
  347. on_plot (func): An optional callback to pass plots path and data when they are rendered.
  348. """
  349. import seaborn # scope for faster 'import ultralytics'
  350. array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
  351. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  352. fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
  353. nc, nn = self.nc, len(names) # number of classes, names
  354. seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
  355. labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
  356. ticklabels = (list(names) + ["background"]) if labels else "auto"
  357. with warnings.catch_warnings():
  358. warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
  359. seaborn.heatmap(
  360. array,
  361. ax=ax,
  362. annot=nc < 30,
  363. annot_kws={"size": 8},
  364. cmap="Blues",
  365. fmt=".2f" if normalize else ".0f",
  366. square=True,
  367. vmin=0.0,
  368. xticklabels=ticklabels,
  369. yticklabels=ticklabels,
  370. ).set_facecolor((1, 1, 1))
  371. title = "Confusion Matrix" + " Normalized" * normalize
  372. ax.set_xlabel("True")
  373. ax.set_ylabel("Predicted")
  374. ax.set_title(title)
  375. plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
  376. fig.savefig(plot_fname, dpi=250)
  377. plt.close(fig)
  378. if on_plot:
  379. on_plot(plot_fname)
  380. def print(self):
  381. """Print the confusion matrix to the console."""
  382. for i in range(self.matrix.shape[0]):
  383. LOGGER.info(" ".join(map(str, self.matrix[i])))
  384. def smooth(y, f=0.05):
  385. """Box filter of fraction f."""
  386. nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
  387. p = np.ones(nf // 2) # ones padding
  388. yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
  389. return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
  390. @plt_settings()
  391. def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
  392. """
  393. Plot precision-recall curve.
  394. Args:
  395. px (np.ndarray): X values for the PR curve.
  396. py (np.ndarray): Y values for the PR curve.
  397. ap (np.ndarray): Average precision values.
  398. save_dir (Path, optional): Path to save the plot.
  399. names (dict, optional): Dictionary mapping class indices to class names.
  400. on_plot (callable, optional): Function to call after plot is saved.
  401. """
  402. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  403. py = np.stack(py, axis=1)
  404. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  405. for i, y in enumerate(py.T):
  406. ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
  407. else:
  408. ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
  409. ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
  410. ax.set_xlabel("Recall")
  411. ax.set_ylabel("Precision")
  412. ax.set_xlim(0, 1)
  413. ax.set_ylim(0, 1)
  414. ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  415. ax.set_title("Precision-Recall Curve")
  416. fig.savefig(save_dir, dpi=250)
  417. plt.close(fig)
  418. if on_plot:
  419. on_plot(save_dir)
  420. @plt_settings()
  421. def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None):
  422. """
  423. Plot metric-confidence curve.
  424. Args:
  425. px (np.ndarray): X values for the metric-confidence curve.
  426. py (np.ndarray): Y values for the metric-confidence curve.
  427. save_dir (Path, optional): Path to save the plot.
  428. names (dict, optional): Dictionary mapping class indices to class names.
  429. xlabel (str, optional): X-axis label.
  430. ylabel (str, optional): Y-axis label.
  431. on_plot (callable, optional): Function to call after plot is saved.
  432. """
  433. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  434. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  435. for i, y in enumerate(py):
  436. ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
  437. else:
  438. ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
  439. y = smooth(py.mean(0), 0.1)
  440. ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
  441. ax.set_xlabel(xlabel)
  442. ax.set_ylabel(ylabel)
  443. ax.set_xlim(0, 1)
  444. ax.set_ylim(0, 1)
  445. ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  446. ax.set_title(f"{ylabel}-Confidence Curve")
  447. fig.savefig(save_dir, dpi=250)
  448. plt.close(fig)
  449. if on_plot:
  450. on_plot(save_dir)
  451. def compute_ap(recall, precision):
  452. """
  453. Compute the average precision (AP) given the recall and precision curves.
  454. Args:
  455. recall (list): The recall curve.
  456. precision (list): The precision curve.
  457. Returns:
  458. (float): Average precision.
  459. (np.ndarray): Precision envelope curve.
  460. (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
  461. """
  462. # Append sentinel values to beginning and end
  463. mrec = np.concatenate(([0.0], recall, [1.0]))
  464. mpre = np.concatenate(([1.0], precision, [0.0]))
  465. # Compute the precision envelope
  466. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  467. # Integrate area under curve
  468. method = "interp" # methods: 'continuous', 'interp'
  469. if method == "interp":
  470. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  471. func = np.trapezoid if checks.check_version(np.__version__, ">=2.0") else np.trapz # np.trapz deprecated
  472. ap = func(np.interp(x, mrec, mpre), x) # integrate
  473. else: # 'continuous'
  474. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
  475. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  476. return ap, mpre, mrec
  477. def ap_per_class(
  478. tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix=""
  479. ):
  480. """
  481. Compute the average precision per class for object detection evaluation.
  482. Args:
  483. tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
  484. conf (np.ndarray): Array of confidence scores of the detections.
  485. pred_cls (np.ndarray): Array of predicted classes of the detections.
  486. target_cls (np.ndarray): Array of true classes of the detections.
  487. plot (bool, optional): Whether to plot PR curves or not.
  488. on_plot (func, optional): A callback to pass plots path and data when they are rendered.
  489. save_dir (Path, optional): Directory to save the PR curves.
  490. names (dict, optional): Dict of class names to plot PR curves.
  491. eps (float, optional): A small value to avoid division by zero.
  492. prefix (str, optional): A prefix string for saving the plot files.
  493. Returns:
  494. tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.
  495. fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.
  496. p (np.ndarray): Precision values at threshold given by max F1 metric for each class.
  497. r (np.ndarray): Recall values at threshold given by max F1 metric for each class.
  498. f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.
  499. ap (np.ndarray): Average precision for each class at different IoU thresholds.
  500. unique_classes (np.ndarray): An array of unique classes that have data.
  501. p_curve (np.ndarray): Precision curves for each class.
  502. r_curve (np.ndarray): Recall curves for each class.
  503. f1_curve (np.ndarray): F1-score curves for each class.
  504. x (np.ndarray): X-axis values for the curves.
  505. prec_values (np.ndarray): Precision values at mAP@0.5 for each class.
  506. """
  507. # Sort by objectness
  508. i = np.argsort(-conf)
  509. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  510. # Find unique classes
  511. unique_classes, nt = np.unique(target_cls, return_counts=True)
  512. nc = unique_classes.shape[0] # number of classes, number of detections
  513. # Create Precision-Recall curve and compute AP for each class
  514. x, prec_values = np.linspace(0, 1, 1000), []
  515. # Average precision, precision and recall curves
  516. ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  517. for ci, c in enumerate(unique_classes):
  518. i = pred_cls == c
  519. n_l = nt[ci] # number of labels
  520. n_p = i.sum() # number of predictions
  521. if n_p == 0 or n_l == 0:
  522. continue
  523. # Accumulate FPs and TPs
  524. fpc = (1 - tp[i]).cumsum(0)
  525. tpc = tp[i].cumsum(0)
  526. # Recall
  527. recall = tpc / (n_l + eps) # recall curve
  528. r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  529. # Precision
  530. precision = tpc / (tpc + fpc) # precision curve
  531. p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score
  532. # AP from recall-precision curve
  533. for j in range(tp.shape[1]):
  534. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  535. if j == 0:
  536. prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5
  537. prec_values = np.array(prec_values) if prec_values else np.zeros((1, 1000)) # (nc, 1000)
  538. # Compute F1 (harmonic mean of precision and recall)
  539. f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
  540. names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
  541. names = dict(enumerate(names)) # to dict
  542. if plot:
  543. plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
  544. plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
  545. plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
  546. plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
  547. i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index
  548. p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values
  549. tp = (r * nt).round() # true positives
  550. fp = (tp / (p + eps) - tp).round() # false positives
  551. return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values
  552. class Metric(SimpleClass):
  553. """
  554. Class for computing evaluation metrics for YOLOv8 model.
  555. Attributes:
  556. p (list): Precision for each class. Shape: (nc,).
  557. r (list): Recall for each class. Shape: (nc,).
  558. f1 (list): F1 score for each class. Shape: (nc,).
  559. all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
  560. ap_class_index (list): Index of class for each AP score. Shape: (nc,).
  561. nc (int): Number of classes.
  562. Methods:
  563. ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
  564. ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
  565. mp(): Mean precision of all classes. Returns: Float.
  566. mr(): Mean recall of all classes. Returns: Float.
  567. map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
  568. map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
  569. map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
  570. mean_results(): Mean of results, returns mp, mr, map50, map.
  571. class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
  572. maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
  573. fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
  574. update(results): Update metric attributes with new evaluation results.
  575. """
  576. def __init__(self) -> None:
  577. """Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model."""
  578. self.p = [] # (nc, )
  579. self.r = [] # (nc, )
  580. self.f1 = [] # (nc, )
  581. self.all_ap = [] # (nc, 10)
  582. self.ap_class_index = [] # (nc, )
  583. self.nc = 0
  584. @property
  585. def ap50(self):
  586. """
  587. Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
  588. Returns:
  589. (np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
  590. """
  591. return self.all_ap[:, 0] if len(self.all_ap) else []
  592. @property
  593. def ap(self):
  594. """
  595. Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
  596. Returns:
  597. (np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
  598. """
  599. return self.all_ap.mean(1) if len(self.all_ap) else []
  600. @property
  601. def mp(self):
  602. """
  603. Return the Mean Precision of all classes.
  604. Returns:
  605. (float): The mean precision of all classes.
  606. """
  607. return self.p.mean() if len(self.p) else 0.0
  608. @property
  609. def mr(self):
  610. """
  611. Return the Mean Recall of all classes.
  612. Returns:
  613. (float): The mean recall of all classes.
  614. """
  615. return self.r.mean() if len(self.r) else 0.0
  616. @property
  617. def map50(self):
  618. """
  619. Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
  620. Returns:
  621. (float): The mAP at an IoU threshold of 0.5.
  622. """
  623. return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
  624. @property
  625. def map75(self):
  626. """
  627. Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
  628. Returns:
  629. (float): The mAP at an IoU threshold of 0.75.
  630. """
  631. return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
  632. @property
  633. def map(self):
  634. """
  635. Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
  636. Returns:
  637. (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
  638. """
  639. return self.all_ap.mean() if len(self.all_ap) else 0.0
  640. def mean_results(self):
  641. """Return mean of results, mp, mr, map50, map."""
  642. return [self.mp, self.mr, self.map50, self.map]
  643. def class_result(self, i):
  644. """Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
  645. return self.p[i], self.r[i], self.ap50[i], self.ap[i]
  646. @property
  647. def maps(self):
  648. """Return mAP of each class."""
  649. maps = np.zeros(self.nc) + self.map
  650. for i, c in enumerate(self.ap_class_index):
  651. maps[c] = self.ap[i]
  652. return maps
  653. def fitness(self):
  654. """Return model fitness as a weighted combination of metrics."""
  655. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  656. return (np.array(self.mean_results()) * w).sum()
  657. def update(self, results):
  658. """
  659. Update the evaluation metrics with a new set of results.
  660. Args:
  661. results (tuple): A tuple containing evaluation metrics:
  662. - p (list): Precision for each class.
  663. - r (list): Recall for each class.
  664. - f1 (list): F1 score for each class.
  665. - all_ap (list): AP scores for all classes and all IoU thresholds.
  666. - ap_class_index (list): Index of class for each AP score.
  667. - p_curve (list): Precision curve for each class.
  668. - r_curve (list): Recall curve for each class.
  669. - f1_curve (list): F1 curve for each class.
  670. - px (list): X values for the curves.
  671. - prec_values (list): Precision values for each class.
  672. """
  673. (
  674. self.p,
  675. self.r,
  676. self.f1,
  677. self.all_ap,
  678. self.ap_class_index,
  679. self.p_curve,
  680. self.r_curve,
  681. self.f1_curve,
  682. self.px,
  683. self.prec_values,
  684. ) = results
  685. @property
  686. def curves(self):
  687. """Return a list of curves for accessing specific metrics curves."""
  688. return []
  689. @property
  690. def curves_results(self):
  691. """Return a list of curves for accessing specific metrics curves."""
  692. return [
  693. [self.px, self.prec_values, "Recall", "Precision"],
  694. [self.px, self.f1_curve, "Confidence", "F1"],
  695. [self.px, self.p_curve, "Confidence", "Precision"],
  696. [self.px, self.r_curve, "Confidence", "Recall"],
  697. ]
  698. class DetMetrics(SimpleClass):
  699. """
  700. Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
  701. Attributes:
  702. save_dir (Path): A path to the directory where the output plots will be saved.
  703. plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
  704. names (dict): A dictionary of class names.
  705. box (Metric): An instance of the Metric class for storing detection results.
  706. speed (dict): A dictionary for storing execution times of different parts of the detection process.
  707. task (str): The task type, set to 'detect'.
  708. """
  709. def __init__(self, save_dir=Path("."), plot=False, names={}) -> None:
  710. """
  711. Initialize a DetMetrics instance with a save directory, plot flag, and class names.
  712. Args:
  713. save_dir (Path, optional): Directory to save plots.
  714. plot (bool, optional): Whether to plot precision-recall curves.
  715. names (dict, optional): Dictionary mapping class indices to names.
  716. """
  717. self.save_dir = save_dir
  718. self.plot = plot
  719. self.names = names
  720. self.box = Metric()
  721. self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
  722. self.task = "detect"
  723. def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
  724. """
  725. Process predicted results for object detection and update metrics.
  726. Args:
  727. tp (np.ndarray): True positive array.
  728. conf (np.ndarray): Confidence array.
  729. pred_cls (np.ndarray): Predicted class indices array.
  730. target_cls (np.ndarray): Target class indices array.
  731. on_plot (callable, optional): Function to call after plots are generated.
  732. """
  733. results = ap_per_class(
  734. tp,
  735. conf,
  736. pred_cls,
  737. target_cls,
  738. plot=self.plot,
  739. save_dir=self.save_dir,
  740. names=self.names,
  741. on_plot=on_plot,
  742. )[2:]
  743. self.box.nc = len(self.names)
  744. self.box.update(results)
  745. @property
  746. def keys(self):
  747. """Return a list of keys for accessing specific metrics."""
  748. return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
  749. def mean_results(self):
  750. """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
  751. return self.box.mean_results()
  752. def class_result(self, i):
  753. """Return the result of evaluating the performance of an object detection model on a specific class."""
  754. return self.box.class_result(i)
  755. @property
  756. def maps(self):
  757. """Return mean Average Precision (mAP) scores per class."""
  758. return self.box.maps
  759. @property
  760. def fitness(self):
  761. """Return the fitness of box object."""
  762. return self.box.fitness()
  763. @property
  764. def ap_class_index(self):
  765. """Return the average precision index per class."""
  766. return self.box.ap_class_index
  767. @property
  768. def results_dict(self):
  769. """Return dictionary of computed performance metrics and statistics."""
  770. return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
  771. @property
  772. def curves(self):
  773. """Return a list of curves for accessing specific metrics curves."""
  774. return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
  775. @property
  776. def curves_results(self):
  777. """Return dictionary of computed performance metrics and statistics."""
  778. return self.box.curves_results
  779. class SegmentMetrics(SimpleClass):
  780. """
  781. Calculates and aggregates detection and segmentation metrics over a given set of classes.
  782. Attributes:
  783. save_dir (Path): Path to the directory where the output plots should be saved.
  784. plot (bool): Whether to save the detection and segmentation plots.
  785. names (dict): Dictionary of class names.
  786. box (Metric): An instance of the Metric class to calculate box detection metrics.
  787. seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
  788. speed (dict): Dictionary to store the time taken in different phases of inference.
  789. task (str): The task type, set to 'segment'.
  790. """
  791. def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
  792. """
  793. Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
  794. Args:
  795. save_dir (Path, optional): Directory to save plots.
  796. plot (bool, optional): Whether to plot precision-recall curves.
  797. names (dict, optional): Dictionary mapping class indices to names.
  798. """
  799. self.save_dir = save_dir
  800. self.plot = plot
  801. self.names = names
  802. self.box = Metric()
  803. self.seg = Metric()
  804. self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
  805. self.task = "segment"
  806. def process(self, tp, tp_m, conf, pred_cls, target_cls, on_plot=None):
  807. """
  808. Process the detection and segmentation metrics over the given set of predictions.
  809. Args:
  810. tp (np.ndarray): True positive array for boxes.
  811. tp_m (np.ndarray): True positive array for masks.
  812. conf (np.ndarray): Confidence array.
  813. pred_cls (np.ndarray): Predicted class indices array.
  814. target_cls (np.ndarray): Target class indices array.
  815. on_plot (callable, optional): Function to call after plots are generated.
  816. """
  817. results_mask = ap_per_class(
  818. tp_m,
  819. conf,
  820. pred_cls,
  821. target_cls,
  822. plot=self.plot,
  823. on_plot=on_plot,
  824. save_dir=self.save_dir,
  825. names=self.names,
  826. prefix="Mask",
  827. )[2:]
  828. self.seg.nc = len(self.names)
  829. self.seg.update(results_mask)
  830. results_box = ap_per_class(
  831. tp,
  832. conf,
  833. pred_cls,
  834. target_cls,
  835. plot=self.plot,
  836. on_plot=on_plot,
  837. save_dir=self.save_dir,
  838. names=self.names,
  839. prefix="Box",
  840. )[2:]
  841. self.box.nc = len(self.names)
  842. self.box.update(results_box)
  843. @property
  844. def keys(self):
  845. """Return a list of keys for accessing metrics."""
  846. return [
  847. "metrics/precision(B)",
  848. "metrics/recall(B)",
  849. "metrics/mAP50(B)",
  850. "metrics/mAP50-95(B)",
  851. "metrics/precision(M)",
  852. "metrics/recall(M)",
  853. "metrics/mAP50(M)",
  854. "metrics/mAP50-95(M)",
  855. ]
  856. def mean_results(self):
  857. """Return the mean metrics for bounding box and segmentation results."""
  858. return self.box.mean_results() + self.seg.mean_results()
  859. def class_result(self, i):
  860. """Return classification results for a specified class index."""
  861. return self.box.class_result(i) + self.seg.class_result(i)
  862. @property
  863. def maps(self):
  864. """Return mAP scores for object detection and semantic segmentation models."""
  865. return self.box.maps + self.seg.maps
  866. @property
  867. def fitness(self):
  868. """Return the fitness score for both segmentation and bounding box models."""
  869. return self.seg.fitness() + self.box.fitness()
  870. @property
  871. def ap_class_index(self):
  872. """
  873. Return the class indices.
  874. Boxes and masks have the same ap_class_index.
  875. """
  876. return self.box.ap_class_index
  877. @property
  878. def results_dict(self):
  879. """Return results of object detection model for evaluation."""
  880. return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
  881. @property
  882. def curves(self):
  883. """Return a list of curves for accessing specific metrics curves."""
  884. return [
  885. "Precision-Recall(B)",
  886. "F1-Confidence(B)",
  887. "Precision-Confidence(B)",
  888. "Recall-Confidence(B)",
  889. "Precision-Recall(M)",
  890. "F1-Confidence(M)",
  891. "Precision-Confidence(M)",
  892. "Recall-Confidence(M)",
  893. ]
  894. @property
  895. def curves_results(self):
  896. """Return dictionary of computed performance metrics and statistics."""
  897. return self.box.curves_results + self.seg.curves_results
  898. class PoseMetrics(SegmentMetrics):
  899. """
  900. Calculates and aggregates detection and pose metrics over a given set of classes.
  901. Attributes:
  902. save_dir (Path): Path to the directory where the output plots should be saved.
  903. plot (bool): Whether to save the detection and pose plots.
  904. names (dict): Dictionary of class names.
  905. box (Metric): An instance of the Metric class to calculate box detection metrics.
  906. pose (Metric): An instance of the Metric class to calculate pose metrics.
  907. speed (dict): Dictionary to store the time taken in different phases of inference.
  908. task (str): The task type, set to 'pose'.
  909. Methods:
  910. process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
  911. mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
  912. class_result(i): Returns the detection and segmentation metrics of class `i`.
  913. maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
  914. fitness: Returns the fitness scores, which are a single weighted combination of metrics.
  915. ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
  916. results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
  917. """
  918. def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
  919. """
  920. Initialize the PoseMetrics class with directory path, class names, and plotting options.
  921. Args:
  922. save_dir (Path, optional): Directory to save plots.
  923. plot (bool, optional): Whether to plot precision-recall curves.
  924. names (dict, optional): Dictionary mapping class indices to names.
  925. """
  926. super().__init__(save_dir, plot, names)
  927. self.save_dir = save_dir
  928. self.plot = plot
  929. self.names = names
  930. self.box = Metric()
  931. self.pose = Metric()
  932. self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
  933. self.task = "pose"
  934. def process(self, tp, tp_p, conf, pred_cls, target_cls, on_plot=None):
  935. """
  936. Process the detection and pose metrics over the given set of predictions.
  937. Args:
  938. tp (np.ndarray): True positive array for boxes.
  939. tp_p (np.ndarray): True positive array for keypoints.
  940. conf (np.ndarray): Confidence array.
  941. pred_cls (np.ndarray): Predicted class indices array.
  942. target_cls (np.ndarray): Target class indices array.
  943. on_plot (callable, optional): Function to call after plots are generated.
  944. """
  945. results_pose = ap_per_class(
  946. tp_p,
  947. conf,
  948. pred_cls,
  949. target_cls,
  950. plot=self.plot,
  951. on_plot=on_plot,
  952. save_dir=self.save_dir,
  953. names=self.names,
  954. prefix="Pose",
  955. )[2:]
  956. self.pose.nc = len(self.names)
  957. self.pose.update(results_pose)
  958. results_box = ap_per_class(
  959. tp,
  960. conf,
  961. pred_cls,
  962. target_cls,
  963. plot=self.plot,
  964. on_plot=on_plot,
  965. save_dir=self.save_dir,
  966. names=self.names,
  967. prefix="Box",
  968. )[2:]
  969. self.box.nc = len(self.names)
  970. self.box.update(results_box)
  971. @property
  972. def keys(self):
  973. """Return list of evaluation metric keys."""
  974. return [
  975. "metrics/precision(B)",
  976. "metrics/recall(B)",
  977. "metrics/mAP50(B)",
  978. "metrics/mAP50-95(B)",
  979. "metrics/precision(P)",
  980. "metrics/recall(P)",
  981. "metrics/mAP50(P)",
  982. "metrics/mAP50-95(P)",
  983. ]
  984. def mean_results(self):
  985. """Return the mean results of box and pose."""
  986. return self.box.mean_results() + self.pose.mean_results()
  987. def class_result(self, i):
  988. """Return the class-wise detection results for a specific class i."""
  989. return self.box.class_result(i) + self.pose.class_result(i)
  990. @property
  991. def maps(self):
  992. """Return the mean average precision (mAP) per class for both box and pose detections."""
  993. return self.box.maps + self.pose.maps
  994. @property
  995. def fitness(self):
  996. """Return combined fitness score for pose and box detection."""
  997. return self.pose.fitness() + self.box.fitness()
  998. @property
  999. def curves(self):
  1000. """Return a list of curves for accessing specific metrics curves."""
  1001. return [
  1002. "Precision-Recall(B)",
  1003. "F1-Confidence(B)",
  1004. "Precision-Confidence(B)",
  1005. "Recall-Confidence(B)",
  1006. "Precision-Recall(P)",
  1007. "F1-Confidence(P)",
  1008. "Precision-Confidence(P)",
  1009. "Recall-Confidence(P)",
  1010. ]
  1011. @property
  1012. def curves_results(self):
  1013. """Return dictionary of computed performance metrics and statistics."""
  1014. return self.box.curves_results + self.pose.curves_results
  1015. class ClassifyMetrics(SimpleClass):
  1016. """
  1017. Class for computing classification metrics including top-1 and top-5 accuracy.
  1018. Attributes:
  1019. top1 (float): The top-1 accuracy.
  1020. top5 (float): The top-5 accuracy.
  1021. speed (dict): A dictionary containing the time taken for each step in the pipeline.
  1022. task (str): The task type, set to 'classify'.
  1023. """
  1024. def __init__(self) -> None:
  1025. """Initialize a ClassifyMetrics instance."""
  1026. self.top1 = 0
  1027. self.top5 = 0
  1028. self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
  1029. self.task = "classify"
  1030. def process(self, targets, pred):
  1031. """
  1032. Process target classes and predicted classes to compute metrics.
  1033. Args:
  1034. targets (torch.Tensor): Target classes.
  1035. pred (torch.Tensor): Predicted classes.
  1036. """
  1037. pred, targets = torch.cat(pred), torch.cat(targets)
  1038. correct = (targets[:, None] == pred).float()
  1039. acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
  1040. self.top1, self.top5 = acc.mean(0).tolist()
  1041. @property
  1042. def fitness(self):
  1043. """Return mean of top-1 and top-5 accuracies as fitness score."""
  1044. return (self.top1 + self.top5) / 2
  1045. @property
  1046. def results_dict(self):
  1047. """Return a dictionary with model's performance metrics and fitness score."""
  1048. return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
  1049. @property
  1050. def keys(self):
  1051. """Return a list of keys for the results_dict property."""
  1052. return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
  1053. @property
  1054. def curves(self):
  1055. """Return a list of curves for accessing specific metrics curves."""
  1056. return []
  1057. @property
  1058. def curves_results(self):
  1059. """Return a list of curves for accessing specific metrics curves."""
  1060. return []
  1061. class OBBMetrics(SimpleClass):
  1062. """
  1063. Metrics for evaluating oriented bounding box (OBB) detection.
  1064. Attributes:
  1065. save_dir (Path): Path to the directory where the output plots should be saved.
  1066. plot (bool): Whether to save the detection plots.
  1067. names (dict): Dictionary of class names.
  1068. box (Metric): An instance of the Metric class for storing detection results.
  1069. speed (dict): A dictionary for storing execution times of different parts of the detection process.
  1070. References:
  1071. https://arxiv.org/pdf/2106.06072.pdf
  1072. """
  1073. def __init__(self, save_dir=Path("."), plot=False, names=()) -> None:
  1074. """
  1075. Initialize an OBBMetrics instance with directory, plotting, and class names.
  1076. Args:
  1077. save_dir (Path, optional): Directory to save plots.
  1078. plot (bool, optional): Whether to plot precision-recall curves.
  1079. names (dict, optional): Dictionary mapping class indices to names.
  1080. """
  1081. self.save_dir = save_dir
  1082. self.plot = plot
  1083. self.names = names
  1084. self.box = Metric()
  1085. self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
  1086. def process(self, tp, conf, pred_cls, target_cls, on_plot=None):
  1087. """
  1088. Process predicted results for object detection and update metrics.
  1089. Args:
  1090. tp (np.ndarray): True positive array.
  1091. conf (np.ndarray): Confidence array.
  1092. pred_cls (np.ndarray): Predicted class indices array.
  1093. target_cls (np.ndarray): Target class indices array.
  1094. on_plot (callable, optional): Function to call after plots are generated.
  1095. """
  1096. results = ap_per_class(
  1097. tp,
  1098. conf,
  1099. pred_cls,
  1100. target_cls,
  1101. plot=self.plot,
  1102. save_dir=self.save_dir,
  1103. names=self.names,
  1104. on_plot=on_plot,
  1105. )[2:]
  1106. self.box.nc = len(self.names)
  1107. self.box.update(results)
  1108. @property
  1109. def keys(self):
  1110. """Return a list of keys for accessing specific metrics."""
  1111. return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
  1112. def mean_results(self):
  1113. """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
  1114. return self.box.mean_results()
  1115. def class_result(self, i):
  1116. """Return the result of evaluating the performance of an object detection model on a specific class."""
  1117. return self.box.class_result(i)
  1118. @property
  1119. def maps(self):
  1120. """Return mean Average Precision (mAP) scores per class."""
  1121. return self.box.maps
  1122. @property
  1123. def fitness(self):
  1124. """Return the fitness of box object."""
  1125. return self.box.fitness()
  1126. @property
  1127. def ap_class_index(self):
  1128. """Return the average precision index per class."""
  1129. return self.box.ap_class_index
  1130. @property
  1131. def results_dict(self):
  1132. """Return dictionary of computed performance metrics and statistics."""
  1133. return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness]))
  1134. @property
  1135. def curves(self):
  1136. """Return a list of curves for accessing specific metrics curves."""
  1137. return []
  1138. @property
  1139. def curves_results(self):
  1140. """Return a list of curves for accessing specific metrics curves."""
  1141. return []
Discard
Tip!

Press p or to see the previous file or, n or to see the next file