Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 9.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
  1. # Loss functions
  2. import torch
  3. import torch.nn as nn
  4. from utils.general import bbox_iou
  5. from utils.torch_utils import is_parallel
  6. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  7. # return positive, negative label smoothing BCE targets
  8. return 1.0 - 0.5 * eps, 0.5 * eps
  9. class BCEBlurWithLogitsLoss(nn.Module):
  10. # BCEwithLogitLoss() with reduced missing label effects.
  11. def __init__(self, alpha=0.05):
  12. super(BCEBlurWithLogitsLoss, self).__init__()
  13. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  14. self.alpha = alpha
  15. def forward(self, pred, true):
  16. loss = self.loss_fcn(pred, true)
  17. pred = torch.sigmoid(pred) # prob from logits
  18. dx = pred - true # reduce only missing label effects
  19. # dx = (pred - true).abs() # reduce missing label and false label effects
  20. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  21. loss *= alpha_factor
  22. return loss.mean()
  23. class FocalLoss(nn.Module):
  24. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  25. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  26. super(FocalLoss, self).__init__()
  27. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  28. self.gamma = gamma
  29. self.alpha = alpha
  30. self.reduction = loss_fcn.reduction
  31. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  32. def forward(self, pred, true):
  33. loss = self.loss_fcn(pred, true)
  34. # p_t = torch.exp(-loss)
  35. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  36. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  37. pred_prob = torch.sigmoid(pred) # prob from logits
  38. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  39. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  40. modulating_factor = (1.0 - p_t) ** self.gamma
  41. loss *= alpha_factor * modulating_factor
  42. if self.reduction == 'mean':
  43. return loss.mean()
  44. elif self.reduction == 'sum':
  45. return loss.sum()
  46. else: # 'none'
  47. return loss
  48. class QFocalLoss(nn.Module):
  49. # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  50. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  51. super(QFocalLoss, self).__init__()
  52. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  53. self.gamma = gamma
  54. self.alpha = alpha
  55. self.reduction = loss_fcn.reduction
  56. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  57. def forward(self, pred, true):
  58. loss = self.loss_fcn(pred, true)
  59. pred_prob = torch.sigmoid(pred) # prob from logits
  60. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  61. modulating_factor = torch.abs(true - pred_prob) ** self.gamma
  62. loss *= alpha_factor * modulating_factor
  63. if self.reduction == 'mean':
  64. return loss.mean()
  65. elif self.reduction == 'sum':
  66. return loss.sum()
  67. else: # 'none'
  68. return loss
  69. class ComputeLoss:
  70. # Compute losses
  71. def __init__(self, model, autobalance=False):
  72. super(ComputeLoss, self).__init__()
  73. self.sort_obj_iou = False
  74. device = next(model.parameters()).device # get model device
  75. h = model.hyp # hyperparameters
  76. # Define criteria
  77. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
  78. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
  79. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  80. self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
  81. # Focal loss
  82. g = h['fl_gamma'] # focal loss gamma
  83. if g > 0:
  84. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  85. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  86. self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
  87. self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
  88. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
  89. for k in 'na', 'nc', 'nl', 'anchors':
  90. setattr(self, k, getattr(det, k))
  91. def __call__(self, p, targets): # predictions, targets, model
  92. device = targets.device
  93. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  94. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
  95. # Losses
  96. for i, pi in enumerate(p): # layer index, layer predictions
  97. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  98. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  99. n = b.shape[0] # number of targets
  100. if n:
  101. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  102. # Regression
  103. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  104. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  105. pbox = torch.cat((pxy, pwh), 1) # predicted box
  106. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  107. lbox += (1.0 - iou).mean() # iou loss
  108. # Objectness
  109. score_iou = iou.detach().clamp(0).type(tobj.dtype)
  110. if self.sort_obj_iou:
  111. sort_id = torch.argsort(score_iou)
  112. b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]
  113. tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio
  114. # Classification
  115. if self.nc > 1: # cls loss (only if multiple classes)
  116. t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
  117. t[range(n), tcls[i]] = self.cp
  118. lcls += self.BCEcls(ps[:, 5:], t) # BCE
  119. # Append targets to text file
  120. # with open('targets.txt', 'a') as file:
  121. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  122. obji = self.BCEobj(pi[..., 4], tobj)
  123. lobj += obji * self.balance[i] # obj loss
  124. if self.autobalance:
  125. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  126. if self.autobalance:
  127. self.balance = [x / self.balance[self.ssi] for x in self.balance]
  128. lbox *= self.hyp['box']
  129. lobj *= self.hyp['obj']
  130. lcls *= self.hyp['cls']
  131. bs = tobj.shape[0] # batch size
  132. loss = lbox + lobj + lcls
  133. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  134. def build_targets(self, p, targets):
  135. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  136. na, nt = self.na, targets.shape[0] # number of anchors, targets
  137. tcls, tbox, indices, anch = [], [], [], []
  138. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  139. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  140. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  141. g = 0.5 # bias
  142. off = torch.tensor([[0, 0],
  143. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  144. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  145. ], device=targets.device).float() * g # offsets
  146. for i in range(self.nl):
  147. anchors = self.anchors[i]
  148. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  149. # Match targets to anchors
  150. t = targets * gain
  151. if nt:
  152. # Matches
  153. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  154. j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
  155. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  156. t = t[j] # filter
  157. # Offsets
  158. gxy = t[:, 2:4] # grid xy
  159. gxi = gain[[2, 3]] - gxy # inverse
  160. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  161. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  162. j = torch.stack((torch.ones_like(j), j, k, l, m))
  163. t = t.repeat((5, 1, 1))[j]
  164. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  165. else:
  166. t = targets[0]
  167. offsets = 0
  168. # Define
  169. b, c = t[:, :2].long().T # image, class
  170. gxy = t[:, 2:4] # grid xy
  171. gwh = t[:, 4:6] # grid wh
  172. gij = (gxy - offsets).long()
  173. gi, gj = gij.T # grid xy indices
  174. # Append
  175. a = t[:, 6].long() # anchor indices
  176. indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
  177. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  178. anch.append(anchors[a]) # anchors
  179. tcls.append(c) # class
  180. return tcls, tbox, indices, anch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...