Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

metrics.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Model validation metrics
  4. """
  5. import math
  6. import warnings
  7. from pathlib import Path
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import torch
  11. def fitness(x):
  12. # Model fitness as a weighted combination of metrics
  13. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  14. return (x[:, :4] * w).sum(1)
  15. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
  16. """ Compute the average precision, given the recall and precision curves.
  17. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  18. # Arguments
  19. tp: True positives (nparray, nx1 or nx10).
  20. conf: Objectness value from 0-1 (nparray).
  21. pred_cls: Predicted object classes (nparray).
  22. target_cls: True object classes (nparray).
  23. plot: Plot precision-recall curve at mAP@0.5
  24. save_dir: Plot save directory
  25. # Returns
  26. The average precision as computed in py-faster-rcnn.
  27. """
  28. # Sort by objectness
  29. i = np.argsort(-conf)
  30. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  31. # Find unique classes
  32. unique_classes = np.unique(target_cls)
  33. nc = unique_classes.shape[0] # number of classes, number of detections
  34. # Create Precision-Recall curve and compute AP for each class
  35. px, py = np.linspace(0, 1, 1000), [] # for plotting
  36. ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  37. for ci, c in enumerate(unique_classes):
  38. i = pred_cls == c
  39. n_l = (target_cls == c).sum() # number of labels
  40. n_p = i.sum() # number of predictions
  41. if n_p == 0 or n_l == 0:
  42. continue
  43. else:
  44. # Accumulate FPs and TPs
  45. fpc = (1 - tp[i]).cumsum(0)
  46. tpc = tp[i].cumsum(0)
  47. # Recall
  48. recall = tpc / (n_l + 1e-16) # recall curve
  49. r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  50. # Precision
  51. precision = tpc / (tpc + fpc) # precision curve
  52. p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
  53. # AP from recall-precision curve
  54. for j in range(tp.shape[1]):
  55. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  56. if plot and j == 0:
  57. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  58. # Compute F1 (harmonic mean of precision and recall)
  59. f1 = 2 * p * r / (p + r + 1e-16)
  60. if plot:
  61. plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
  62. plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
  63. plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
  64. plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
  65. i = f1.mean(0).argmax() # max F1 index
  66. return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
  67. def compute_ap(recall, precision):
  68. """ Compute the average precision, given the recall and precision curves
  69. # Arguments
  70. recall: The recall curve (list)
  71. precision: The precision curve (list)
  72. # Returns
  73. Average precision, precision curve, recall curve
  74. """
  75. # Append sentinel values to beginning and end
  76. mrec = np.concatenate(([0.0], recall, [1.0]))
  77. mpre = np.concatenate(([1.0], precision, [0.0]))
  78. # Compute the precision envelope
  79. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  80. # Integrate area under curve
  81. method = 'interp' # methods: 'continuous', 'interp'
  82. if method == 'interp':
  83. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  84. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  85. else: # 'continuous'
  86. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  87. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  88. return ap, mpre, mrec
  89. class ConfusionMatrix:
  90. # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
  91. def __init__(self, nc, conf=0.25, iou_thres=0.45):
  92. self.matrix = np.zeros((nc + 1, nc + 1))
  93. self.nc = nc # number of classes
  94. self.conf = conf
  95. self.iou_thres = iou_thres
  96. def process_batch(self, detections, labels):
  97. """
  98. Return intersection-over-union (Jaccard index) of boxes.
  99. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  100. Arguments:
  101. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  102. labels (Array[M, 5]), class, x1, y1, x2, y2
  103. Returns:
  104. None, updates confusion matrix accordingly
  105. """
  106. detections = detections[detections[:, 4] > self.conf]
  107. gt_classes = labels[:, 0].int()
  108. detection_classes = detections[:, 5].int()
  109. iou = box_iou(labels[:, 1:], detections[:, :4])
  110. x = torch.where(iou > self.iou_thres)
  111. if x[0].shape[0]:
  112. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  113. if x[0].shape[0] > 1:
  114. matches = matches[matches[:, 2].argsort()[::-1]]
  115. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  116. matches = matches[matches[:, 2].argsort()[::-1]]
  117. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  118. else:
  119. matches = np.zeros((0, 3))
  120. n = matches.shape[0] > 0
  121. m0, m1, _ = matches.transpose().astype(np.int16)
  122. for i, gc in enumerate(gt_classes):
  123. j = m0 == i
  124. if n and sum(j) == 1:
  125. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  126. else:
  127. self.matrix[self.nc, gc] += 1 # background FP
  128. if n:
  129. for i, dc in enumerate(detection_classes):
  130. if not any(m1 == i):
  131. self.matrix[dc, self.nc] += 1 # background FN
  132. def matrix(self):
  133. return self.matrix
  134. def plot(self, normalize=True, save_dir='', names=()):
  135. try:
  136. import seaborn as sn
  137. array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-6) if normalize else 1) # normalize columns
  138. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  139. fig = plt.figure(figsize=(12, 9), tight_layout=True)
  140. sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
  141. labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
  142. with warnings.catch_warnings():
  143. warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
  144. sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
  145. xticklabels=names + ['background FP'] if labels else "auto",
  146. yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
  147. fig.axes[0].set_xlabel('True')
  148. fig.axes[0].set_ylabel('Predicted')
  149. fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
  150. plt.close()
  151. except Exception as e:
  152. print(f'WARNING: ConfusionMatrix plot failure: {e}')
  153. def print(self):
  154. for i in range(self.nc + 1):
  155. print(' '.join(map(str, self.matrix[i])))
  156. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  157. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  158. box2 = box2.T
  159. # Get the coordinates of bounding boxes
  160. if x1y1x2y2: # x1, y1, x2, y2 = box1
  161. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  162. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  163. else: # transform from xywh to xyxy
  164. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  165. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  166. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  167. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  168. # Intersection area
  169. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  170. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  171. # Union Area
  172. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  173. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  174. union = w1 * h1 + w2 * h2 - inter + eps
  175. iou = inter / union
  176. if GIoU or DIoU or CIoU:
  177. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  178. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  179. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  180. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  181. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  182. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  183. if DIoU:
  184. return iou - rho2 / c2 # DIoU
  185. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  186. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  187. with torch.no_grad():
  188. alpha = v / (v - iou + (1 + eps))
  189. return iou - (rho2 / c2 + v * alpha) # CIoU
  190. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  191. c_area = cw * ch + eps # convex area
  192. return iou - (c_area - union) / c_area # GIoU
  193. else:
  194. return iou # IoU
  195. def box_iou(box1, box2):
  196. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  197. """
  198. Return intersection-over-union (Jaccard index) of boxes.
  199. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  200. Arguments:
  201. box1 (Tensor[N, 4])
  202. box2 (Tensor[M, 4])
  203. Returns:
  204. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  205. IoU values for every element in boxes1 and boxes2
  206. """
  207. def box_area(box):
  208. # box = 4xn
  209. return (box[2] - box[0]) * (box[3] - box[1])
  210. area1 = box_area(box1.T)
  211. area2 = box_area(box2.T)
  212. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  213. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  214. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  215. def bbox_ioa(box1, box2, eps=1E-7):
  216. """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
  217. box1: np.array of shape(4)
  218. box2: np.array of shape(nx4)
  219. returns: np.array of shape(n)
  220. """
  221. box2 = box2.transpose()
  222. # Get the coordinates of bounding boxes
  223. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  224. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  225. # Intersection area
  226. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  227. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  228. # box2 area
  229. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
  230. # Intersection over box2 area
  231. return inter_area / box2_area
  232. def wh_iou(wh1, wh2):
  233. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  234. wh1 = wh1[:, None] # [N,1,2]
  235. wh2 = wh2[None] # [1,M,2]
  236. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  237. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  238. # Plots ----------------------------------------------------------------------------------------------------------------
  239. def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
  240. # Precision-recall curve
  241. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  242. py = np.stack(py, axis=1)
  243. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  244. for i, y in enumerate(py.T):
  245. ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
  246. else:
  247. ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
  248. ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  249. ax.set_xlabel('Recall')
  250. ax.set_ylabel('Precision')
  251. ax.set_xlim(0, 1)
  252. ax.set_ylim(0, 1)
  253. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  254. fig.savefig(Path(save_dir), dpi=250)
  255. plt.close()
  256. def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
  257. # Metric-confidence curve
  258. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  259. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  260. for i, y in enumerate(py):
  261. ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
  262. else:
  263. ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
  264. y = py.mean(0)
  265. ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
  266. ax.set_xlabel(xlabel)
  267. ax.set_ylabel(ylabel)
  268. ax.set_xlim(0, 1)
  269. ax.set_ylim(0, 1)
  270. plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
  271. fig.savefig(Path(save_dir), dpi=250)
  272. plt.close()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...