Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

autoanchor.py 7.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Auto-anchor utils
  4. """
  5. import random
  6. import numpy as np
  7. import torch
  8. import yaml
  9. from tqdm import tqdm
  10. from utils.general import colorstr
  11. def check_anchor_order(m):
  12. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  13. a = m.anchor_grid.prod(-1).view(-1) # anchor area
  14. da = a[-1] - a[0] # delta a
  15. ds = m.stride[-1] - m.stride[0] # delta s
  16. if da.sign() != ds.sign(): # same order
  17. print('Reversing anchor order')
  18. m.anchors[:] = m.anchors.flip(0)
  19. m.anchor_grid[:] = m.anchor_grid.flip(0)
  20. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  21. # Check anchor fit to data, recompute if necessary
  22. prefix = colorstr('autoanchor: ')
  23. print(f'\n{prefix}Analyzing anchors... ', end='')
  24. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  25. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  26. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  27. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  28. def metric(k): # compute metric
  29. r = wh[:, None] / k[None]
  30. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  31. best = x.max(1)[0] # best_x
  32. aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold
  33. bpr = (best > 1. / thr).float().mean() # best possible recall
  34. return bpr, aat
  35. anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors
  36. bpr, aat = metric(anchors)
  37. print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='')
  38. if bpr < 0.98: # threshold to recompute
  39. print('. Attempting to improve anchors, please wait...')
  40. na = m.anchor_grid.numel() // 2 # number of anchors
  41. try:
  42. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  43. except Exception as e:
  44. print(f'{prefix}ERROR: {e}')
  45. new_bpr = metric(anchors)[0]
  46. if new_bpr > bpr: # replace anchors
  47. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  48. m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference
  49. m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
  50. check_anchor_order(m)
  51. print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.')
  52. else:
  53. print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.')
  54. print('') # newline
  55. def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  56. """ Creates kmeans-evolved anchors from training dataset
  57. Arguments:
  58. dataset: path to data.yaml, or a loaded dataset
  59. n: number of anchors
  60. img_size: image size used for training
  61. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  62. gen: generations to evolve anchors using genetic algorithm
  63. verbose: print all results
  64. Return:
  65. k: kmeans evolved anchors
  66. Usage:
  67. from utils.autoanchor import *; _ = kmean_anchors()
  68. """
  69. from scipy.cluster.vq import kmeans
  70. thr = 1. / thr
  71. prefix = colorstr('autoanchor: ')
  72. def metric(k, wh): # compute metrics
  73. r = wh[:, None] / k[None]
  74. x = torch.min(r, 1. / r).min(2)[0] # ratio metric
  75. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  76. return x, x.max(1)[0] # x, best_x
  77. def anchor_fitness(k): # mutation fitness
  78. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  79. return (best * (best > thr).float()).mean() # fitness
  80. def print_results(k):
  81. k = k[np.argsort(k.prod(1))] # sort small to large
  82. x, best = metric(k, wh0)
  83. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  84. print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr')
  85. print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, '
  86. f'past_thr={x[x > thr].mean():.3f}-mean: ', end='')
  87. for i, x in enumerate(k):
  88. print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
  89. return k
  90. if isinstance(dataset, str): # *.yaml file
  91. with open(dataset, errors='ignore') as f:
  92. data_dict = yaml.safe_load(f) # model dict
  93. from utils.datasets import LoadImagesAndLabels
  94. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  95. # Get label wh
  96. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  97. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  98. # Filter
  99. i = (wh0 < 3.0).any(1).sum()
  100. if i:
  101. print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.')
  102. wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
  103. # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  104. # Kmeans calculation
  105. print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...')
  106. s = wh.std(0) # sigmas for whitening
  107. k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
  108. assert len(k) == n, f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}'
  109. k *= s
  110. wh = torch.tensor(wh, dtype=torch.float32) # filtered
  111. wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
  112. k = print_results(k)
  113. # Plot
  114. # k, d = [None] * 20, [None] * 20
  115. # for i in tqdm(range(1, 21)):
  116. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  117. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  118. # ax = ax.ravel()
  119. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  120. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  121. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  122. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  123. # fig.savefig('wh.png', dpi=200)
  124. # Evolve
  125. npr = np.random
  126. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  127. pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar
  128. for _ in pbar:
  129. v = np.ones(sh)
  130. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  131. v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  132. kg = (k.copy() * v).clip(min=2.0)
  133. fg = anchor_fitness(kg)
  134. if fg > f:
  135. f, k = fg, kg.copy()
  136. pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  137. if verbose:
  138. print_results(k)
  139. return print_results(k)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...