Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

yolo.py 13 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. YOLO-specific modules
  4. Usage:
  5. $ python path/to/models/yolo.py --cfg yolov5s.yaml
  6. """
  7. import argparse
  8. import sys
  9. from copy import deepcopy
  10. from pathlib import Path
  11. FILE = Path(__file__).resolve()
  12. ROOT = FILE.parents[1] # YOLOv5 root directory
  13. if str(ROOT) not in sys.path:
  14. sys.path.append(str(ROOT)) # add ROOT to PATH
  15. from models.common import *
  16. from models.experimental import *
  17. from utils.autoanchor import check_anchor_order
  18. from utils.general import check_yaml, make_divisible, set_logging
  19. from utils.plots import feature_visualization
  20. from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \
  21. select_device, time_sync
  22. try:
  23. import thop # for FLOPs computation
  24. except ImportError:
  25. thop = None
  26. LOGGER = logging.getLogger(__name__)
  27. class Detect(nn.Module):
  28. stride = None # strides computed during build
  29. onnx_dynamic = False # ONNX export parameter
  30. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  31. super().__init__()
  32. self.nc = nc # number of classes
  33. self.no = nc + 5 # number of outputs per anchor
  34. self.nl = len(anchors) # number of detection layers
  35. self.na = len(anchors[0]) // 2 # number of anchors
  36. self.grid = [torch.zeros(1)] * self.nl # init grid
  37. a = torch.tensor(anchors).float().view(self.nl, -1, 2)
  38. self.register_buffer('anchors', a) # shape(nl,na,2)
  39. self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
  40. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  41. self.inplace = inplace # use in-place ops (e.g. slice assignment)
  42. def forward(self, x):
  43. z = [] # inference output
  44. for i in range(self.nl):
  45. x[i] = self.m[i](x[i]) # conv
  46. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  47. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  48. if not self.training: # inference
  49. if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
  50. self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
  51. y = x[i].sigmoid()
  52. if self.inplace:
  53. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  54. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
  55. else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
  56. xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
  57. wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
  58. y = torch.cat((xy, wh, y[..., 4:]), -1)
  59. z.append(y.view(bs, -1, self.no))
  60. return x if self.training else (torch.cat(z, 1), x)
  61. @staticmethod
  62. def _make_grid(nx=20, ny=20):
  63. yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
  64. return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
  65. class Model(nn.Module):
  66. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  67. super().__init__()
  68. if isinstance(cfg, dict):
  69. self.yaml = cfg # model dict
  70. else: # is *.yaml
  71. import yaml # for torch hub
  72. self.yaml_file = Path(cfg).name
  73. with open(cfg) as f:
  74. self.yaml = yaml.safe_load(f) # model dict
  75. # Define model
  76. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  77. if nc and nc != self.yaml['nc']:
  78. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  79. self.yaml['nc'] = nc # override yaml value
  80. if anchors:
  81. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  82. self.yaml['anchors'] = round(anchors) # override yaml value
  83. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  84. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  85. self.inplace = self.yaml.get('inplace', True)
  86. # Build strides, anchors
  87. m = self.model[-1] # Detect()
  88. if isinstance(m, Detect):
  89. s = 256 # 2x min stride
  90. m.inplace = self.inplace
  91. m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
  92. m.anchors /= m.stride.view(-1, 1, 1)
  93. check_anchor_order(m)
  94. self.stride = m.stride
  95. self._initialize_biases() # only run once
  96. # Init weights, biases
  97. initialize_weights(self)
  98. self.info()
  99. LOGGER.info('')
  100. def forward(self, x, augment=False, profile=False, visualize=False):
  101. if augment:
  102. return self._forward_augment(x) # augmented inference, None
  103. return self._forward_once(x, profile, visualize) # single-scale inference, train
  104. def _forward_augment(self, x):
  105. img_size = x.shape[-2:] # height, width
  106. s = [1, 0.83, 0.67] # scales
  107. f = [None, 3, None] # flips (2-ud, 3-lr)
  108. y = [] # outputs
  109. for si, fi in zip(s, f):
  110. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  111. yi = self._forward_once(xi)[0] # forward
  112. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  113. yi = self._descale_pred(yi, fi, si, img_size)
  114. y.append(yi)
  115. return torch.cat(y, 1), None # augmented inference, train
  116. def _forward_once(self, x, profile=False, visualize=False):
  117. y, dt = [], [] # outputs
  118. for m in self.model:
  119. if m.f != -1: # if not from previous layer
  120. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  121. if profile:
  122. self._profile_one_layer(m, x, dt)
  123. x = m(x) # run
  124. y.append(x if m.i in self.save else None) # save output
  125. if visualize:
  126. feature_visualization(x, m.type, m.i, save_dir=visualize)
  127. return x
  128. def _descale_pred(self, p, flips, scale, img_size):
  129. # de-scale predictions following augmented inference (inverse operation)
  130. if self.inplace:
  131. p[..., :4] /= scale # de-scale
  132. if flips == 2:
  133. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  134. elif flips == 3:
  135. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  136. else:
  137. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  138. if flips == 2:
  139. y = img_size[0] - y # de-flip ud
  140. elif flips == 3:
  141. x = img_size[1] - x # de-flip lr
  142. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  143. return p
  144. def _profile_one_layer(self, m, x, dt):
  145. c = isinstance(m, Detect) # is final layer, copy input as inplace fix
  146. o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  147. t = time_sync()
  148. for _ in range(10):
  149. m(x.copy() if c else x)
  150. dt.append((time_sync() - t) * 100)
  151. if m == self.model[0]:
  152. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
  153. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  154. if c:
  155. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  156. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  157. # https://arxiv.org/abs/1708.02002 section 3.3
  158. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  159. m = self.model[-1] # Detect() module
  160. for mi, s in zip(m.m, m.stride): # from
  161. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  162. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  163. b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
  164. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  165. def _print_biases(self):
  166. m = self.model[-1] # Detect() module
  167. for mi in m.m: # from
  168. b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
  169. LOGGER.info(
  170. ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
  171. # def _print_weights(self):
  172. # for m in self.model.modules():
  173. # if type(m) is Bottleneck:
  174. # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
  175. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  176. LOGGER.info('Fusing layers... ')
  177. for m in self.model.modules():
  178. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  179. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  180. delattr(m, 'bn') # remove batchnorm
  181. m.forward = m.forward_fuse # update forward
  182. self.info()
  183. return self
  184. def autoshape(self): # add AutoShape module
  185. LOGGER.info('Adding AutoShape... ')
  186. m = AutoShape(self) # wrap model
  187. copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
  188. return m
  189. def info(self, verbose=False, img_size=640): # print model information
  190. model_info(self, verbose, img_size)
  191. def parse_model(d, ch): # model_dict, input_channels(3)
  192. LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
  193. anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
  194. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  195. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  196. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  197. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  198. m = eval(m) if isinstance(m, str) else m # eval strings
  199. for j, a in enumerate(args):
  200. try:
  201. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  202. except:
  203. pass
  204. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  205. if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
  206. BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
  207. c1, c2 = ch[f], args[0]
  208. if c2 != no: # if not output
  209. c2 = make_divisible(c2 * gw, 8)
  210. args = [c1, c2, *args[1:]]
  211. if m in [BottleneckCSP, C3, C3TR, C3Ghost]:
  212. args.insert(2, n) # number of repeats
  213. n = 1
  214. elif m is nn.BatchNorm2d:
  215. args = [ch[f]]
  216. elif m is Concat:
  217. c2 = sum([ch[x] for x in f])
  218. elif m is Detect:
  219. args.append([ch[x] for x in f])
  220. if isinstance(args[1], int): # number of anchors
  221. args[1] = [list(range(args[1] * 2))] * len(f)
  222. elif m is Contract:
  223. c2 = ch[f] * args[0] ** 2
  224. elif m is Expand:
  225. c2 = ch[f] // args[0] ** 2
  226. else:
  227. c2 = ch[f]
  228. m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
  229. t = str(m)[8:-2].replace('__main__.', '') # module type
  230. np = sum([x.numel() for x in m_.parameters()]) # number params
  231. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  232. LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n_, np, t, args)) # print
  233. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  234. layers.append(m_)
  235. if i == 0:
  236. ch = []
  237. ch.append(c2)
  238. return nn.Sequential(*layers), sorted(save)
  239. if __name__ == '__main__':
  240. parser = argparse.ArgumentParser()
  241. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  242. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  243. parser.add_argument('--profile', action='store_true', help='profile model speed')
  244. opt = parser.parse_args()
  245. opt.cfg = check_yaml(opt.cfg) # check YAML
  246. set_logging()
  247. device = select_device(opt.device)
  248. # Create model
  249. model = Model(opt.cfg).to(device)
  250. model.train()
  251. # Profile
  252. if opt.profile:
  253. img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
  254. y = model(img, profile=True)
  255. # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
  256. # from torch.utils.tensorboard import SummaryWriter
  257. # tb_writer = SummaryWriter('.')
  258. # LOGGER.info("Run 'tensorboard --logdir=models' to view tensorboard at http://localhost:6006/")
  259. # tb_writer.add_graph(torch.jit.trace(model, img, strict=False), []) # add model graph
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...