Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 8.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
  1. # Loss functions
  2. import torch
  3. import torch.nn as nn
  4. from utils.general import bbox_iou
  5. from utils.torch_utils import is_parallel
  6. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  7. # return positive, negative label smoothing BCE targets
  8. return 1.0 - 0.5 * eps, 0.5 * eps
  9. class BCEBlurWithLogitsLoss(nn.Module):
  10. # BCEwithLogitLoss() with reduced missing label effects.
  11. def __init__(self, alpha=0.05):
  12. super(BCEBlurWithLogitsLoss, self).__init__()
  13. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  14. self.alpha = alpha
  15. def forward(self, pred, true):
  16. loss = self.loss_fcn(pred, true)
  17. pred = torch.sigmoid(pred) # prob from logits
  18. dx = pred - true # reduce only missing label effects
  19. # dx = (pred - true).abs() # reduce missing label and false label effects
  20. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  21. loss *= alpha_factor
  22. return loss.mean()
  23. class FocalLoss(nn.Module):
  24. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  25. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  26. super(FocalLoss, self).__init__()
  27. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  28. self.gamma = gamma
  29. self.alpha = alpha
  30. self.reduction = loss_fcn.reduction
  31. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  32. def forward(self, pred, true):
  33. loss = self.loss_fcn(pred, true)
  34. # p_t = torch.exp(-loss)
  35. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  36. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  37. pred_prob = torch.sigmoid(pred) # prob from logits
  38. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  39. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  40. modulating_factor = (1.0 - p_t) ** self.gamma
  41. loss *= alpha_factor * modulating_factor
  42. if self.reduction == 'mean':
  43. return loss.mean()
  44. elif self.reduction == 'sum':
  45. return loss.sum()
  46. else: # 'none'
  47. return loss
  48. class QFocalLoss(nn.Module):
  49. # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  50. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  51. super(QFocalLoss, self).__init__()
  52. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  53. self.gamma = gamma
  54. self.alpha = alpha
  55. self.reduction = loss_fcn.reduction
  56. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  57. def forward(self, pred, true):
  58. loss = self.loss_fcn(pred, true)
  59. pred_prob = torch.sigmoid(pred) # prob from logits
  60. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  61. modulating_factor = torch.abs(true - pred_prob) ** self.gamma
  62. loss *= alpha_factor * modulating_factor
  63. if self.reduction == 'mean':
  64. return loss.mean()
  65. elif self.reduction == 'sum':
  66. return loss.sum()
  67. else: # 'none'
  68. return loss
  69. def compute_loss(p, targets, model): # predictions, targets, model
  70. device = targets.device
  71. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
  72. tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
  73. h = model.hyp # hyperparameters
  74. # Define criteria
  75. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device)
  76. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device)
  77. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  78. cp, cn = smooth_BCE(eps=0.0)
  79. # Focal loss
  80. g = h['fl_gamma'] # focal loss gamma
  81. if g > 0:
  82. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  83. # Losses
  84. nt = 0 # number of targets
  85. no = len(p) # number of outputs
  86. balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6
  87. for i, pi in enumerate(p): # layer index, layer predictions
  88. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  89. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
  90. n = b.shape[0] # number of targets
  91. if n:
  92. nt += n # cumulative targets
  93. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
  94. # Regression
  95. pxy = ps[:, :2].sigmoid() * 2. - 0.5
  96. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
  97. pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box
  98. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  99. lbox += (1.0 - iou).mean() # iou loss
  100. # Objectness
  101. tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
  102. # Classification
  103. if model.nc > 1: # cls loss (only if multiple classes)
  104. t = torch.full_like(ps[:, 5:], cn, device=device) # targets
  105. t[range(n), tcls[i]] = cp
  106. lcls += BCEcls(ps[:, 5:], t) # BCE
  107. # Append targets to text file
  108. # with open('targets.txt', 'a') as file:
  109. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  110. lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
  111. s = 3 / no # output count scaling
  112. lbox *= h['box'] * s
  113. lobj *= h['obj'] * s * (1.4 if no == 4 else 1.)
  114. lcls *= h['cls'] * s
  115. bs = tobj.shape[0] # batch size
  116. loss = lbox + lobj + lcls
  117. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
  118. def build_targets(p, targets, model):
  119. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  120. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
  121. na, nt = det.na, targets.shape[0] # number of anchors, targets
  122. tcls, tbox, indices, anch = [], [], [], []
  123. gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
  124. ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  125. targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
  126. g = 0.5 # bias
  127. off = torch.tensor([[0, 0],
  128. # [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
  129. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  130. ], device=targets.device).float() * g # offsets
  131. for i in range(det.nl):
  132. anchors = det.anchors[i]
  133. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
  134. # Match targets to anchors
  135. t = targets * gain
  136. if nt:
  137. # Matches
  138. r = t[:, :, 4:6] / anchors[:, None] # wh ratio
  139. j = torch.max(r, 1. / r).max(2)[0] < model.hyp['anchor_t'] # compare
  140. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  141. t = t[j] # filter
  142. # Offsets
  143. gxy = t[:, 2:4] # grid xy
  144. gxi = gain[[2, 3]] - gxy # inverse
  145. j, k = ((gxy % 1. < g) & (gxy > 1.)).T
  146. l, m = ((gxi % 1. < g) & (gxi > 1.)).T
  147. j = torch.stack((torch.ones_like(j),))
  148. t = t.repeat((off.shape[0], 1, 1))[j]
  149. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  150. else:
  151. t = targets[0]
  152. offsets = 0
  153. # Define
  154. b, c = t[:, :2].long().T # image, class
  155. gxy = t[:, 2:4] # grid xy
  156. gwh = t[:, 4:6] # grid wh
  157. gij = (gxy - offsets).long()
  158. gi, gj = gij.T # grid xy indices
  159. # Append
  160. a = t[:, 6].long() # anchor indices
  161. indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
  162. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  163. anch.append(anchors[a]) # anchors
  164. tcls.append(c) # class
  165. return tcls, tbox, indices, anch
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...