Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 9.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Loss functions
  4. """
  5. import torch
  6. import torch.nn as nn
  7. from utils.metrics import bbox_iou
  8. from utils.torch_utils import is_parallel
  9. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  10. # return positive, negative label smoothing BCE targets
  11. return 1.0 - 0.5 * eps, 0.5 * eps
  12. class BCEBlurWithLogitsLoss(nn.Module):
  13. # BCEwithLogitLoss() with reduced missing label effects.
  14. def __init__(self, alpha=0.05):
  15. super(BCEBlurWithLogitsLoss, self).__init__()
  16. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  17. self.alpha = alpha
  18. def forward(self, pred, true):
  19. loss = self.loss_fcn(pred, true)
  20. pred = torch.sigmoid(pred) # prob from logits
  21. dx = pred - true # reduce only missing label effects
  22. # dx = (pred - true).abs() # reduce missing label and false label effects
  23. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  24. loss *= alpha_factor
  25. return loss.mean()
  26. class FocalLoss(nn.Module):
  27. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  28. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  29. super(FocalLoss, self).__init__()
  30. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  31. self.gamma = gamma
  32. self.alpha = alpha
  33. self.reduction = loss_fcn.reduction
  34. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  35. def forward(self, pred, true):
  36. loss = self.loss_fcn(pred, true)
  37. # p_t = torch.exp(-loss)
  38. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  39. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  40. pred_prob = torch.sigmoid(pred) # prob from logits
  41. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  42. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  43. modulating_factor = (1.0 - p_t) ** self.gamma
  44. loss *= alpha_factor * modulating_factor
  45. if self.reduction == 'mean':
  46. return loss.mean()
  47. elif self.reduction == 'sum':
  48. return loss.sum()
  49. else: # 'none'
  50. return loss
  51. class QFocalLoss(nn.Module):
  52. # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  53. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  54. super(QFocalLoss, self).__init__()
  55. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  56. self.gamma = gamma
  57. self.alpha = alpha
  58. self.reduction = loss_fcn.reduction
  59. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  60. def forward(self, pred, true):
  61. loss = self.loss_fcn(pred, true)
  62. pred_prob = torch.sigmoid(pred) # prob from logits
  63. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  64. modulating_factor = torch.abs(true - pred_prob) ** self.gamma
  65. loss *= alpha_factor * modulating_factor
  66. if self.reduction == 'mean':
  67. return loss.mean()
  68. elif self.reduction == 'sum':
  69. return loss.sum()
  70. else: # 'none'
  71. return loss
  72. class ComputeLoss:
  73. # Compute losses
  74. def __init__(self, model, autobalance=False):
  75. self.sort_obj_iou = False
  76. device = next(model.parameters()).device # get model device
  77. h = model.hyp # hyperparameters
  78. # Define criteria
  79. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
  80. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
  81. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  82. self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
  83. # Focal loss
  84. g = h['fl_gamma'] # focal loss gamma
  85. if g > 0:
  86. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  87. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  88. self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
  89. self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
  90. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
  91. for k in 'na', 'nc', 'nl', 'anchors':
  92. setattr(self, k, getattr(det, k))
  93. def __call__(self, p, targets): # predictions, targets, model
  94. device = targets.device
  95. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  96. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
  97. # Losses
  98. for i, pi in enumerate(p): # layer index, layer predictions
  99. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  100. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  101. n = b.shape[0] # number of targets
  102. if n:
  103. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  104. # Regression
  105. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  106. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  107. pbox = torch.cat((pxy, pwh), 1) # predicted box
  108. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  109. lbox += (1.0 - iou).mean() # iou loss
  110. # Objectness
  111. score_iou = iou.detach().clamp(0).type(tobj.dtype)
  112. if self.sort_obj_iou:
  113. sort_id = torch.argsort(score_iou)
  114. b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]
  115. tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio
  116. # Classification
  117. if self.nc > 1: # cls loss (only if multiple classes)
  118. t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
  119. t[range(n), tcls[i]] = self.cp
  120. lcls += self.BCEcls(ps[:, 5:], t) # BCE
  121. # Append targets to text file
  122. # with open('targets.txt', 'a') as file:
  123. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  124. obji = self.BCEobj(pi[..., 4], tobj)
  125. lobj += obji * self.balance[i] # obj loss
  126. if self.autobalance:
  127. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  128. if self.autobalance:
  129. self.balance = [x / self.balance[self.ssi] for x in self.balance]
  130. lbox *= self.hyp['box']
  131. lobj *= self.hyp['obj']
  132. lcls *= self.hyp['cls']
  133. bs = tobj.shape[0] # batch size
  134. return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
  135. def build_targets(self, p, targets):
  136. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  137. na, nt = self.na, targets.shape[0] # number of anchors, targets
  138. tcls, tbox, indices, anch = [], [], [], []
  139. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  140. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  141. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  142. g = 0.5 # bias
  143. off = torch.tensor([[0, 0],
  144. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  145. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  146. ], device=targets.device).float() * g # offsets
  147. for i in range(self.nl):
  148. anchors = self.anchors[i]
  149. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  150. # Match targets to anchors
  151. t = targets * gain
  152. if nt:
  153. # Matches
  154. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  155. j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
  156. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  157. t = t[j] # filter
  158. # Offsets
  159. gxy = t[:, 2:4] # grid xy
  160. gxi = gain[[2, 3]] - gxy # inverse
  161. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  162. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  163. j = torch.stack((torch.ones_like(j), j, k, l, m))
  164. t = t.repeat((5, 1, 1))[j]
  165. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  166. else:
  167. t = targets[0]
  168. offsets = 0
  169. # Define
  170. b, c = t[:, :2].long().T # image, class
  171. gxy = t[:, 2:4] # grid xy
  172. gwh = t[:, 4:6] # grid wh
  173. gij = (gxy - offsets).long()
  174. gi, gj = gij.T # grid xy indices
  175. # Append
  176. a = t[:, 6].long() # anchor indices
  177. indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
  178. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  179. anch.append(anchors[a]) # anchors
  180. tcls.append(c) # class
  181. return tcls, tbox, indices, anch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...