Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. YOLO-specific modules
  4. Usage:
  5. $ python path/to/models/yolo.py --cfg yolov5s.yaml
  6. """
  7. import argparse
  8. import sys
  9. from copy import deepcopy
  10. from pathlib import Path
  11. FILE = Path(__file__).resolve()
  12. sys.path.append(FILE.parents[1].as_posix()) # add yolov5/ to path
  13. from models.common import *
  14. from models.experimental import *
  15. from utils.autoanchor import check_anchor_order
  16. from utils.general import check_yaml, make_divisible, set_logging
  17. from utils.plots import feature_visualization
  18. from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
  19. select_device, time_sync
  20. try:
  21. import thop # for FLOPs computation
  22. except ImportError:
  23. thop = None
  24. LOGGER = logging.getLogger(__name__)
  25. class Detect(nn.Module):
  26. stride = None # strides computed during build
  27. onnx_dynamic = False # ONNX export parameter
  28. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  29. super().__init__()
  30. self.nc = nc # number of classes
  31. self.no = nc + 5 # number of outputs per anchor
  32. self.nl = len(anchors) # number of detection layers
  33. self.na = len(anchors[0]) // 2 # number of anchors
  34. self.grid = [torch.zeros(1)] * self.nl # init grid
  35. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  36. self.register_buffer('anchors', a) # shape(nl,na,2)
  37. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  38. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  39. self.inplace = inplace # use in-place ops (e.g. slice assignment)
  40. def forward(self, x):
  41. z = [] # inference output
  42. for i in range(self.nl):
  43. x[i] = self.m[i](x[i]) # conv
  44. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  45. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  46. if not self.training: # inference
  47. if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
  48. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  49. y = x[i].sigmoid()
  50. if self.inplace:
  51. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  52. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  53. else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
  54. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  55. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
  56. y = torch.cat((xy, wh, y[..., 4:]), -1)
  57. z.append(y.view(bs, -1, self.no))
  58. return x if self.training else (torch.cat(z, 1), x)
  59. @staticmethod
  60. def _make_grid(nx=20, ny=20):
  61. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  62. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  63. class Model(nn.Module):
  64. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  65. super().__init__()
  66. if isinstance(cfg, dict):
  67. self.yaml = cfg # model dict
  68. else: # is *.yaml
  69. import yaml # for torch hub
  70. self.yaml_file = Path(cfg).name
  71. with open(cfg) as f:
  72. self.yaml = yaml.safe_load(f) # model dict
  73. # Define model
  74. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  75. if nc and nc != self.yaml['nc']:
  76. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  77. self.yaml['nc'] = nc # override yaml value
  78. if anchors:
  79. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  80. self.yaml['anchors'] = round(anchors) # override yaml value
  81. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  82. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  83. self.inplace = self.yaml.get('inplace', True)
  84. # LOGGER.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
  85. # Build strides, anchors
  86. m = self.model[-1] # Detect()
  87. if isinstance(m, Detect):
  88. s = 256 # 2x min stride
  89. m.inplace = self.inplace
  90. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  91. m.anchors /= m.stride.view(-1, 1, 1)
  92. check_anchor_order(m)
  93. self.stride = m.stride
  94. self._initialize_biases() # only run once
  95. # LOGGER.info('Strides: %s' % m.stride.tolist())
  96. # Init weights, biases
  97. initialize_weights(self)
  98. self.info()
  99. LOGGER.info('')
  100. def forward(self, x, augment=False, profile=False, visualize=False):
  101. if augment:
  102. return self.forward_augment(x) # augmented inference, None
  103. return self.forward_once(x, profile, visualize) # single-scale inference, train
  104. def forward_augment(self, x):
  105. img_size = x.shape[-2:] # height, width
  106. s = [1, 0.83, 0.67] # scales
  107. f = [None, 3, None] # flips (2-ud, 3-lr)
  108. y = [] # outputs
  109. for si, fi in zip(s, f):
  110. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  111. yi = self.forward_once(xi)[0] # forward
  112. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  113. yi = self._descale_pred(yi, fi, si, img_size)
  114. y.append(yi)
  115. return torch.cat(y, 1), None # augmented inference, train
  116. def forward_once(self, x, profile=False, visualize=False):
  117. y, dt = [], [] # outputs
  118. for m in self.model:
  119. if m.f != -1: # if not from previous layer
  120. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  121. if profile:
  122. c = isinstance(m, Detect) # copy input as inplace fix
  123. o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  124. t = time_sync()
  125. for _ in range(10):
  126. m(x.copy() if c else x)
  127. dt.append((time_sync() - t) * 100)
  128. if m == self.model[0]:
  129. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
  130. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  131. x = m(x) # run
  132. y.append(x if m.i in self.save else None) # save output
  133. if visualize:
  134. feature_visualization(x, m.type, m.i, save_dir=visualize)
  135. if profile:
  136. LOGGER.info('%.1fms total' % sum(dt))
  137. return x
  138. def _descale_pred(self, p, flips, scale, img_size):
  139. # de-scale predictions following augmented inference (inverse operation)
  140. if self.inplace:
  141. p[..., :4] /= scale # de-scale
  142. if flips == 2:
  143. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  144. elif flips == 3:
  145. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  146. else:
  147. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  148. if flips == 2:
  149. y = img_size[0] - y # de-flip ud
  150. elif flips == 3:
  151. x = img_size[1] - x # de-flip lr
  152. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  153. return p
  154. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  155. # https://arxiv.org/abs/1708.02002 section 3.3
  156. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  157. m = self.model[-1] # Detect() module
  158. for mi, s in zip(m.m, m.stride): # from
  159. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  160. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  161. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  162. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  163. def _print_biases(self):
  164. m = self.model[-1] # Detect() module
  165. for mi in m.m: # from
  166. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  167. LOGGER.info(
  168. ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  169. # def _print_weights(self):
  170. # for m in self.model.modules():
  171. # if type(m) is Bottleneck:
  172. # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  173. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  174. LOGGER.info('Fusing layers... ')
  175. for m in self.model.modules():
  176. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  177. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  178. delattr(m, 'bn') # remove batchnorm
  179. m.forward = m.forward_fuse # update forward
  180. self.info()
  181. return self
  182. def autoshape(self): # add AutoShape module
  183. LOGGER.info('Adding AutoShape... ')
  184. m = AutoShape(self) # wrap model
  185. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  186. return m
  187. def info(self, verbose=False, img_size=640): # print model information
  188. model_info(self, verbose, img_size)
  189. def parse_model(d, ch): # model_dict, input_channels(3)
  190. LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  191. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  192. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  193. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  194. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  195. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  196. m = eval(m) if isinstance(m, str) else m # eval strings
  197. for j, a in enumerate(args):
  198. try:
  199. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  200. except:
  201. pass
  202. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  203. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
  204. BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
  205. c1, c2 = ch[f], args[0]
  206. if c2 != no: # if not output
  207. c2 = make_divisible(c2 * gw, 8)
  208. args = [c1, c2, *args[1:]]
  209. if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
  210. args.insert(2, n) # number of repeats
  211. n = 1
  212. elif m is nn.BatchNorm2d:
  213. args = [ch[f]]
  214. elif m is Concat:
  215. c2 = sum([ch[x] for x in f])
  216. elif m is Detect:
  217. args.append([ch[x] for x in f])
  218. if isinstance(args[1], int): # number of anchors
  219. args[1] = [list(range(args[1] * 2))] * len(f)
  220. elif m is Contract:
  221. c2 = ch[f] * args[0] ** 2
  222. elif m is Expand:
  223. c2 = ch[f] // args[0] ** 2
  224. else:
  225. c2 = ch[f]
  226. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  227. t = str(m)[8:-2].replace('__main__.', '') # module type
  228. np = sum([x.numel() for x in m_.parameters()]) # number params
  229. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  230. LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n_, np, t, args)) # print
  231. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  232. layers.append(m_)
  233. if i == 0:
  234. ch = []
  235. ch.append(c2)
  236. return nn.Sequential(*layers), sorted(save)
  237. if __name__ == '__main__':
  238. parser = argparse.ArgumentParser()
  239. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  240. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  241. parser.add_argument('--profile', action='store_true', help='profile model speed')
  242. opt = parser.parse_args()
  243. opt.cfg = check_yaml(opt.cfg) # check YAML
  244. set_logging()
  245. device = select_device(opt.device)
  246. # Create model
  247. model = Model(opt.cfg).to(device)
  248. model.train()
  249. # Profile
  250. if opt.profile:
  251. img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  252. y = model(img, profile=True)
  253. # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
  254. # from torch.utils.tensorboard import SummaryWriter
  255. # tb_writer = SummaryWriter('.')
  256. # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
  257. # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...