Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 9.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
  1. # Loss functions
  2. import torch
  3. import torch.nn as nn
  4. from utils.general import bbox_iou
  5. from utils.torch_utils import is_parallel
  6. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  7. # return positive, negative label smoothing BCE targets
  8. return 1.0 - 0.5 * eps, 0.5 * eps
  9. class BCEBlurWithLogitsLoss(nn.Module):
  10. # BCEwithLogitLoss() with reduced missing label effects.
  11. def __init__(self, alpha=0.05):
  12. super(BCEBlurWithLogitsLoss, self).__init__()
  13. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  14. self.alpha = alpha
  15. def forward(self, pred, true):
  16. loss = self.loss_fcn(pred, true)
  17. pred = torch.sigmoid(pred) # prob from logits
  18. dx = pred - true # reduce only missing label effects
  19. # dx = (pred - true).abs() # reduce missing label and false label effects
  20. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  21. loss *= alpha_factor
  22. return loss.mean()
  23. class FocalLoss(nn.Module):
  24. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  25. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  26. super(FocalLoss, self).__init__()
  27. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  28. self.gamma = gamma
  29. self.alpha = alpha
  30. self.reduction = loss_fcn.reduction
  31. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  32. def forward(self, pred, true):
  33. loss = self.loss_fcn(pred, true)
  34. # p_t = torch.exp(-loss)
  35. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  36. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  37. pred_prob = torch.sigmoid(pred) # prob from logits
  38. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  39. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  40. modulating_factor = (1.0 - p_t) ** self.gamma
  41. loss *= alpha_factor * modulating_factor
  42. if self.reduction == 'mean':
  43. return loss.mean()
  44. elif self.reduction == 'sum':
  45. return loss.sum()
  46. else: # 'none'
  47. return loss
  48. class QFocalLoss(nn.Module):
  49. # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  50. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  51. super(QFocalLoss, self).__init__()
  52. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  53. self.gamma = gamma
  54. self.alpha = alpha
  55. self.reduction = loss_fcn.reduction
  56. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  57. def forward(self, pred, true):
  58. loss = self.loss_fcn(pred, true)
  59. pred_prob = torch.sigmoid(pred) # prob from logits
  60. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  61. modulating_factor = torch.abs(true - pred_prob) ** self.gamma
  62. loss *= alpha_factor * modulating_factor
  63. if self.reduction == 'mean':
  64. return loss.mean()
  65. elif self.reduction == 'sum':
  66. return loss.sum()
  67. else: # 'none'
  68. return loss
  69. class ComputeLoss:
  70. # Compute losses
  71. def __init__(self, model, autobalance=False):
  72. super(ComputeLoss, self).__init__()
  73. device = next(model.parameters()).device # get model device
  74. h = model.hyp # hyperparameters
  75. # Define criteria
  76. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
  77. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
  78. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  79. self.cp, self.cn = smooth_BCE(eps=0.0)
  80. # Focal loss
  81. g = h['fl_gamma'] # focal loss gamma
  82. if g > 0:
  83. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  84. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  85. self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
  86. self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
  87. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
  88. for k in 'na', 'nc', 'nl', 'anchors':
  89. setattr(self, k, getattr(det, k))
  90. def __call__(self, p, targets): # predictions, targets, model
  91. device = targets.device
  92. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  93. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
  94. # Losses
  95. for i, pi in enumerate(p): # layer index, layer predictions
  96. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  97. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  98. n = b.shape[0] # number of targets
  99. if n:
  100. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  101. # Regression
  102. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  103. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  104. pbox = torch.cat((pxy, pwh), 1) # predicted box
  105. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  106. lbox += (1.0 - iou).mean() # iou loss
  107. # Objectness
  108. tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
  109. # Classification
  110. if self.nc > 1: # cls loss (only if multiple classes)
  111. t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
  112. t[range(n), tcls[i]] = self.cp
  113. lcls += self.BCEcls(ps[:, 5:], t) # BCE
  114. # Append targets to text file
  115. # with open('targets.txt', 'a') as file:
  116. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  117. obji = self.BCEobj(pi[..., 4], tobj)
  118. lobj += obji * self.balance[i] # obj loss
  119. if self.autobalance:
  120. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  121. if self.autobalance:
  122. self.balance = [x / self.balance[self.ssi] for x in self.balance]
  123. lbox *= self.hyp['box']
  124. lobj *= self.hyp['obj']
  125. lcls *= self.hyp['cls']
  126. bs = tobj.shape[0] # batch size
  127. loss = lbox + lobj + lcls
  128. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  129. def build_targets(self, p, targets):
  130. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  131. na, nt = self.na, targets.shape[0] # number of anchors, targets
  132. tcls, tbox, indices, anch = [], [], [], []
  133. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  134. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  135. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  136. g = 0.5 # bias
  137. off = torch.tensor([[0, 0],
  138. [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  139. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  140. ], device=targets.device).float() * g # offsets
  141. for i in range(self.nl):
  142. anchors = self.anchors[i]
  143. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  144. # Match targets to anchors
  145. t = targets * gain
  146. if nt:
  147. # Matches
  148. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  149. j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
  150. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  151. t = t[j] # filter
  152. # Offsets
  153. gxy = t[:, 2:4] # grid xy
  154. gxi = gain[[2, 3]] - gxy # inverse
  155. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  156. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  157. j = torch.stack((torch.ones_like(j), j, k, l, m))
  158. t = t.repeat((5, 1, 1))[j]
  159. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  160. else:
  161. t = targets[0]
  162. offsets = 0
  163. # Define
  164. b, c = t[:, :2].long().T # image, class
  165. gxy = t[:, 2:4] # grid xy
  166. gwh = t[:, 4:6] # grid wh
  167. gij = (gxy - offsets).long()
  168. gi, gj = gij.T # grid xy indices
  169. # Append
  170. a = t[:, 6].long() # anchor indices
  171. indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
  172. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  173. anch.append(anchors[a]) # anchors
  174. tcls.append(c) # class
  175. return tcls, tbox, indices, anch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...