Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#257 allow using an external Optimizer (not initialized outside)

Merged
Ofri Masad merged 1 commits into Deci-AI:master from deci-ai:feature/SG-184_external_optimizer
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
  1. from typing import Tuple
  2. import torch
  3. from torch import nn
  4. from torch.nn.modules.loss import _Loss
  5. import torch.nn.functional as F
  6. from super_gradients.training.utils.detection_utils import calculate_bbox_iou_matrix
  7. from super_gradients.training.utils.ssd_utils import DefaultBoxes
  8. class HardMiningCrossEntropyLoss(_Loss):
  9. """
  10. L_cls = [CE of all positives] + [CE of the hardest backgrounds]
  11. where the second term is built from [neg_pos_ratio * positive pairs] background cells with the highest CE
  12. (the hardest background cells)
  13. """
  14. def __init__(self, neg_pos_ratio: float):
  15. """
  16. :param neg_pos_ratio: a ratio of negative samples to positive samples in the loss
  17. (unlike positives, not all negatives will be used:
  18. for each positive the [neg_pos_ratio] hardest negatives will be selected)
  19. """
  20. super().__init__()
  21. self.neg_pos_ratio = neg_pos_ratio
  22. self.ce = nn.CrossEntropyLoss(reduce=False)
  23. def forward(self, pred_labels, target_labels):
  24. mask = target_labels > 0 # not background
  25. pos_num = mask.sum(dim=1)
  26. # HARD NEGATIVE MINING
  27. con = self.ce(pred_labels, target_labels)
  28. # POSITIVE MASK WILL NOT BE SELECTED
  29. # set 0. loss for all positive objects, leave the loss where the object is background
  30. con_neg = con.clone()
  31. con_neg[mask] = 0
  32. # sort background cells by CE loss value (bigger_first)
  33. _, con_idx = con_neg.sort(dim=1, descending=True)
  34. # restore cells order, get each cell's order (rank) in CE loss sorting
  35. _, con_rank = con_idx.sort(dim=1)
  36. # NUMBER OF NEGATIVE THREE TIMES POSITIVE
  37. neg_num = torch.clamp(self.neg_pos_ratio * pos_num, max=mask.size(1)).unsqueeze(-1)
  38. # for each image into neg mask we'll take (3 * positive pairs) background objects with the highest CE
  39. neg_mask = con_rank < neg_num
  40. closs = (con * (mask.float() + neg_mask.float())).sum(dim=1)
  41. return closs
  42. class SSDLoss(_Loss):
  43. """
  44. Implements the loss as the sum of the followings:
  45. 1. Confidence Loss: All labels, with hard negative mining
  46. 2. Localization Loss: Only on positive labels
  47. L = (2 - alpha) * L_l1 + alpha * L_cls, where
  48. * L_cls is HardMiningCrossEntropyLoss
  49. * L_l1 = [SmoothL1Loss for all positives]
  50. """
  51. def __init__(self, dboxes: DefaultBoxes, alpha: float = 1.0, iou_thresh: float = 0.5, neg_pos_ratio: float = 3.):
  52. """
  53. :param dboxes: model anchors, shape [Num Grid Cells * Num anchors x 4]
  54. :param alpha: a weighting factor between classification and regression loss
  55. :param iou_thresh: a threshold for matching of anchors in each grid cell to GTs
  56. (a match should have IoU > iou_thresh)
  57. :param neg_pos_ratio: a ratio for HardMiningCrossEntropyLoss
  58. """
  59. super(SSDLoss, self).__init__()
  60. self.scale_xy = dboxes.scale_xy
  61. self.scale_wh = dboxes.scale_wh
  62. self.alpha = alpha
  63. self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim=0), requires_grad=False)
  64. self.sl1_loss = nn.SmoothL1Loss(reduce=False)
  65. self.con_loss = HardMiningCrossEntropyLoss(neg_pos_ratio)
  66. self.iou_thresh = iou_thresh
  67. def _norm_relative_bbox(self, loc):
  68. """
  69. convert bbox locations into relative locations (relative to the dboxes)
  70. :param loc a tensor of shape [batch, 4, num_boxes]
  71. """
  72. gxy = ((loc[:, :2, :] - self.dboxes[:, :2, :]) / self.dboxes[:, 2:, ]) / self.scale_xy
  73. gwh = (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log() / self.scale_wh
  74. return torch.cat((gxy, gwh), dim=1).contiguous()
  75. def match_dboxes(self, targets):
  76. """
  77. creates tensors with target boxes and labels for each dboxes, so with the same len as dboxes.
  78. * Each GT is assigned with a grid cell with the highest IoU, this creates a pair for each GT and some cells;
  79. * The rest of grid cells are assigned to a GT with the highest IoU, assuming it's > self.iou_thresh;
  80. If this condition is not met the grid cell is marked as background
  81. GT-wise: one to many
  82. Grid-cell-wise: one to one
  83. :param targets: a tensor containing the boxes for a single image;
  84. shape [num_boxes, 6] (image_id, label, x, y, w, h)
  85. :return: two tensors
  86. boxes - shape of dboxes [4, num_dboxes] (x,y,w,h)
  87. labels - sahpe [num_dboxes]
  88. """
  89. device = targets.device
  90. each_cell_target_locations = self.dboxes.data.clone().squeeze()
  91. each_cell_target_labels = torch.zeros((self.dboxes.data.shape[2])).to(device)
  92. if len(targets) > 0:
  93. target_boxes = targets[:, 2:]
  94. target_labels = targets[:, 1]
  95. ious = calculate_bbox_iou_matrix(target_boxes, self.dboxes.data.squeeze().T, x1y1x2y2=False)
  96. # one best GT for EACH cell (does not guarantee that all GTs will be used)
  97. best_target_per_cell, best_target_per_cell_index = ious.max(0)
  98. # one best grid cell (anchor in it) for EACH target
  99. best_cell_per_target, best_cell_per_target_index = ious.max(1)
  100. # make sure EACH target has a grid cell assigned
  101. best_target_per_cell_index[best_cell_per_target_index] = torch.arange(len(targets)).to(device)
  102. # 2. is higher than any IoU, so it is guaranteed to pass any IoU threshold
  103. # which ensures that the pairs selected for each target will be included in the mask below
  104. # while the threshold will only affect other grid cell anchors that aren't pre-assigned to any target
  105. best_target_per_cell[best_cell_per_target_index] = 2.
  106. mask = best_target_per_cell > self.iou_thresh
  107. each_cell_target_locations[:, mask] = target_boxes[best_target_per_cell_index[mask]].T
  108. each_cell_target_labels[mask] = target_labels[best_target_per_cell_index[mask]] + 1
  109. return each_cell_target_locations, each_cell_target_labels
  110. def forward(self, predictions: Tuple, targets):
  111. """
  112. Compute the loss
  113. :param predictions - predictions tensor coming from the network,
  114. tuple with shapes ([Batch Size, 4, num_dboxes], [Batch Size, num_classes + 1, num_dboxes])
  115. were predictions have logprobs for background and other classes
  116. :param targets - targets for the batch. [num targets, 6] (index in batch, label, x,y,w,h)
  117. """
  118. if isinstance(predictions, tuple) and isinstance(predictions[1], tuple):
  119. # Calculate loss in a validation mode
  120. predictions = predictions[1]
  121. batch_target_locations = []
  122. batch_target_labels = []
  123. (ploc, plabel) = predictions
  124. targets = targets.to(self.dboxes.device)
  125. for i in range(ploc.shape[0]):
  126. target_locations, target_labels = self.match_dboxes(targets[targets[:, 0] == i])
  127. batch_target_locations.append(target_locations)
  128. batch_target_labels.append(target_labels)
  129. batch_target_locations = torch.stack(batch_target_locations)
  130. batch_target_labels = torch.stack(batch_target_labels).type(torch.long)
  131. mask = batch_target_labels > 0 # not background
  132. pos_num = mask.sum(dim=1)
  133. vec_gd = self._norm_relative_bbox(batch_target_locations)
  134. # SUM ON FOUR COORDINATES, AND MASK
  135. sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1)
  136. sl1 = (mask.float() * sl1).sum(dim=1)
  137. closs = self.con_loss(plabel, batch_target_labels)
  138. # AVOID NO OBJECT DETECTED
  139. total_loss = (2 - self.alpha) * sl1 + self.alpha * closs
  140. num_mask = (pos_num > 0).float() # a mask with 0 for images that have no positive pairs at all
  141. pos_num = pos_num.float().clamp(min=1e-6)
  142. ret = (total_loss * num_mask / pos_num).mean(dim=0) # normalize by the number of positive pairs
  143. return ret, torch.cat((sl1.mean().unsqueeze(0), closs.mean().unsqueeze(0), ret.unsqueeze(0))).detach()
Discard
Tip!

Press p or to see the previous file or, n or to see the next file