Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 9.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
  1. # Loss functions
  2. import torch
  3. import torch.nn as nn
  4. from utils.general import bbox_iou
  5. from utils.torch_utils import is_parallel
  6. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  7. # return positive, negative label smoothing BCE targets
  8. return 1.0 - 0.5 * eps, 0.5 * eps
  9. class BCEBlurWithLogitsLoss(nn.Module):
  10. # BCEwithLogitLoss() with reduced missing label effects.
  11. def __init__(self, alpha=0.05):
  12. super(BCEBlurWithLogitsLoss, self).__init__()
  13. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  14. self.alpha = alpha
  15. def forward(self, pred, true):
  16. loss = self.loss_fcn(pred, true)
  17. pred = torch.sigmoid(pred) # prob from logits
  18. dx = pred - true # reduce only missing label effects
  19. # dx = (pred - true).abs() # reduce missing label and false label effects
  20. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  21. loss *= alpha_factor
  22. return loss.mean()
  23. class FocalLoss(nn.Module):
  24. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  25. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  26. super(FocalLoss, self).__init__()
  27. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  28. self.gamma = gamma
  29. self.alpha = alpha
  30. self.reduction = loss_fcn.reduction
  31. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  32. def forward(self, pred, true):
  33. loss = self.loss_fcn(pred, true)
  34. # p_t = torch.exp(-loss)
  35. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  36. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  37. pred_prob = torch.sigmoid(pred) # prob from logits
  38. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  39. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  40. modulating_factor = (1.0 - p_t) ** self.gamma
  41. loss *= alpha_factor * modulating_factor
  42. if self.reduction == 'mean':
  43. return loss.mean()
  44. elif self.reduction == 'sum':
  45. return loss.sum()
  46. else: # 'none'
  47. return loss
  48. class QFocalLoss(nn.Module):
  49. # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  50. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  51. super(QFocalLoss, self).__init__()
  52. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  53. self.gamma = gamma
  54. self.alpha = alpha
  55. self.reduction = loss_fcn.reduction
  56. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  57. def forward(self, pred, true):
  58. loss = self.loss_fcn(pred, true)
  59. pred_prob = torch.sigmoid(pred) # prob from logits
  60. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  61. modulating_factor = torch.abs(true - pred_prob) ** self.gamma
  62. loss *= alpha_factor * modulating_factor
  63. if self.reduction == 'mean':
  64. return loss.mean()
  65. elif self.reduction == 'sum':
  66. return loss.sum()
  67. else: # 'none'
  68. return loss
  69. class ComputeLoss:
  70. # Compute losses
  71. def __init__(self, model, autobalance=False):
  72. super(ComputeLoss, self).__init__()
  73. device = next(model.parameters()).device # get model device
  74. h = model.hyp # hyperparameters
  75. # Define criteria
  76. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
  77. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
  78. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  79. self.cp, self.cn = smooth_BCE(eps=0.0)
  80. # Focal loss
  81. g = h['fl_gamma'] # focal loss gamma
  82. if g > 0:
  83. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  84. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  85. self.balance = {3: [3.67, 1.0, 0.43], 4: [4.00, 1.0, 0.4, 0.10], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl]
  86. # self.balance = {3: [3.67, 1.0, 0.43], 4: [3.78, 1.0, 0.39, 0.22], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl]
  87. # self.balance = [1.0] * det.nl
  88. self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
  89. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
  90. for k in 'na', 'nc', 'nl', 'anchors':
  91. setattr(self, k, getattr(det, k))
  92. def __call__(self, p, targets): # predictions, targets, model
  93. device = targets.device
  94. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  95. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
  96. # Losses
  97. for i, pi in enumerate(p): # layer index, layer predictions
  98. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  99. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  100. n = b.shape[0] # number of targets
  101. if n:
  102. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  103. # Regression
  104. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  105. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  106. pbox = torch.cat((pxy, pwh), 1) # predicted box
  107. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  108. lbox += (1.0 - iou).mean() # iou loss
  109. # Objectness
  110. tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
  111. # Classification
  112. if self.nc > 1: # cls loss (only if multiple classes)
  113. t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
  114. t[range(n), tcls[i]] = self.cp
  115. lcls += self.BCEcls(ps[:, 5:], t) # BCE
  116. # Append targets to text file
  117. # with open('targets.txt', 'a') as file:
  118. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  119. obji = self.BCEobj(pi[..., 4], tobj)
  120. lobj += obji * self.balance[i] # obj loss
  121. if self.autobalance:
  122. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  123. if self.autobalance:
  124. self.balance = [x / self.balance[self.ssi] for x in self.balance]
  125. lbox *= self.hyp['box']
  126. lobj *= self.hyp['obj'] * 1.4 * 3 / 4
  127. lcls *= self.hyp['cls']
  128. bs = tobj.shape[0] # batch size
  129. loss = lbox + lobj + lcls
  130. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  131. def build_targets(self, p, targets):
  132. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  133. na, nt = self.na, targets.shape[0] # number of anchors, targets
  134. tcls, tbox, indices, anch = [], [], [], []
  135. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  136. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  137. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  138. g = 0.5 # bias
  139. off = torch.tensor([[0, 0],
  140. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  141. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  142. ], device=targets.device).float() * g # offsets
  143. for i in range(self.nl):
  144. anchors = self.anchors[i]
  145. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  146. # Match targets to anchors
  147. t = targets * gain
  148. if nt:
  149. # Matches
  150. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  151. j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
  152. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  153. t = t[j] # filter
  154. # Offsets
  155. gxy = t[:, 2:4] # grid xy
  156. gxi = gain[[2, 3]] - gxy # inverse
  157. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  158. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  159. j = torch.stack((torch.ones_like(j), j, k, l, m))
  160. t = t.repeat((5, 1, 1))[j]
  161. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  162. else:
  163. t = targets[0]
  164. offsets = 0
  165. # Define
  166. b, c = t[:, :2].long().T # image, class
  167. gxy = t[:, 2:4] # grid xy
  168. gwh = t[:, 4:6] # grid wh
  169. gij = (gxy - offsets).long()
  170. gi, gj = gij.T # grid xy indices
  171. # Append
  172. a = t[:, 6].long() # anchor indices
  173. indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
  174. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  175. anch.append(anchors[a]) # anchors
  176. tcls.append(c) # class
  177. return tcls, tbox, indices, anch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...