Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#970 Update YoloNASQuickstart.md

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_readme_yolonas_snippets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
  1. from typing import Tuple
  2. import torch
  3. from torch import nn
  4. from torch.nn.modules.loss import _Loss
  5. from super_gradients.common.object_names import Losses
  6. from super_gradients.common.registry.registry import register_loss
  7. from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
  8. from super_gradients.training.utils.ssd_utils import DefaultBoxes
  9. class HardMiningCrossEntropyLoss(_Loss):
  10. """
  11. L_cls = [CE of all positives] + [CE of the hardest backgrounds]
  12. where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE
  13. (the hardest background cells)
  14. """
  15. def __init__(self, neg_pos_ratio: float):
  16. """
  17. :param neg_pos_ratio: a ratio of negative samples to positive samples in the loss
  18. (unlike positives, not all negatives will be used:
  19. for each positive the [neg_pos_ratio] hardest negatives will be selected)
  20. """
  21. super().__init__()
  22. self.neg_pos_ratio = neg_pos_ratio
  23. self.ce = nn.CrossEntropyLoss(reduce=False)
  24. def forward(self, pred_labels, target_labels):
  25. mask = target_labels > 0 # not background
  26. pos_num = mask.sum(dim=1)
  27. # HARD NEGATIVE MINING
  28. con = self.ce(pred_labels, target_labels)
  29. # POSITIVE MASK WILL NOT BE SELECTED
  30. # set 0. loss for all positive objects, leave the loss where the object is background
  31. con_neg = con.clone()
  32. con_neg[mask] = 0
  33. # sort background cells by CE loss value (bigger_first)
  34. _, con_idx = con_neg.sort(dim=1, descending=True)
  35. # restore cells order, get each cell's order (rank) in CE loss sorting
  36. _, con_rank = con_idx.sort(dim=1)
  37. # NUMBER OF NEGATIVE THREE TIMES POSITIVE
  38. neg_num = torch.clamp(self.neg_pos_ratio * pos_num, max=mask.size(1)).unsqueeze(-1)
  39. # for each image into neg mask we'll take (3 * positive pairs) background objects with the highest CE
  40. neg_mask = con_rank < neg_num
  41. closs = (con * (mask.float() + neg_mask.float())).sum(dim=1)
  42. return closs
  43. @register_loss(Losses.SSD_LOSS)
  44. class SSDLoss(_Loss):
  45. """
  46. Implements the loss as the sum of the followings:
  47. 1. Confidence Loss: All labels, with hard negative mining
  48. 2. Localization Loss: Only on positive labels
  49. L = (2 - alpha) * L_l1 + alpha * L_cls, where
  50. * L_cls is HardMiningCrossEntropyLoss
  51. * L_l1 = [SmoothL1Loss for all positives]
  52. """
  53. def __init__(self, dboxes: DefaultBoxes, alpha: float = 1.0, iou_thresh: float = 0.5, neg_pos_ratio: float = 3.0):
  54. """
  55. :param dboxes: model anchors, shape [Num Grid Cells * Num anchors x 4]
  56. :param alpha: a weighting factor between classification and regression loss
  57. :param iou_thresh: a threshold for matching of anchors in each grid cell to GTs
  58. (a match should have IoU > iou_thresh)
  59. :param neg_pos_ratio: a ratio for HardMiningCrossEntropyLoss
  60. """
  61. super(SSDLoss, self).__init__()
  62. self.scale_xy = dboxes.scale_xy
  63. self.scale_wh = dboxes.scale_wh
  64. self.alpha = alpha
  65. self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim=0), requires_grad=False)
  66. self.sl1_loss = nn.SmoothL1Loss(reduce=False)
  67. self.con_loss = HardMiningCrossEntropyLoss(neg_pos_ratio)
  68. self.iou_thresh = iou_thresh
  69. @property
  70. def component_names(self):
  71. """
  72. Component names for logging during training.
  73. These correspond to 2nd item in the tuple returned in self.forward(...).
  74. See super_gradients.Trainer.train() docs for more info.
  75. """
  76. return ["smooth_l1", "closs", "Loss"]
  77. def _norm_relative_bbox(self, loc):
  78. """
  79. convert bbox locations into relative locations (relative to the dboxes)
  80. :param loc a tensor of shape [batch, 4, num_boxes]
  81. """
  82. gxy = (
  83. (loc[:, :2, :] - self.dboxes[:, :2, :])
  84. / self.dboxes[
  85. :,
  86. 2:,
  87. ]
  88. ) / self.scale_xy
  89. gwh = (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log() / self.scale_wh
  90. return torch.cat((gxy, gwh), dim=1).contiguous()
  91. def match_dboxes(self, targets):
  92. """
  93. creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.
  94. * Each GT is assigned with a grid cell with the highest IoU, this creates a pair for each GT and some cells;
  95. * The rest of grid cells are assigned to a GT with the highest IoU, assuming it's > self.iou_thresh;
  96. If this condition is not met the grid cell is marked as background
  97. GT-wise: one to many
  98. Grid-cell-wise: one to one
  99. :param targets: a tensor containing the boxes for a single image;
  100. shape [num_boxes, 6] (image_id, label, x, y, w, h)
  101. :return: two tensors
  102. boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)
  103. labels - sahpe [num_dboxes]
  104. """
  105. device = targets.device
  106. each_cell_target_locations = self.dboxes.data.clone().squeeze()
  107. each_cell_target_labels = torch.zeros((self.dboxes.data.shape[2])).to(device)
  108. if len(targets) > 0:
  109. target_boxes = targets[:, 2:]
  110. target_labels = targets[:, 1]
  111. ious = calculate_bbox_iou_matrix(target_boxes, self.dboxes.data.squeeze().T, x1y1x2y2=False)
  112. # one best GT for EACH cell (does not guarantee that all GTs will be used)
  113. best_target_per_cell, best_target_per_cell_index = ious.max(0)
  114. # one best grid cell (anchor in it) for EACH target
  115. best_cell_per_target, best_cell_per_target_index = ious.max(1)
  116. # make sure EACH target has a grid cell assigned
  117. best_target_per_cell_index[best_cell_per_target_index] = torch.arange(len(targets)).to(device)
  118. # 2. is higher than any IoU, so it is guaranteed to pass any IoU threshold
  119. # which ensures that the pairs selected for each target will be included in the mask below
  120. # while the threshold will only affect other grid cell anchors that aren't pre-assigned to any target
  121. best_target_per_cell[best_cell_per_target_index] = 2.0
  122. mask = best_target_per_cell > self.iou_thresh
  123. each_cell_target_locations[:, mask] = target_boxes[best_target_per_cell_index[mask]].T
  124. each_cell_target_labels[mask] = target_labels[best_target_per_cell_index[mask]] + 1
  125. return each_cell_target_locations, each_cell_target_labels
  126. def forward(self, predictions: Tuple, targets):
  127. """
  128. Compute the loss
  129. :param predictions - predictions tensor coming from the network,
  130. tuple with shapes ([Batch Size, 4, num_dboxes], [Batch Size, num_classes + 1, num_dboxes])
  131. were predictions have logprobs for background and other classes
  132. :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)
  133. """
  134. if isinstance(predictions, tuple) and isinstance(predictions[1], tuple):
  135. # Calculate loss in a validation mode
  136. predictions = predictions[1]
  137. batch_target_locations = []
  138. batch_target_labels = []
  139. (ploc, plabel) = predictions
  140. targets = targets.to(self.dboxes.device)
  141. for i in range(ploc.shape[0]):
  142. target_locations, target_labels = self.match_dboxes(targets[targets[:, 0] == i])
  143. batch_target_locations.append(target_locations)
  144. batch_target_labels.append(target_labels)
  145. batch_target_locations = torch.stack(batch_target_locations)
  146. batch_target_labels = torch.stack(batch_target_labels).type(torch.long)
  147. mask = batch_target_labels > 0 # not background
  148. pos_num = mask.sum(dim=1)
  149. vec_gd = self._norm_relative_bbox(batch_target_locations)
  150. # SUM ON FOUR COORDINATES, AND MASK
  151. sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
  152. sl1 = (mask.float() * sl1).sum(dim=1)
  153. closs = self.con_loss(plabel, batch_target_labels)
  154. # AVOID NO OBJECT DETECTED
  155. total_loss = (2 - self.alpha) * sl1 + self.alpha * closs
  156. num_mask = (pos_num > 0).float() # a mask with 0 for images that have no positive pairs at all
  157. pos_num = pos_num.float().clamp(min=1e-6)
  158. ret = (total_loss * num_mask / pos_num).mean(dim=0) # normalize by the number of positive pairs
  159. return ret, torch.cat((sl1.mean().unsqueeze(0), closs.mean().unsqueeze(0), ret.unsqueeze(0))).detach()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file