Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#378 Feature/sg 281 add kd notebook

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-281-add_kd_notebook
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
  1. from typing import Tuple
  2. import torch
  3. from torch import nn
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
  6. from super_gradients.training.utils.ssd_utils import DefaultBoxes
  7. class HardMiningCrossEntropyLoss(_Loss):
  8. """
  9. L_cls = [CE of all positives] + [CE of the hardest backgrounds]
  10. where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE
  11. (the hardest background cells)
  12. """
  13. def __init__(self, neg_pos_ratio: float):
  14. """
  15. :param neg_pos_ratio: a ratio of negative samples to positive samples in the loss
  16. (unlike positives, not all negatives will be used:
  17. for each positive the [neg_pos_ratio] hardest negatives will be selected)
  18. """
  19. super().__init__()
  20. self.neg_pos_ratio = neg_pos_ratio
  21. self.ce = nn.CrossEntropyLoss(reduce=False)
  22. def forward(self, pred_labels, target_labels):
  23. mask = target_labels > 0 # not background
  24. pos_num = mask.sum(dim=1)
  25. # HARD NEGATIVE MINING
  26. con = self.ce(pred_labels, target_labels)
  27. # POSITIVE MASK WILL NOT BE SELECTED
  28. # set 0. loss for all positive objects, leave the loss where the object is background
  29. con_neg = con.clone()
  30. con_neg[mask] = 0
  31. # sort background cells by CE loss value (bigger_first)
  32. _, con_idx = con_neg.sort(dim=1, descending=True)
  33. # restore cells order, get each cell's order (rank) in CE loss sorting
  34. _, con_rank = con_idx.sort(dim=1)
  35. # NUMBER OF NEGATIVE THREE TIMES POSITIVE
  36. neg_num = torch.clamp(self.neg_pos_ratio * pos_num, max=mask.size(1)).unsqueeze(-1)
  37. # for each image into neg mask we'll take (3 * positive pairs) background objects with the highest CE
  38. neg_mask = con_rank < neg_num
  39. closs = (con * (mask.float() + neg_mask.float())).sum(dim=1)
  40. return closs
  41. class SSDLoss(_Loss):
  42. """
  43. Implements the loss as the sum of the followings:
  44. 1. Confidence Loss: All labels, with hard negative mining
  45. 2. Localization Loss: Only on positive labels
  46. L = (2 - alpha) * L_l1 + alpha * L_cls, where
  47. * L_cls is HardMiningCrossEntropyLoss
  48. * L_l1 = [SmoothL1Loss for all positives]
  49. """
  50. def __init__(self, dboxes: DefaultBoxes, alpha: float = 1.0, iou_thresh: float = 0.5, neg_pos_ratio: float = 3.):
  51. """
  52. :param dboxes: model anchors, shape [Num Grid Cells * Num anchors x 4]
  53. :param alpha: a weighting factor between classification and regression loss
  54. :param iou_thresh: a threshold for matching of anchors in each grid cell to GTs
  55. (a match should have IoU > iou_thresh)
  56. :param neg_pos_ratio: a ratio for HardMiningCrossEntropyLoss
  57. """
  58. super(SSDLoss, self).__init__()
  59. self.scale_xy = dboxes.scale_xy
  60. self.scale_wh = dboxes.scale_wh
  61. self.alpha = alpha
  62. self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim=0), requires_grad=False)
  63. self.sl1_loss = nn.SmoothL1Loss(reduce=False)
  64. self.con_loss = HardMiningCrossEntropyLoss(neg_pos_ratio)
  65. self.iou_thresh = iou_thresh
  66. def _norm_relative_bbox(self, loc):
  67. """
  68. convert bbox locations into relative locations (relative to the dboxes)
  69. :param loc a tensor of shape [batch, 4, num_boxes]
  70. """
  71. gxy = ((loc[:, :2, :] - self.dboxes[:, :2, :]) / self.dboxes[:, 2:, ]) / self.scale_xy
  72. gwh = (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log() / self.scale_wh
  73. return torch.cat((gxy, gwh), dim=1).contiguous()
  74. def match_dboxes(self, targets):
  75. """
  76. creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.
  77. * Each GT is assigned with a grid cell with the highest IoU, this creates a pair for each GT and some cells;
  78. * The rest of grid cells are assigned to a GT with the highest IoU, assuming it's > self.iou_thresh;
  79. If this condition is not met the grid cell is marked as background
  80. GT-wise: one to many
  81. Grid-cell-wise: one to one
  82. :param targets: a tensor containing the boxes for a single image;
  83. shape [num_boxes, 6] (image_id, label, x, y, w, h)
  84. :return: two tensors
  85. boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)
  86. labels - sahpe [num_dboxes]
  87. """
  88. device = targets.device
  89. each_cell_target_locations = self.dboxes.data.clone().squeeze()
  90. each_cell_target_labels = torch.zeros((self.dboxes.data.shape[2])).to(device)
  91. if len(targets) > 0:
  92. target_boxes = targets[:, 2:]
  93. target_labels = targets[:, 1]
  94. ious = calculate_bbox_iou_matrix(target_boxes, self.dboxes.data.squeeze().T, x1y1x2y2=False)
  95. # one best GT for EACH cell (does not guarantee that all GTs will be used)
  96. best_target_per_cell, best_target_per_cell_index = ious.max(0)
  97. # one best grid cell (anchor in it) for EACH target
  98. best_cell_per_target, best_cell_per_target_index = ious.max(1)
  99. # make sure EACH target has a grid cell assigned
  100. best_target_per_cell_index[best_cell_per_target_index] = torch.arange(len(targets)).to(device)
  101. # 2. is higher than any IoU, so it is guaranteed to pass any IoU threshold
  102. # which ensures that the pairs selected for each target will be included in the mask below
  103. # while the threshold will only affect other grid cell anchors that aren't pre-assigned to any target
  104. best_target_per_cell[best_cell_per_target_index] = 2.
  105. mask = best_target_per_cell > self.iou_thresh
  106. each_cell_target_locations[:, mask] = target_boxes[best_target_per_cell_index[mask]].T
  107. each_cell_target_labels[mask] = target_labels[best_target_per_cell_index[mask]] + 1
  108. return each_cell_target_locations, each_cell_target_labels
  109. def forward(self, predictions: Tuple, targets):
  110. """
  111. Compute the loss
  112. :param predictions - predictions tensor coming from the network,
  113. tuple with shapes ([Batch Size, 4, num_dboxes], [Batch Size, num_classes + 1, num_dboxes])
  114. were predictions have logprobs for background and other classes
  115. :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)
  116. """
  117. if isinstance(predictions, tuple) and isinstance(predictions[1], tuple):
  118. # Calculate loss in a validation mode
  119. predictions = predictions[1]
  120. batch_target_locations = []
  121. batch_target_labels = []
  122. (ploc, plabel) = predictions
  123. targets = targets.to(self.dboxes.device)
  124. for i in range(ploc.shape[0]):
  125. target_locations, target_labels = self.match_dboxes(targets[targets[:, 0] == i])
  126. batch_target_locations.append(target_locations)
  127. batch_target_labels.append(target_labels)
  128. batch_target_locations = torch.stack(batch_target_locations)
  129. batch_target_labels = torch.stack(batch_target_labels).type(torch.long)
  130. mask = batch_target_labels > 0 # not background
  131. pos_num = mask.sum(dim=1)
  132. vec_gd = self._norm_relative_bbox(batch_target_locations)
  133. # SUM ON FOUR COORDINATES, AND MASK
  134. sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
  135. sl1 = (mask.float() * sl1).sum(dim=1)
  136. closs = self.con_loss(plabel, batch_target_labels)
  137. # AVOID NO OBJECT DETECTED
  138. total_loss = (2 - self.alpha) * sl1 + self.alpha * closs
  139. num_mask = (pos_num > 0).float() # a mask with 0 for images that have no positive pairs at all
  140. pos_num = pos_num.float().clamp(min=1e-6)
  141. ret = (total_loss * num_mask / pos_num).mean(dim=0) # normalize by the number of positive pairs
  142. return ret, torch.cat((sl1.mean().unsqueeze(0), closs.mean().unsqueeze(0), ret.unsqueeze(0))).detach()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file