Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
  1. from typing import Mapping, Tuple, Union, Optional
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn, Tensor
  6. import super_gradients
  7. from super_gradients.common.object_names import Losses
  8. from super_gradients.common.registry.registry import register_loss
  9. from super_gradients.training.datasets.data_formats.bbox_formats.cxcywh import cxcywh_to_xyxy
  10. from super_gradients.training.utils.bbox_utils import batch_distance2bbox
  11. from super_gradients.training.utils.distributed_training_utils import (
  12. get_world_size,
  13. )
  14. def batch_iou_similarity(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-9) -> float:
  15. """Calculate iou of box1 and box2 in batch. Bboxes are expected to be in x1y1x2y2 format.
  16. :param box1: box with the shape [N, M1, 4]
  17. :param box2: box with the shape [N, M2, 4]
  18. :return iou: iou between box1 and box2 with the shape [N, M1, M2]
  19. """
  20. box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
  21. box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
  22. px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
  23. gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
  24. x1y1 = torch.maximum(px1y1, gx1y1)
  25. x2y2 = torch.minimum(px2y2, gx2y2)
  26. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  27. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  28. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  29. union = area1 + area2 - overlap + eps
  30. return overlap / union
  31. def iou_similarity(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-10) -> float:
  32. """
  33. Calculate iou of box1 and box2. Bboxes are expected to be in x1y1x2y2 format.
  34. :param box1: box with the shape [M1, 4]
  35. :param box2: box with the shape [M2, 4]
  36. :return iou: iou between box1 and box2 with the shape [M1, M2]
  37. """
  38. box1 = box1.unsqueeze(1) # [M1, 4] -> [M1, 1, 4]
  39. box2 = box2.unsqueeze(0) # [M2, 4] -> [1, M2, 4]
  40. px1y1, px2y2 = box1[:, :, 0:2], box1[:, :, 2:4]
  41. gx1y1, gx2y2 = box2[:, :, 0:2], box2[:, :, 2:4]
  42. x1y1 = torch.maximum(px1y1, gx1y1)
  43. x2y2 = torch.minimum(px2y2, gx2y2)
  44. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  45. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  46. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  47. union = area1 + area2 - overlap + eps
  48. return overlap / union
  49. def bbox_overlaps(bboxes1: torch.Tensor, bboxes2: torch.Tensor, mode: str = "iou", is_aligned: bool = False, eps: float = 1e-6) -> torch.Tensor:
  50. """
  51. Calculate overlap between two set of bboxes.
  52. If ``is_aligned `` is ``False``, then calculate the overlaps between each
  53. bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
  54. pair of bboxes1 and bboxes2.
  55. :param bboxes1: shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
  56. :param bboxes2: shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
  57. B indicates the batch dim, in shape (B1, B2, ..., Bn).
  58. If ``is_aligned `` is ``True``, then m and n must be equal.
  59. :param mode: Either "iou" (intersection over union) or "iof" (intersection over foreground).
  60. :param is_aligned: If True, then m and n must be equal. Default False.
  61. :param eps: A value added to the denominator for numerical stability. Default 1e-6.
  62. :return: Tensor of shape (m, n) if ``is_aligned `` is False else shape (m,)
  63. """
  64. assert mode in ["iou", "iof", "giou"], "Unsupported mode {}".format(mode)
  65. # Either the boxes are empty or the length of boxes's last dimenstion is 4
  66. assert bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0
  67. assert bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0
  68. # Batch dim must be the same
  69. # Batch dim: (B1, B2, ... Bn)
  70. assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
  71. batch_shape = bboxes1.shape[:-2]
  72. rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0
  73. cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0
  74. if is_aligned:
  75. assert rows == cols
  76. if rows * cols == 0:
  77. if is_aligned:
  78. return np.random.random(batch_shape + (rows,))
  79. else:
  80. return np.random.random(batch_shape + (rows, cols))
  81. area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
  82. area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
  83. if is_aligned:
  84. lt = np.maximum(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
  85. rb = np.minimum(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
  86. wh = (rb - lt).clip(min=0) # [B, rows, 2]
  87. overlap = wh[..., 0] * wh[..., 1]
  88. if mode in ["iou", "giou"]:
  89. union = area1 + area2 - overlap
  90. else:
  91. union = area1
  92. if mode == "giou":
  93. enclosed_lt = np.minimum(bboxes1[..., :2], bboxes2[..., :2])
  94. enclosed_rb = np.maximum(bboxes1[..., 2:], bboxes2[..., 2:])
  95. else:
  96. lt = np.maximum(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
  97. rb = np.minimum(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
  98. wh = (rb - lt).clip(min=0) # [B, rows, cols, 2]
  99. overlap = wh[..., 0] * wh[..., 1]
  100. if mode in ["iou", "giou"]:
  101. union = area1[..., None] + area2[..., None, :] - overlap
  102. else:
  103. union = area1[..., None]
  104. if mode == "giou":
  105. enclosed_lt = np.minimum(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2])
  106. enclosed_rb = np.maximum(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:])
  107. eps = np.array([eps])
  108. union = np.maximum(union, eps)
  109. ious = overlap / union
  110. if mode in ["iou", "iof"]:
  111. return ious
  112. # calculate gious
  113. enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
  114. enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
  115. enclose_area = np.maximum(enclose_area, eps)
  116. gious = ious - (enclose_area - union) / enclose_area
  117. return gious
  118. def topk_(input, k, axis=1, largest=True):
  119. x = -input if largest else input
  120. if axis == 0:
  121. row_index = np.arange(input.shape[1 - axis])
  122. topk_index = np.argpartition(x, k, axis=axis)[0:k, :]
  123. topk_data = x[topk_index, row_index]
  124. topk_index_sort = np.argsort(topk_data, axis=axis)
  125. topk_data_sort = topk_data[topk_index_sort, row_index]
  126. topk_index_sort = topk_index[0:k, :][topk_index_sort, row_index]
  127. else:
  128. column_index = np.arange(x.shape[1 - axis])[:, None]
  129. topk_index = np.argpartition(x, k, axis=axis)[:, 0:k]
  130. topk_data = x[column_index, topk_index]
  131. topk_data = -topk_data if largest else topk_data
  132. topk_index_sort = np.argsort(topk_data, axis=axis)
  133. topk_data_sort = topk_data[column_index, topk_index_sort]
  134. topk_index_sort = topk_index[:, 0:k][column_index, topk_index_sort]
  135. return topk_data_sort, topk_index_sort
  136. def compute_max_iou_anchor(ious: Tensor) -> Tensor:
  137. r"""
  138. For each anchor, find the GT with the largest IOU.
  139. :param ious: Tensor (float32) of shape[B, n, L], n: num_gts, L: num_anchors
  140. :return: is_max_iou is Tensor (float32) of shape[B, n, L], value=1. means selected
  141. """
  142. num_max_boxes = ious.shape[-2]
  143. max_iou_index = ious.argmax(dim=-2)
  144. is_max_iou: Tensor = torch.nn.functional.one_hot(max_iou_index, num_max_boxes).permute([0, 2, 1])
  145. return is_max_iou.type_as(ious)
  146. def check_points_inside_bboxes(points: Tensor, bboxes: Tensor, center_radius_tensor: Optional[Tensor] = None, eps: float = 1e-9) -> Tensor:
  147. """
  148. :param points: Tensor (float32) of shape[L, 2], "xy" format, L: num_anchors
  149. :param bboxes: Tensor (float32) of shape[B, n, 4], "xmin, ymin, xmax, ymax" format
  150. :param center_radius_tensor: Tensor (float32) of shape [L, 1]. Default: None.
  151. :param eps: Default: 1e-9
  152. :return is_in_bboxes: Tensor (float32) of shape[B, n, L], value=1. means selected
  153. """
  154. points = points.unsqueeze(0).unsqueeze(0)
  155. x, y = points.chunk(2, dim=-1)
  156. xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, dim=-1)
  157. # check whether `points` is in `bboxes`
  158. left = x - xmin
  159. top = y - ymin
  160. right = xmax - x
  161. bottom = ymax - y
  162. delta_ltrb = torch.cat([left, top, right, bottom], dim=-1)
  163. is_in_bboxes = delta_ltrb.min(dim=-1).values > eps
  164. if center_radius_tensor is not None:
  165. # check whether `points` is in `center_radius`
  166. center_radius_tensor = center_radius_tensor.unsqueeze(0).unsqueeze(0)
  167. cx = (xmin + xmax) * 0.5
  168. cy = (ymin + ymax) * 0.5
  169. left = x - (cx - center_radius_tensor)
  170. top = y - (cy - center_radius_tensor)
  171. right = (cx + center_radius_tensor) - x
  172. bottom = (cy + center_radius_tensor) - y
  173. delta_ltrb_c = torch.cat([left, top, right, bottom], dim=-1)
  174. is_in_center = delta_ltrb_c.min(dim=-1) > eps
  175. return (torch.logical_and(is_in_bboxes, is_in_center), torch.logical_or(is_in_bboxes, is_in_center))
  176. return is_in_bboxes.type_as(bboxes)
  177. def gather_topk_anchors(metrics: Tensor, topk: int, largest: bool = True, topk_mask: Optional[Tensor] = None, eps: float = 1e-9) -> Tensor:
  178. """
  179. :param metrics: Tensor(float32) of shape[B, n, L], n: num_gts, L: num_anchors
  180. :param topk: The number of top elements to look for along the axis.
  181. :param largest: If set to true, algorithm will sort by descending order, otherwise sort by ascending order.
  182. :param topk_mask: Tensor(float32) of shape[B, n, 1], ignore bbox mask,
  183. :param eps: Default: 1e-9
  184. :return: is_in_topk, Tensor (float32) of shape[B, n, L], value=1. means selected
  185. """
  186. num_anchors = metrics.shape[-1]
  187. topk_metrics, topk_idxs = torch.topk(metrics, topk, dim=-1, largest=largest)
  188. if topk_mask is None:
  189. topk_mask = (topk_metrics.max(dim=-1, keepdim=True).values > eps).type_as(metrics)
  190. is_in_topk = torch.nn.functional.one_hot(topk_idxs, num_anchors).sum(dim=-2).type_as(metrics)
  191. return is_in_topk * topk_mask
  192. def bbox_center(boxes: Tensor) -> Tensor:
  193. """
  194. Get bbox centers from boxes.
  195. :param boxes: Boxes with shape (..., 4), "xmin, ymin, xmax, ymax" format.
  196. :return: Boxes centers with shape (..., 2), "cx, cy" format.
  197. """
  198. boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2
  199. boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2
  200. return torch.stack([boxes_cx, boxes_cy], dim=-1)
  201. def compute_max_iou_gt(ious: Tensor) -> Tensor:
  202. """
  203. For each GT, find the anchor with the largest IOU.
  204. :param ious: Tensor (float32) of shape[B, n, L], n: num_gts, L: num_anchors
  205. :return: is_max_iou, Tensor (float32) of shape[B, n, L], value=1. means selected
  206. """
  207. num_anchors = ious.shape[-1]
  208. max_iou_index = ious.argmax(dim=-1)
  209. is_max_iou = torch.nn.functional.one_hot(max_iou_index, num_anchors)
  210. return is_max_iou.astype(ious.dtype)
  211. class ATSSAssigner(nn.Module):
  212. """Bridging the Gap Between Anchor-based and Anchor-free Detection
  213. via Adaptive Training Sample Selection
  214. """
  215. __shared__ = ["num_classes"]
  216. def __init__(self, topk=9, num_classes=80, force_gt_matching=False, eps=1e-9):
  217. """
  218. :param topk: Maximum number of achors that is selected for each gt box
  219. :param num_classes:
  220. :param force_gt_matching: Guarantee that each gt box is matched to at least one anchor.
  221. If two gt boxes match to the same anchor, the one with the larger area will be selected.
  222. And the second-best achnor will be assigned to the other gt box.
  223. :param eps: Small constant for numerical stability
  224. """
  225. super(ATSSAssigner, self).__init__()
  226. self.topk = topk
  227. self.num_classes = num_classes
  228. self.force_gt_matching = force_gt_matching
  229. self.eps = eps
  230. def _gather_topk_pyramid(self, gt2anchor_distances, num_anchors_list, pad_gt_mask):
  231. gt2anchor_distances_list = torch.split(gt2anchor_distances, num_anchors_list, dim=-1)
  232. num_anchors_index = np.cumsum(num_anchors_list).tolist()
  233. num_anchors_index = [
  234. 0,
  235. ] + num_anchors_index[:-1]
  236. is_in_topk_list = []
  237. topk_idxs_list = []
  238. for distances, anchors_index in zip(gt2anchor_distances_list, num_anchors_index):
  239. num_anchors = distances.shape[-1]
  240. _, topk_idxs = torch.topk(distances, self.topk, dim=-1, largest=False)
  241. topk_idxs_list.append(topk_idxs + anchors_index)
  242. is_in_topk = torch.nn.functional.one_hot(topk_idxs, num_anchors).sum(dim=-2).type_as(gt2anchor_distances)
  243. is_in_topk_list.append(is_in_topk * pad_gt_mask)
  244. is_in_topk_list = torch.cat(is_in_topk_list, dim=-1)
  245. topk_idxs_list = torch.cat(topk_idxs_list, dim=-1)
  246. return is_in_topk_list, topk_idxs_list
  247. @torch.no_grad()
  248. def forward(
  249. self,
  250. anchor_bboxes: Tensor,
  251. num_anchors_list: list,
  252. gt_labels: Tensor,
  253. gt_bboxes: Tensor,
  254. pad_gt_mask: Tensor,
  255. bg_index: int,
  256. gt_scores: Optional[Tensor] = None,
  257. pred_bboxes: Optional[Tensor] = None,
  258. ) -> Tuple[Tensor, Tensor, Tensor]:
  259. """
  260. This code is based on https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  261. The assignment is done in following steps
  262. 1. compute iou between all bbox (bbox of all pyramid levels) and gt
  263. 2. compute center distance between all bbox and gt
  264. 3. on each pyramid level, for each gt, select k bbox whose center
  265. are closest to the gt center, so we total select k*l bbox as
  266. candidates for each gt
  267. 4. get corresponding iou for the these candidates, and compute the
  268. mean and std, set mean + std as the iou threshold
  269. 5. select these candidates whose iou are greater than or equal to
  270. the threshold as positive
  271. 6. limit the positive sample's center in gt
  272. 7. if an anchor box is assigned to multiple gts, the one with the
  273. highest iou will be selected.
  274. :param anchor_bboxes: Tensor(float32) - pre-defined anchors, shape(L, 4), "xmin, xmax, ymin, ymax" format
  275. :param num_anchors_list: Number of anchors in each level
  276. :param gt_labels: Tensor (int64|int32) - Label of gt_bboxes, shape(B, n, 1)
  277. :param gt_bboxes: Tensor (float32) - Ground truth bboxes, shape(B, n, 4)
  278. :param pad_gt_mask: Tensor (float32) - 1 means bbox, 0 means no bbox, shape(B, n, 1)
  279. :param bg_index: Background index
  280. :param gt_scores: Tensor (float32) - Score of gt_bboxes, shape(B, n, 1), if None, then it will initialize with one_hot label
  281. :param pred_bboxes: Tensor (float32) - predicted bounding boxes, shape(B, L, 4)
  282. :return:
  283. - assigned_labels: Tensor of shape (B, L)
  284. - assigned_bboxes: Tensor of shape (B, L, 4)
  285. - assigned_scores: Tensor of shape (B, L, C), if pred_bboxes is not None, then output ious
  286. """
  287. assert gt_labels.ndim == gt_bboxes.ndim and gt_bboxes.ndim == 3
  288. num_anchors, _ = anchor_bboxes.shape
  289. batch_size, num_max_boxes, _ = gt_bboxes.shape
  290. # negative batch
  291. if num_max_boxes == 0:
  292. assigned_labels = torch.full([batch_size, num_anchors], bg_index, dtype=torch.long, device=anchor_bboxes.device)
  293. assigned_bboxes = torch.zeros([batch_size, num_anchors, 4], device=anchor_bboxes.device)
  294. assigned_scores = torch.zeros([batch_size, num_anchors, self.num_classes], device=anchor_bboxes.device)
  295. return assigned_labels, assigned_bboxes, assigned_scores
  296. # 1. compute iou between gt and anchor bbox, [B, n, L]
  297. ious = iou_similarity(gt_bboxes.reshape([-1, 4]), anchor_bboxes)
  298. ious = ious.reshape([batch_size, -1, num_anchors])
  299. # 2. compute center distance between all anchors and gt, [B, n, L]
  300. gt_centers = bbox_center(gt_bboxes.reshape([-1, 4])).unsqueeze(1)
  301. anchor_centers = bbox_center(anchor_bboxes)
  302. # gt2anchor_distances = (
  303. # (gt_centers - anchor_centers.unsqueeze(0)).norm(2, dim=-1).reshape([batch_size, -1, num_anchors])
  304. # )
  305. gt2anchor_distances = torch.norm(gt_centers - anchor_centers.unsqueeze(0), p=2, dim=-1).reshape([batch_size, -1, num_anchors])
  306. # 3. on each pyramid level, selecting top-k closest candidates
  307. # based on the center distance, [B, n, L]
  308. is_in_topk, topk_idxs = self._gather_topk_pyramid(gt2anchor_distances, num_anchors_list, pad_gt_mask)
  309. # 4. get corresponding iou for the these candidates, and compute the
  310. # mean and std, 5. set mean + std as the iou threshold
  311. iou_candidates = ious * is_in_topk
  312. iou_threshold = torch.gather(iou_candidates.flatten(end_dim=-2), dim=1, index=topk_idxs.flatten(end_dim=-2))
  313. iou_threshold = iou_threshold.reshape([batch_size, num_max_boxes, -1])
  314. iou_threshold = iou_threshold.mean(dim=-1, keepdim=True) + iou_threshold.std(dim=-1, keepdim=True)
  315. is_in_topk = torch.where(iou_candidates > iou_threshold, is_in_topk, torch.zeros_like(is_in_topk))
  316. # 6. check the positive sample's center in gt, [B, n, L]
  317. is_in_gts = check_points_inside_bboxes(anchor_centers, gt_bboxes)
  318. # select positive sample, [B, n, L]
  319. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  320. # 7. if an anchor box is assigned to multiple gts,
  321. # the one with the highest iou will be selected.
  322. mask_positive_sum = mask_positive.sum(dim=-2)
  323. if mask_positive_sum.max() > 1:
  324. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile([1, num_max_boxes, 1])
  325. is_max_iou = compute_max_iou_anchor(ious)
  326. mask_positive = torch.where(mask_multiple_gts, is_max_iou, mask_positive)
  327. mask_positive_sum = mask_positive.sum(dim=-2)
  328. # 8. make sure every gt_bbox matches the anchor
  329. if self.force_gt_matching:
  330. is_max_iou = compute_max_iou_gt(ious) * pad_gt_mask
  331. mask_max_iou = (is_max_iou.sum(-2, keepdim=True) == 1).tile([1, num_max_boxes, 1])
  332. mask_positive = torch.where(mask_max_iou, is_max_iou, mask_positive)
  333. mask_positive_sum = mask_positive.sum(dim=-2)
  334. assigned_gt_index = mask_positive.argmax(dim=-2)
  335. # assigned target
  336. batch_ind = torch.arange(end=batch_size, dtype=gt_labels.dtype, device=gt_labels.device).unsqueeze(-1)
  337. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  338. assigned_labels = torch.gather(gt_labels.flatten(), index=assigned_gt_index.flatten(), dim=0)
  339. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  340. assigned_labels = torch.where(mask_positive_sum > 0, assigned_labels, torch.full_like(assigned_labels, bg_index))
  341. # assigned_bboxes = torch.gather(gt_bboxes.reshape([-1, 4]), index=assigned_gt_index.flatten(), dim=0)
  342. assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_index.flatten(), :]
  343. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  344. assigned_scores = torch.nn.functional.one_hot(assigned_labels, self.num_classes + 1).float()
  345. ind = list(range(self.num_classes + 1))
  346. ind.remove(bg_index)
  347. assigned_scores = torch.index_select(assigned_scores, index=torch.tensor(ind, device=assigned_scores.device), dim=-1)
  348. if pred_bboxes is not None:
  349. # assigned iou
  350. ious = batch_iou_similarity(gt_bboxes, pred_bboxes) * mask_positive
  351. ious = ious.max(dim=-2).values.unsqueeze(-1)
  352. assigned_scores *= ious
  353. elif gt_scores is not None:
  354. gather_scores = torch.gather(gt_scores.flatten(), assigned_gt_index.flatten(), dim=0)
  355. gather_scores = gather_scores.reshape([batch_size, num_anchors])
  356. gather_scores = torch.where(mask_positive_sum > 0, gather_scores, torch.zeros_like(gather_scores))
  357. assigned_scores *= gather_scores.unsqueeze(-1)
  358. return assigned_labels, assigned_bboxes, assigned_scores
  359. class TaskAlignedAssigner(nn.Module):
  360. """TOOD: Task-aligned One-stage Object Detection"""
  361. def __init__(self, topk=13, alpha=1.0, beta=6.0, eps=1e-9):
  362. """
  363. :param topk: Maximum number of achors that is selected for each gt box
  364. :param alpha: Power factor for class probabilities of predicted boxes (Used compute alignment metric)
  365. :param beta: Power factor for IoU score of predicted boxes (Used compute alignment metric)
  366. :param eps: Small constant for numerical stability
  367. """
  368. super(TaskAlignedAssigner, self).__init__()
  369. self.topk = topk
  370. self.alpha = alpha
  371. self.beta = beta
  372. self.eps = eps
  373. @torch.no_grad()
  374. def forward(
  375. self,
  376. pred_scores: Tensor,
  377. pred_bboxes: Tensor,
  378. anchor_points: Tensor,
  379. num_anchors_list: list,
  380. gt_labels: Tensor,
  381. gt_bboxes: Tensor,
  382. pad_gt_mask: Tensor,
  383. bg_index: int,
  384. gt_scores: Optional[Tensor] = None,
  385. ):
  386. """
  387. This code is based on https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/task_aligned_assigner.py
  388. The assignment is done in following steps
  389. 1. compute alignment metric between all bbox (bbox of all pyramid levels) and gt
  390. 2. select top-k bbox as candidates for each gt
  391. 3. limit the positive sample's center in gt (because the anchor-free detector
  392. only can predict positive distance)
  393. 4. if an anchor box is assigned to multiple gts, the one with the
  394. highest iou will be selected.
  395. :param pred_scores: Tensor (float32): predicted class probability, shape(B, L, C)
  396. :param pred_bboxes: Tensor (float32): predicted bounding boxes, shape(B, L, 4)
  397. :param anchor_points: Tensor (float32): pre-defined anchors, shape(L, 2), "cxcy" format
  398. :param num_anchors_list: List ( num of anchors in each level, shape(L)
  399. :param gt_labels: Tensor (int64|int32): Label of gt_bboxes, shape(B, n, 1)
  400. :param gt_bboxes: Tensor (float32): Ground truth bboxes, shape(B, n, 4)
  401. :param pad_gt_mask: Tensor (float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  402. :param bg_index: int ( background index
  403. :param gt_scores: Tensor (one, float32) Score of gt_bboxes, shape(B, n, 1)
  404. :return:
  405. - assigned_labels, Tensor of shape (B, L)
  406. - assigned_bboxes, Tensor of shape (B, L, 4)
  407. - assigned_scores, Tensor of shape (B, L, C)
  408. """
  409. assert pred_scores.ndim == pred_bboxes.ndim
  410. assert gt_labels.ndim == gt_bboxes.ndim and gt_bboxes.ndim == 3
  411. batch_size, num_anchors, num_classes = pred_scores.shape
  412. _, num_max_boxes, _ = gt_bboxes.shape
  413. # negative batch
  414. if num_max_boxes == 0:
  415. assigned_labels = torch.full([batch_size, num_anchors], bg_index, dtype=torch.long, device=gt_labels.device)
  416. assigned_bboxes = torch.zeros([batch_size, num_anchors, 4], device=gt_labels.device)
  417. assigned_scores = torch.zeros([batch_size, num_anchors, num_classes], device=gt_labels.device)
  418. return assigned_labels, assigned_bboxes, assigned_scores
  419. # compute iou between gt and pred bbox, [B, n, L]
  420. ious = batch_iou_similarity(gt_bboxes, pred_bboxes)
  421. # gather pred bboxes class score
  422. pred_scores = torch.permute(pred_scores, [0, 2, 1])
  423. batch_ind = torch.arange(end=batch_size, dtype=gt_labels.dtype, device=gt_labels.device).unsqueeze(-1)
  424. gt_labels_ind = torch.stack([batch_ind.tile([1, num_max_boxes]), gt_labels.squeeze(-1)], dim=-1)
  425. bbox_cls_scores = pred_scores[gt_labels_ind[..., 0], gt_labels_ind[..., 1]]
  426. # compute alignment metrics, [B, n, L]
  427. alignment_metrics = bbox_cls_scores.pow(self.alpha) * ious.pow(self.beta)
  428. # check the positive sample's center in gt, [B, n, L]
  429. is_in_gts = check_points_inside_bboxes(anchor_points, gt_bboxes)
  430. # select topk largest alignment metrics pred bbox as candidates
  431. # for each gt, [B, n, L]
  432. is_in_topk = gather_topk_anchors(alignment_metrics * is_in_gts, self.topk, topk_mask=pad_gt_mask)
  433. # select positive sample, [B, n, L]
  434. mask_positive = is_in_topk * is_in_gts * pad_gt_mask
  435. # if an anchor box is assigned to multiple gts,
  436. # the one with the highest iou will be selected, [B, n, L]
  437. mask_positive_sum = mask_positive.sum(dim=-2)
  438. if mask_positive_sum.max() > 1:
  439. mask_multiple_gts = (mask_positive_sum.unsqueeze(1) > 1).tile([1, num_max_boxes, 1])
  440. is_max_iou = compute_max_iou_anchor(ious)
  441. mask_positive = torch.where(mask_multiple_gts, is_max_iou, mask_positive)
  442. mask_positive_sum = mask_positive.sum(dim=-2)
  443. assigned_gt_index = mask_positive.argmax(dim=-2)
  444. # assigned target
  445. assigned_gt_index = assigned_gt_index + batch_ind * num_max_boxes
  446. assigned_labels = torch.gather(gt_labels.flatten(), index=assigned_gt_index.flatten(), dim=0)
  447. assigned_labels = assigned_labels.reshape([batch_size, num_anchors])
  448. assigned_labels = torch.where(mask_positive_sum > 0, assigned_labels, torch.full_like(assigned_labels, bg_index))
  449. assigned_bboxes = gt_bboxes.reshape([-1, 4])[assigned_gt_index.flatten(), :]
  450. assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
  451. assigned_scores = torch.nn.functional.one_hot(assigned_labels, num_classes + 1)
  452. ind = list(range(num_classes + 1))
  453. ind.remove(bg_index)
  454. assigned_scores = torch.index_select(assigned_scores, index=torch.tensor(ind, device=assigned_scores.device, dtype=torch.long), dim=-1)
  455. # rescale alignment metrics
  456. alignment_metrics *= mask_positive
  457. max_metrics_per_instance = alignment_metrics.max(dim=-1, keepdim=True).values
  458. max_ious_per_instance = (ious * mask_positive).max(dim=-1, keepdim=True).values
  459. alignment_metrics = alignment_metrics / (max_metrics_per_instance + self.eps) * max_ious_per_instance
  460. alignment_metrics = alignment_metrics.max(dim=-2).values.unsqueeze(-1)
  461. assigned_scores = assigned_scores * alignment_metrics
  462. return assigned_labels, assigned_bboxes, assigned_scores
  463. class GIoULoss(object):
  464. """
  465. Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630
  466. :param loss_weight: giou loss weight, default as 1
  467. :param eps: epsilon to avoid divide by zero, default as 1e-10
  468. :param reduction: Options are "none", "mean" and "sum". default as none
  469. """
  470. def __init__(self, loss_weight: float = 1.0, eps: float = 1e-10, reduction: str = "none"):
  471. self.loss_weight = loss_weight
  472. self.eps = eps
  473. assert reduction in ("none", "mean", "sum")
  474. self.reduction = reduction
  475. def bbox_overlap(self, box1: Tensor, box2: Tensor, eps: float = 1e-10) -> Tuple[Tensor, Tensor, Tensor]:
  476. """
  477. Calculate the iou of box1 and box2.
  478. :param box1: box1 with the shape (..., 4)
  479. :param box2: box1 with the shape (..., 4)
  480. :param eps: epsilon to avoid divide by zero
  481. :return:
  482. - iou: iou of box1 and box2
  483. - overlap: overlap of box1 and box2
  484. - union: union of box1 and box2
  485. """
  486. x1, y1, x2, y2 = box1
  487. x1g, y1g, x2g, y2g = box2
  488. xkis1 = torch.maximum(x1, x1g)
  489. ykis1 = torch.maximum(y1, y1g)
  490. xkis2 = torch.minimum(x2, x2g)
  491. ykis2 = torch.minimum(y2, y2g)
  492. w_inter = (xkis2 - xkis1).clip(0)
  493. h_inter = (ykis2 - ykis1).clip(0)
  494. overlap = w_inter * h_inter
  495. area1 = (x2 - x1) * (y2 - y1)
  496. area2 = (x2g - x1g) * (y2g - y1g)
  497. union = area1 + area2 - overlap + eps
  498. iou = overlap / union
  499. return iou, overlap, union
  500. def __call__(self, pbox: Tensor, gbox: Tensor, iou_weight=1.0, loc_reweight=None):
  501. # x1, y1, x2, y2 = torch.split(pbox, split_size_or_sections=4, dim=-1)
  502. # x1g, y1g, x2g, y2g = torch.split(gbox, split_size_or_sections=4, dim=-1)
  503. x1, y1, x2, y2 = pbox.chunk(4, dim=-1)
  504. x1g, y1g, x2g, y2g = gbox.chunk(4, dim=-1)
  505. box1 = [x1, y1, x2, y2]
  506. box2 = [x1g, y1g, x2g, y2g]
  507. iou, overlap, union = self.bbox_overlap(box1, box2, self.eps)
  508. xc1 = torch.minimum(x1, x1g)
  509. yc1 = torch.minimum(y1, y1g)
  510. xc2 = torch.maximum(x2, x2g)
  511. yc2 = torch.maximum(y2, y2g)
  512. area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
  513. miou = iou - ((area_c - union) / area_c)
  514. if loc_reweight is not None:
  515. loc_reweight = torch.reshape(loc_reweight, shape=(-1, 1))
  516. loc_thresh = 0.9
  517. giou = 1 - (1 - loc_thresh) * miou - loc_thresh * miou * loc_reweight
  518. else:
  519. giou = 1 - miou
  520. if self.reduction == "none":
  521. loss = giou
  522. elif self.reduction == "sum":
  523. loss = torch.sum(giou * iou_weight)
  524. else:
  525. loss = torch.mean(giou * iou_weight)
  526. return loss * self.loss_weight
  527. @register_loss(Losses.PPYOLOE_LOSS)
  528. class PPYoloELoss(nn.Module):
  529. def __init__(
  530. self,
  531. num_classes: int,
  532. use_varifocal_loss: bool = True,
  533. use_static_assigner: bool = True,
  534. reg_max: int = 16,
  535. classification_loss_weight: float = 1.0,
  536. iou_loss_weight: float = 2.5,
  537. dfl_loss_weight: float = 0.5,
  538. ):
  539. """
  540. :param num_classes: Number of classes
  541. :param use_varifocal_loss: Whether to use Varifocal loss for classification loss; otherwise use Focal loss
  542. :param static_assigner_epoch: Whether to use static assigner or Task-Aligned assigner
  543. :param classification_loss_weight: Classification loss weight
  544. :param iou_loss_weight: IoU loss weight
  545. :param dfl_loss_weight: DFL loss weight
  546. :param reg_max: Number of regression bins (Must match the number of bins in the PPYoloE head)
  547. """
  548. super().__init__()
  549. self.use_varifocal_loss = use_varifocal_loss
  550. self.classification_loss_weight = classification_loss_weight
  551. self.dfl_loss_weight = dfl_loss_weight
  552. self.iou_loss_weight = iou_loss_weight
  553. self.iou_loss = GIoULoss()
  554. self.static_assigner = ATSSAssigner(topk=9, num_classes=num_classes)
  555. self.assigner = TaskAlignedAssigner(topk=13, alpha=1.0, beta=6.0)
  556. self.use_static_assigner = use_static_assigner
  557. self.reg_max = reg_max
  558. self.num_classes = num_classes
  559. # Same as in PPYoloE head
  560. proj = torch.linspace(0, self.reg_max, self.reg_max + 1).reshape([1, self.reg_max + 1, 1, 1])
  561. self.register_buffer("proj_conv", proj)
  562. @torch.no_grad()
  563. def _yolox_targets_to_ppyolo(self, targets: torch.Tensor, batch_size: int) -> Mapping[str, torch.Tensor]:
  564. """
  565. Convert targets from YoloX format to PPYolo since its the easiest (not the cleanest) way to
  566. have PP Yolo training & metrics computed
  567. :param targets: (N, 6) format of bboxes is meant to be LABEL_CXCYWH (index, c, cx, cy, w, h)
  568. :return: (Dictionary [str,Tensor]) with keys:
  569. - gt_class: (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  570. - gt_bbox: (Tensor, float32): Ground truth bboxes, shape(B, n, 4) in x1y1x2y2 format
  571. - pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  572. """
  573. image_index = targets[:, 0]
  574. gt_class = targets[:, 1:2].long()
  575. gt_bbox = cxcywh_to_xyxy(targets[:, 2:6], image_shape=None)
  576. per_image_class = []
  577. per_image_bbox = []
  578. per_image_pad_mask = []
  579. max_boxes = 0
  580. for i in range(batch_size):
  581. mask = image_index == i
  582. image_labels = gt_class[mask]
  583. image_bboxes = gt_bbox[mask, :]
  584. valid_bboxes = image_bboxes.sum(dim=1, keepdims=True) > 0
  585. per_image_class.append(image_labels)
  586. per_image_bbox.append(image_bboxes)
  587. per_image_pad_mask.append(valid_bboxes)
  588. max_boxes = max(max_boxes, mask.sum().item())
  589. for i in range(batch_size):
  590. elements_to_pad = max_boxes - len(per_image_class[i])
  591. padding_left = 0
  592. padding_right = 0
  593. padding_top = 0
  594. padding_bottom = elements_to_pad
  595. pad = padding_left, padding_right, padding_top, padding_bottom
  596. per_image_class[i] = F.pad(per_image_class[i], pad, mode="constant", value=0)
  597. per_image_bbox[i] = F.pad(per_image_bbox[i], pad, mode="constant", value=0)
  598. per_image_pad_mask[i] = F.pad(per_image_pad_mask[i], pad, mode="constant", value=0)
  599. return {
  600. "gt_class": torch.stack(per_image_class, dim=0),
  601. "gt_bbox": torch.stack(per_image_bbox, dim=0),
  602. "pad_gt_mask": torch.stack(per_image_pad_mask, dim=0),
  603. }
  604. def forward(
  605. self,
  606. outputs: Union[
  607. Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor], Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]]
  608. ],
  609. targets: Tensor,
  610. ) -> Mapping[str, Tensor]:
  611. """
  612. :param outputs: Tuple of pred_scores, pred_distri, anchors, anchor_points, num_anchors_list, stride_tensor
  613. :param targets: (Dictionary [str,Tensor]) with keys:
  614. - gt_class: (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
  615. - gt_bbox: (Tensor, float32): Ground truth bboxes, shape(B, n, 4) in x1y1x2y2 format
  616. - pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
  617. :return:
  618. """
  619. # in test/eval mode the model outputs a tuple where the second item is the raw predictions
  620. if isinstance(outputs, tuple) and len(outputs) == 2:
  621. # in test/eval mode the Yolo model outputs a tuple where the second item is the raw predictions
  622. _, predictions = outputs
  623. else:
  624. predictions = outputs
  625. (
  626. pred_scores,
  627. pred_distri,
  628. anchors,
  629. anchor_points,
  630. num_anchors_list,
  631. stride_tensor,
  632. ) = predictions
  633. targets = self._yolox_targets_to_ppyolo(targets, batch_size=pred_scores.size(0)) # yolox -> ppyolo
  634. anchor_points_s = anchor_points / stride_tensor
  635. pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
  636. gt_labels = targets["gt_class"]
  637. gt_bboxes = targets["gt_bbox"]
  638. pad_gt_mask = targets["pad_gt_mask"]
  639. # label assignment
  640. if self.use_static_assigner:
  641. assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
  642. anchor_bboxes=anchors,
  643. num_anchors_list=num_anchors_list,
  644. gt_labels=gt_labels,
  645. gt_bboxes=gt_bboxes,
  646. pad_gt_mask=pad_gt_mask,
  647. bg_index=self.num_classes,
  648. pred_bboxes=pred_bboxes.detach() * stride_tensor,
  649. )
  650. alpha_l = 0.25
  651. else:
  652. assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
  653. pred_scores=pred_scores.detach().sigmoid(), # Pred scores are logits on training for numerical stability
  654. pred_bboxes=pred_bboxes.detach() * stride_tensor,
  655. anchor_points=anchor_points,
  656. num_anchors_list=num_anchors_list,
  657. gt_labels=gt_labels,
  658. gt_bboxes=gt_bboxes,
  659. pad_gt_mask=pad_gt_mask,
  660. bg_index=self.num_classes,
  661. )
  662. alpha_l = -1
  663. # rescale bbox
  664. assigned_bboxes /= stride_tensor
  665. # cls loss
  666. if self.use_varifocal_loss:
  667. one_hot_label = torch.nn.functional.one_hot(assigned_labels, self.num_classes + 1)[..., :-1]
  668. loss_cls = self._varifocal_loss(pred_scores, assigned_scores, one_hot_label)
  669. else:
  670. loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
  671. assigned_scores_sum = assigned_scores.sum()
  672. if super_gradients.is_distributed():
  673. torch.distributed.all_reduce(assigned_scores_sum, op=torch.distributed.ReduceOp.SUM)
  674. assigned_scores_sum /= get_world_size()
  675. assigned_scores_sum = torch.clip(assigned_scores_sum, min=1.0)
  676. loss_cls /= assigned_scores_sum
  677. loss_iou, loss_dfl = self._bbox_loss(
  678. pred_distri,
  679. pred_bboxes,
  680. anchor_points_s,
  681. assigned_labels,
  682. assigned_bboxes,
  683. assigned_scores,
  684. assigned_scores_sum,
  685. )
  686. loss = self.classification_loss_weight * loss_cls + self.iou_loss_weight * loss_iou + self.dfl_loss_weight * loss_dfl
  687. log_losses = torch.stack([loss_cls.detach(), loss_iou.detach(), loss_dfl.detach(), loss.detach()])
  688. return loss, log_losses
  689. @property
  690. def component_names(self):
  691. return ["loss_cls", "loss_iou", "loss_dfl", "loss"]
  692. def _df_loss(self, pred_dist: Tensor, target: Tensor) -> Tensor:
  693. target_left = target.long()
  694. target_right = target_left + 1
  695. weight_left = target_right.float() - target
  696. weight_right = 1 - weight_left
  697. # [B,L,C] -> [B,C,L] to make compatible with torch.nn.functional.cross_entropy
  698. # which expects channel dim to be at index 1
  699. pred_dist = torch.moveaxis(pred_dist, -1, 1)
  700. loss_left = torch.nn.functional.cross_entropy(pred_dist, target_left, reduction="none") * weight_left
  701. loss_right = torch.nn.functional.cross_entropy(pred_dist, target_right, reduction="none") * weight_right
  702. return (loss_left + loss_right).mean(dim=-1, keepdim=True)
  703. def _bbox_loss(
  704. self,
  705. pred_dist,
  706. pred_bboxes,
  707. anchor_points,
  708. assigned_labels,
  709. assigned_bboxes,
  710. assigned_scores,
  711. assigned_scores_sum,
  712. ):
  713. # select positive samples mask
  714. mask_positive = assigned_labels != self.num_classes
  715. num_pos = mask_positive.sum()
  716. # pos/neg loss
  717. if num_pos > 0:
  718. # l1 + iou
  719. bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
  720. pred_bboxes_pos = torch.masked_select(pred_bboxes, bbox_mask).reshape([-1, 4])
  721. assigned_bboxes_pos = torch.masked_select(assigned_bboxes, bbox_mask).reshape([-1, 4])
  722. bbox_weight = torch.masked_select(assigned_scores.sum(-1), mask_positive).unsqueeze(-1)
  723. loss_iou = self.iou_loss(pred_bboxes_pos, assigned_bboxes_pos) * bbox_weight
  724. loss_iou = loss_iou.sum() / assigned_scores_sum
  725. dist_mask = mask_positive.unsqueeze(-1).tile([1, 1, (self.reg_max + 1) * 4])
  726. pred_dist_pos = torch.masked_select(pred_dist, dist_mask).reshape([-1, 4, self.reg_max + 1])
  727. assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
  728. assigned_ltrb_pos = torch.masked_select(assigned_ltrb, bbox_mask).reshape([-1, 4])
  729. loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos) * bbox_weight
  730. loss_dfl = loss_dfl.sum() / assigned_scores_sum
  731. else:
  732. loss_iou = torch.zeros([], device=pred_bboxes.device)
  733. loss_dfl = pred_dist.sum() * 0.0
  734. return loss_iou, loss_dfl
  735. def _bbox_decode(self, anchor_points: Tensor, pred_dist: Tensor):
  736. b, l, *_ = pred_dist.size()
  737. pred_dist = torch.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1]), dim=-1)
  738. pred_dist = torch.nn.functional.conv2d(pred_dist.permute(0, 3, 1, 2), self.proj_conv).squeeze(1)
  739. return batch_distance2bbox(anchor_points, pred_dist)
  740. def _bbox2distance(self, points, bbox):
  741. x1y1, x2y2 = torch.split(bbox, 2, -1)
  742. lt = points - x1y1
  743. rb = x2y2 - points
  744. return torch.cat([lt, rb], dim=-1).clip(0, self.reg_max - 0.01)
  745. @staticmethod
  746. def _focal_loss(pred_logits: Tensor, label: Tensor, alpha=0.25, gamma=2.0) -> Tensor:
  747. pred_score = pred_logits.sigmoid()
  748. weight = (pred_score - label).pow(gamma)
  749. if alpha > 0:
  750. alpha_t = alpha * label + (1 - alpha) * (1 - label)
  751. weight *= alpha_t
  752. loss = -weight * (label * torch.nn.functional.logsigmoid(pred_logits) + (1 - label) * torch.nn.functional.logsigmoid(-pred_logits))
  753. return loss.sum()
  754. @staticmethod
  755. def _varifocal_loss(pred_logits: Tensor, gt_score: Tensor, label: Tensor, alpha=0.75, gamma=2.0) -> Tensor:
  756. pred_score = pred_logits.sigmoid()
  757. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  758. loss = -weight * (gt_score * torch.nn.functional.logsigmoid(pred_logits) + (1 - gt_score) * torch.nn.functional.logsigmoid(-pred_logits))
  759. return loss.sum()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file