Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

train.py 21 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
  1. import argparse
  2. import torch.distributed as dist
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torch.optim.lr_scheduler as lr_scheduler
  6. import torch.utils.data
  7. from torch.utils.tensorboard import SummaryWriter
  8. import test # import test.py to get mAP after each epoch
  9. from models.yolo import Model
  10. from utils import google_utils
  11. from utils.datasets import *
  12. from utils.utils import *
  13. mixed_precision = True
  14. try: # Mixed precision training https://github.com/NVIDIA/apex
  15. from apex import amp
  16. except:
  17. print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
  18. mixed_precision = False # not installed
  19. wdir = 'weights' + os.sep # weights dir
  20. os.makedirs(wdir, exist_ok=True)
  21. last = wdir + 'last.pt'
  22. best = wdir + 'best.pt'
  23. results_file = 'results.txt'
  24. # Hyperparameters
  25. hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
  26. 'momentum': 0.937, # SGD momentum
  27. 'weight_decay': 5e-4, # optimizer weight decay
  28. 'giou': 0.05, # giou loss gain
  29. 'cls': 0.58, # cls loss gain
  30. 'cls_pw': 1.0, # cls BCELoss positive_weight
  31. 'obj': 1.0, # obj loss gain (*=img_size/320 if img_size != 320)
  32. 'obj_pw': 1.0, # obj BCELoss positive_weight
  33. 'iou_t': 0.20, # iou training threshold
  34. 'anchor_t': 4.0, # anchor-multiple threshold
  35. 'fl_gamma': 0.0, # focal loss gamma (efficientDet default is gamma=1.5)
  36. 'hsv_h': 0.014, # image HSV-Hue augmentation (fraction)
  37. 'hsv_s': 0.68, # image HSV-Saturation augmentation (fraction)
  38. 'hsv_v': 0.36, # image HSV-Value augmentation (fraction)
  39. 'degrees': 0.0, # image rotation (+/- deg)
  40. 'translate': 0.0, # image translation (+/- fraction)
  41. 'scale': 0.5, # image scale (+/- gain)
  42. 'shear': 0.0} # image shear (+/- deg)
  43. print(hyp)
  44. # Overwrite hyp with hyp*.txt (optional)
  45. f = glob.glob('hyp*.txt')
  46. if f:
  47. print('Using %s' % f[0])
  48. for k, v in zip(hyp.keys(), np.loadtxt(f[0])):
  49. hyp[k] = v
  50. # Print focal loss if gamma > 0
  51. if hyp['fl_gamma']:
  52. print('Using FocalLoss(gamma=%g)' % hyp['fl_gamma'])
  53. def train(hyp):
  54. epochs = opt.epochs # 300
  55. batch_size = opt.batch_size # 64
  56. weights = opt.weights # initial training weights
  57. # Configure
  58. init_seeds(1)
  59. with open(opt.data) as f:
  60. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  61. train_path = data_dict['train']
  62. test_path = data_dict['val']
  63. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  64. # Remove previous results
  65. for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
  66. os.remove(f)
  67. # Create model
  68. model = Model(opt.cfg).to(device)
  69. assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
  70. model.names = data_dict['names']
  71. # Image sizes
  72. gs = int(max(model.stride)) # grid size (max stride)
  73. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  74. # Optimizer
  75. nbs = 64 # nominal batch size
  76. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  77. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  78. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  79. for k, v in model.named_parameters():
  80. if v.requires_grad:
  81. if '.bias' in k:
  82. pg2.append(v) # biases
  83. elif '.weight' in k and '.bn' not in k:
  84. pg1.append(v) # apply weight decay
  85. else:
  86. pg0.append(v) # all else
  87. optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \
  88. optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  89. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  90. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  91. print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  92. del pg0, pg1, pg2
  93. # Load Model
  94. google_utils.attempt_download(weights)
  95. start_epoch, best_fitness = 0, 0.0
  96. if weights.endswith('.pt'): # pytorch format
  97. ckpt = torch.load(weights, map_location=device) # load checkpoint
  98. # load model
  99. try:
  100. ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
  101. if model.state_dict()[k].shape == v.shape} # to FP32, filter
  102. model.load_state_dict(ckpt['model'], strict=False)
  103. except KeyError as e:
  104. s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
  105. % (opt.weights, opt.cfg, opt.weights)
  106. raise KeyError(s) from e
  107. # load optimizer
  108. if ckpt['optimizer'] is not None:
  109. optimizer.load_state_dict(ckpt['optimizer'])
  110. best_fitness = ckpt['best_fitness']
  111. # load results
  112. if ckpt.get('training_results') is not None:
  113. with open(results_file, 'w') as file:
  114. file.write(ckpt['training_results']) # write results.txt
  115. start_epoch = ckpt['epoch'] + 1
  116. del ckpt
  117. # Mixed precision training https://github.com/NVIDIA/apex
  118. if mixed_precision:
  119. model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
  120. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  121. lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
  122. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  123. scheduler.last_epoch = start_epoch - 1 # do not move
  124. # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
  125. # plot_lr_scheduler(optimizer, scheduler, epochs)
  126. # Initialize distributed training
  127. if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
  128. dist.init_process_group(backend='nccl', # distributed backend
  129. init_method='tcp://127.0.0.1:9999', # init method
  130. world_size=1, # number of nodes
  131. rank=0) # node rank
  132. model = torch.nn.parallel.DistributedDataParallel(model)
  133. # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
  134. # Dataset
  135. dataset = LoadImagesAndLabels(train_path, imgsz, batch_size,
  136. augment=True,
  137. hyp=hyp, # augmentation hyperparameters
  138. rect=opt.rect, # rectangular training
  139. cache_images=opt.cache_images,
  140. single_cls=opt.single_cls,
  141. stride=gs)
  142. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  143. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
  144. # Dataloader
  145. batch_size = min(batch_size, len(dataset))
  146. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  147. dataloader = torch.utils.data.DataLoader(dataset,
  148. batch_size=batch_size,
  149. num_workers=nw,
  150. shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
  151. pin_memory=True,
  152. collate_fn=dataset.collate_fn)
  153. # Testloader
  154. testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
  155. hyp=hyp,
  156. rect=True,
  157. cache_images=opt.cache_images,
  158. single_cls=opt.single_cls,
  159. stride=gs),
  160. batch_size=batch_size,
  161. num_workers=nw,
  162. pin_memory=True,
  163. collate_fn=dataset.collate_fn)
  164. # Model parameters
  165. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  166. model.nc = nc # attach number of classes to model
  167. model.hyp = hyp # attach hyperparameters to model
  168. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  169. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  170. # Class frequency
  171. labels = np.concatenate(dataset.labels, 0)
  172. c = torch.tensor(labels[:, 0]) # classes
  173. # cf = torch.bincount(c.long(), minlength=nc) + 1.
  174. # model._initialize_biases(cf.to(device))
  175. if tb_writer:
  176. plot_labels(labels)
  177. tb_writer.add_histogram('classes', c, 0)
  178. # Check anchors
  179. if not opt.noautoanchor:
  180. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  181. # Exponential moving average
  182. ema = torch_utils.ModelEMA(model)
  183. # Start training
  184. t0 = time.time()
  185. nb = len(dataloader) # number of batches
  186. n_burn = max(3 * nb, 1e3) # burn-in iterations, max(3 epochs, 1k iterations)
  187. maps = np.zeros(nc) # mAP per class
  188. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  189. print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  190. print('Using %g dataloader workers' % nw)
  191. print('Starting training for %g epochs...' % epochs)
  192. # torch.autograd.set_detect_anomaly(True)
  193. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  194. model.train()
  195. # Update image weights (optional)
  196. if dataset.image_weights:
  197. w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  198. image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
  199. dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n) # rand weighted idx
  200. mloss = torch.zeros(4, device=device) # mean losses
  201. print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  202. pbar = tqdm(enumerate(dataloader), total=nb) # progress bar
  203. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  204. ni = i + nb * epoch # number integrated batches (since train start)
  205. imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
  206. # Burn-in
  207. if ni <= n_burn:
  208. xi = [0, n_burn] # x interp
  209. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  210. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  211. for j, x in enumerate(optimizer.param_groups):
  212. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  213. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  214. if 'momentum' in x:
  215. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  216. # Multi-scale
  217. if opt.multi_scale:
  218. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  219. sf = sz / max(imgs.shape[2:]) # scale factor
  220. if sf != 1:
  221. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  222. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  223. # Forward
  224. pred = model(imgs)
  225. # Loss
  226. loss, loss_items = compute_loss(pred, targets.to(device), model)
  227. if not torch.isfinite(loss):
  228. print('WARNING: non-finite loss, ending training ', loss_items)
  229. return results
  230. # Backward
  231. if mixed_precision:
  232. with amp.scale_loss(loss, optimizer) as scaled_loss:
  233. scaled_loss.backward()
  234. else:
  235. loss.backward()
  236. # Optimize
  237. if ni % accumulate == 0:
  238. optimizer.step()
  239. optimizer.zero_grad()
  240. ema.update(model)
  241. # Print
  242. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  243. mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  244. s = ('%10s' * 2 + '%10.4g' * 6) % (
  245. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  246. pbar.set_description(s)
  247. # Plot
  248. if ni < 3:
  249. f = 'train_batch%g.jpg' % i # filename
  250. res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  251. if tb_writer:
  252. tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
  253. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  254. # end batch ------------------------------------------------------------------------------------------------
  255. # Scheduler
  256. scheduler.step()
  257. # mAP
  258. ema.update_attr(model)
  259. final_epoch = epoch + 1 == epochs
  260. if not opt.notest or final_epoch: # Calculate mAP
  261. results, maps, times = test.test(opt.data,
  262. batch_size=batch_size,
  263. imgsz=imgsz_test,
  264. save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
  265. model=ema.ema,
  266. single_cls=opt.single_cls,
  267. dataloader=testloader)
  268. # Write
  269. with open(results_file, 'a') as f:
  270. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  271. if len(opt.name) and opt.bucket:
  272. os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))
  273. # Tensorboard
  274. if tb_writer:
  275. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
  276. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
  277. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
  278. for x, tag in zip(list(mloss[:-1]) + list(results), tags):
  279. tb_writer.add_scalar(tag, x, epoch)
  280. # Update best mAP
  281. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  282. if fi > best_fitness:
  283. best_fitness = fi
  284. # Save model
  285. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  286. if save:
  287. with open(results_file, 'r') as f: # create checkpoint
  288. ckpt = {'epoch': epoch,
  289. 'best_fitness': best_fitness,
  290. 'training_results': f.read(),
  291. 'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
  292. 'optimizer': None if final_epoch else optimizer.state_dict()}
  293. # Save last, best and delete
  294. torch.save(ckpt, last)
  295. if (best_fitness == fi) and not final_epoch:
  296. torch.save(ckpt, best)
  297. del ckpt
  298. # end epoch ----------------------------------------------------------------------------------------------------
  299. # end training
  300. n = opt.name
  301. if len(n):
  302. n = '_' + n if not n.isnumeric() else n
  303. fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
  304. for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
  305. if os.path.exists(f1):
  306. os.rename(f1, f2) # rename
  307. ispt = f2.endswith('.pt') # is *.pt
  308. strip_optimizer(f2) if ispt else None # strip optimizer
  309. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
  310. if not opt.evolve:
  311. plot_results() # save as results.png
  312. print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  313. dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
  314. torch.cuda.empty_cache()
  315. return results
  316. if __name__ == '__main__':
  317. check_git_status()
  318. parser = argparse.ArgumentParser()
  319. parser.add_argument('--epochs', type=int, default=300)
  320. parser.add_argument('--batch-size', type=int, default=16)
  321. parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='*.cfg path')
  322. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='*.data path')
  323. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
  324. parser.add_argument('--rect', action='store_true', help='rectangular training')
  325. parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
  326. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  327. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  328. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  329. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  330. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  331. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  332. parser.add_argument('--weights', type=str, default='', help='initial weights path')
  333. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  334. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  335. parser.add_argument('--adam', action='store_true', help='use adam optimizer')
  336. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%')
  337. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  338. opt = parser.parse_args()
  339. opt.weights = last if opt.resume else opt.weights
  340. opt.cfg = check_file(opt.cfg) # check file
  341. opt.data = check_file(opt.data) # check file
  342. print(opt)
  343. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  344. device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
  345. if device.type == 'cpu':
  346. mixed_precision = False
  347. # Train
  348. if not opt.evolve:
  349. tb_writer = SummaryWriter(comment=opt.name)
  350. print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
  351. train(hyp)
  352. # Evolve hyperparameters (optional)
  353. else:
  354. tb_writer = None
  355. opt.notest, opt.nosave = True, True # only test/save final epoch
  356. if opt.bucket:
  357. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  358. for _ in range(10): # generations to evolve
  359. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  360. # Select parent(s)
  361. parent = 'single' # parent selection method: 'single' or 'weighted'
  362. x = np.loadtxt('evolve.txt', ndmin=2)
  363. n = min(5, len(x)) # number of previous results to consider
  364. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  365. w = fitness(x) - fitness(x).min() # weights
  366. if parent == 'single' or len(x) == 1:
  367. # x = x[random.randint(0, n - 1)] # random selection
  368. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  369. elif parent == 'weighted':
  370. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  371. # Mutate
  372. mp, s = 0.9, 0.2 # mutation probability, sigma
  373. npr = np.random
  374. npr.seed(int(time.time()))
  375. g = np.array([1, 1, 1, 1, 1, 1, 1, 0, .1, 1, 0, 1, 1, 1, 1, 1, 1, 1]) # gains
  376. ng = len(g)
  377. v = np.ones(ng)
  378. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  379. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  380. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  381. hyp[k] = x[i + 7] * v[i] # mutate
  382. # Clip to limits
  383. keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale', 'fl_gamma']
  384. limits = [(1e-5, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.001), (0, .9), (0, .9), (0, .9), (0, .9), (0, 3)]
  385. for k, v in zip(keys, limits):
  386. hyp[k] = np.clip(hyp[k], v[0], v[1])
  387. # Train mutation
  388. results = train(hyp.copy())
  389. # Write mutation results
  390. print_mutation(hyp, results, opt.bucket)
  391. # Plot results
  392. # plot_evolution_results(hyp)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...