Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#578 Feature/sg 516 support head replacement for local pretrained weights unknown dataset

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-516_support_head_replacement_for_local_pretrained_weights_unknown_dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
  1. """ Mixup and Cutmix
  2. Papers:
  3. mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
  4. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
  5. Code Reference:
  6. CutMix: https://github.com/clovaai/CutMix-PyTorch
  7. CutMix by timm: https://github.com/rwightman/pytorch-image-models/timm
  8. """
  9. from typing import List, Union
  10. import numpy as np
  11. import torch
  12. from super_gradients.training.exceptions.dataset_exceptions import IllegalDatasetParameterException
  13. def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
  14. x = x.long().view(-1, 1)
  15. return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
  16. def mixup_target(target: torch.Tensor, num_classes: int, lam: float = 1., smoothing: float = 0.0, device: str = 'cuda'):
  17. """
  18. generate a smooth target (label) two-hot tensor to support the mixed images with different labels
  19. :param target: the targets tensor
  20. :param num_classes: number of classes (to set the final tensor size)
  21. :param lam: percentage of label a range [0, 1] in the mixing
  22. :param smoothing: the smoothing multiplier
  23. :param device: usable device ['cuda', 'cpu']
  24. :return:
  25. """
  26. off_value = smoothing / num_classes
  27. on_value = 1. - smoothing + off_value
  28. y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
  29. y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
  30. return y1 * lam + y2 * (1. - lam)
  31. def rand_bbox(img_shape: tuple, lam: float, margin: float = 0., count: int = None):
  32. """ Standard CutMix bounding-box
  33. Generates a random square bbox based on lambda value. This impl includes
  34. support for enforcing a border margin as percent of bbox dimensions.
  35. :param img_shape: Image shape as tuple
  36. :param lam: Cutmix lambda value
  37. :param margin: Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
  38. :param count: Number of bbox to generate
  39. """
  40. ratio = np.sqrt(1 - lam)
  41. img_h, img_w = img_shape[-2:]
  42. cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
  43. margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
  44. cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
  45. cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
  46. yl = np.clip(cy - cut_h // 2, 0, img_h)
  47. yh = np.clip(cy + cut_h // 2, 0, img_h)
  48. xl = np.clip(cx - cut_w // 2, 0, img_w)
  49. xh = np.clip(cx + cut_w // 2, 0, img_w)
  50. return yl, yh, xl, xh
  51. def rand_bbox_minmax(img_shape: tuple, minmax: Union[tuple, list], count: int = None):
  52. """ Min-Max CutMix bounding-box
  53. Inspired by Darknet cutmix impl, generates a random rectangular bbox
  54. based on min/max percent values applied to each dimension of the input image.
  55. Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
  56. :param img_shape: Image shape as tuple
  57. :param minmax: Min and max bbox ratios (as percent of image size)
  58. :param count: Number of bbox to generate
  59. """
  60. assert len(minmax) == 2
  61. img_h, img_w = img_shape[-2:]
  62. cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
  63. cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
  64. yl = np.random.randint(0, img_h - cut_h, size=count)
  65. xl = np.random.randint(0, img_w - cut_w, size=count)
  66. yu = yl + cut_h
  67. xu = xl + cut_w
  68. return yl, yu, xl, xu
  69. def cutmix_bbox_and_lam(img_shape: tuple, lam: float, ratio_minmax: Union[tuple, list] = None, correct_lam: bool = True,
  70. count: int = None):
  71. """
  72. Generate bbox and apply lambda correction.
  73. """
  74. if ratio_minmax is not None:
  75. yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
  76. else:
  77. yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
  78. if correct_lam or ratio_minmax is not None:
  79. bbox_area = (yu - yl) * (xu - xl)
  80. lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
  81. return (yl, yu, xl, xu), lam
  82. class CollateMixup:
  83. """
  84. Collate with Mixup/Cutmix that applies different params to each element or whole batch
  85. A Mixup impl that's performed while collating the batches.
  86. """
  87. def __init__(self, mixup_alpha: float = 1., cutmix_alpha: float = 0., cutmix_minmax: List[float] = None,
  88. prob: float = 1.0, switch_prob: float = 0.5,
  89. mode: str = 'batch', correct_lam: bool = True, label_smoothing: float = 0.1, num_classes: int = 1000):
  90. """
  91. Mixup/Cutmix that applies different params to each element or whole batch
  92. :param mixup_alpha: mixup alpha value, mixup is active if > 0.
  93. :param cutmix_alpha: cutmix alpha value, cutmix is active if > 0.
  94. :param cutmix_minmax: cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
  95. :param prob: probability of applying mixup or cutmix per batch or element
  96. :param switch_prob: probability of switching to cutmix instead of mixup when both are active
  97. :param mode: how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
  98. :param correct_lam: apply lambda correction when cutmix bbox clipped by image borders
  99. :param label_smoothing: apply label smoothing to the mixed target tensor
  100. :param num_classes: number of classes for target
  101. """
  102. self.mixup_alpha = mixup_alpha
  103. self.cutmix_alpha = cutmix_alpha
  104. self.cutmix_minmax = cutmix_minmax
  105. if self.cutmix_minmax is not None:
  106. assert len(self.cutmix_minmax) == 2
  107. # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
  108. self.cutmix_alpha = 1.0
  109. self.mix_prob = prob
  110. self.switch_prob = switch_prob
  111. self.label_smoothing = label_smoothing
  112. self.num_classes = num_classes
  113. self.mode = mode
  114. self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
  115. self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
  116. def _params_per_elem(self, batch_size):
  117. """
  118. generate two random masks to define which elements of the batch will be mixed and how (depending on the
  119. self.mixup_enabled, self.mixup_alpha, self.cutmix_alpha parameters
  120. :param batch_size:
  121. :return: two tensors with shape=batch_size - the first contains the lambda value per batch element
  122. and the second is a binary flag indicating use of cutmix per batch element
  123. """
  124. lam = torch.ones(batch_size, dtype=torch.float32)
  125. use_cutmix = torch.zeros(batch_size, dtype=torch.bool)
  126. if self.mixup_enabled:
  127. if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
  128. use_cutmix = torch.rand(batch_size) < self.switch_prob
  129. lam_mix = torch.where(
  130. use_cutmix,
  131. torch.distributions.beta.Beta(self.cutmix_alpha, self.cutmix_alpha).sample(sample_shape=batch_size),
  132. torch.distributions.beta.Beta(self.mixup_alpha, self.mixup_alpha).sample(sample_shape=batch_size))
  133. elif self.mixup_alpha > 0.:
  134. lam_mix = torch.distributions.beta.Beta(self.mixup_alpha, self.mixup_alpha).sample(sample_shape=batch_size)
  135. elif self.cutmix_alpha > 0.:
  136. use_cutmix = torch.ones(batch_size, dtype=torch.bool)
  137. lam_mix = torch.distributions.beta.Beta(self.cutmix_alpha, self.cutmix_alpha).sample(sample_shape=batch_size)
  138. else:
  139. raise IllegalDatasetParameterException("One of mixup_alpha > 0., cutmix_alpha > 0., "
  140. "cutmix_minmax not None should be true.")
  141. lam = torch.where(torch.rand(batch_size) < self.mix_prob, lam_mix.type(torch.float32), lam)
  142. return lam, use_cutmix
  143. def _params_per_batch(self):
  144. """
  145. generate two random parameters to define if batch will be mixed and how (depending on the
  146. self.mixup_enabled, self.mixup_alpha, self.cutmix_alpha parameters
  147. :return: two parameters - the first contains the lambda value for the whole batch
  148. and the second is a binary flag indicating use of cutmix for the batch
  149. """
  150. lam = 1.
  151. use_cutmix = False
  152. if self.mixup_enabled and torch.rand(1) < self.mix_prob:
  153. if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
  154. use_cutmix = torch.rand(1) < self.switch_prob
  155. lam_mix = torch.distributions.beta.Beta(self.cutmix_alpha, self.cutmix_alpha).sample() if use_cutmix else \
  156. torch.distributions.beta.Beta(self.mixup_alpha, self.mixup_alpha).sample()
  157. elif self.mixup_alpha > 0.:
  158. lam_mix = torch.distributions.beta.Beta(self.mixup_alpha, self.mixup_alpha).sample()
  159. elif self.cutmix_alpha > 0.:
  160. use_cutmix = True
  161. lam_mix = torch.distributions.beta.Beta(self.cutmix_alpha, self.cutmix_alpha).sample()
  162. else:
  163. raise IllegalDatasetParameterException("One of mixup_alpha > 0., cutmix_alpha > 0., "
  164. "cutmix_minmax not None should be true.")
  165. lam = float(lam_mix)
  166. return lam, use_cutmix
  167. def _mix_elem_collate(self, output: torch.Tensor, batch: list, half: bool = False):
  168. """
  169. This is the implementation for 'elem' or 'half' modes
  170. :param output: the output tensor to fill
  171. :param batch: list of thr batch items
  172. :return: a tensor containing the lambda values used for the mixing (this vector can be used for
  173. mixing the labels as well)
  174. """
  175. batch_size = len(batch)
  176. num_elem = batch_size // 2 if half else batch_size
  177. assert len(output) == num_elem
  178. lam_batch, use_cutmix = self._params_per_elem(num_elem)
  179. for i in range(num_elem):
  180. j = batch_size - i - 1
  181. lam = lam_batch[i]
  182. mixed = batch[i][0]
  183. if lam != 1.:
  184. if use_cutmix[i]:
  185. if not half:
  186. mixed = torch.clone(mixed)
  187. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  188. output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
  189. mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
  190. lam_batch[i] = lam
  191. else:
  192. mixed = mixed * lam + batch[j][0] * (1 - lam)
  193. output[i] += mixed
  194. if half:
  195. lam_batch = torch.cat((lam_batch, torch.ones(num_elem)))
  196. return torch.tensor(lam_batch).unsqueeze(1)
  197. def _mix_pair_collate(self, output: torch.Tensor, batch: list):
  198. """
  199. This is the implementation for 'pair' mode
  200. :param output: the output tensor to fill
  201. :param batch: list of thr batch items
  202. :return: a tensor containing the lambda values used for the mixing (this vector can be used for
  203. mixing the labels as well)
  204. """
  205. batch_size = len(batch)
  206. lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
  207. for i in range(batch_size // 2):
  208. j = batch_size - i - 1
  209. lam = lam_batch[i]
  210. mixed_i = batch[i][0]
  211. mixed_j = batch[j][0]
  212. assert 0 <= lam <= 1.0
  213. if lam < 1.:
  214. if use_cutmix[i]:
  215. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  216. output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
  217. patch_i = torch.clone(mixed_i[:, yl:yh, xl:xh])
  218. mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
  219. mixed_j[:, yl:yh, xl:xh] = patch_i
  220. lam_batch[i] = lam
  221. else:
  222. mixed_temp = mixed_i.type(torch.float32) * lam + mixed_j.type(torch.float32) * (1 - lam)
  223. mixed_j = mixed_j.type(torch.float32) * lam + mixed_i.type(torch.float32) * (1 - lam)
  224. mixed_i = mixed_temp
  225. torch.rint(mixed_j, out=mixed_j)
  226. torch.rint(mixed_i, out=mixed_i)
  227. output[i] += mixed_i
  228. output[j] += mixed_j
  229. lam_batch = torch.cat((lam_batch, lam_batch[::-1]))
  230. return torch.tensor(lam_batch).unsqueeze(1)
  231. def _mix_batch_collate(self, output: torch.Tensor, batch: list):
  232. """
  233. This is the implementation for 'batch' mode
  234. :param output: the output tensor to fill
  235. :param batch: list of thr batch items
  236. :return: the lambda value used for the mixing
  237. """
  238. batch_size = len(batch)
  239. lam, use_cutmix = self._params_per_batch()
  240. if use_cutmix:
  241. (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
  242. output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
  243. for i in range(batch_size):
  244. j = batch_size - i - 1
  245. mixed = batch[i][0]
  246. if lam != 1.:
  247. if use_cutmix:
  248. mixed = torch.clone(mixed) # don't want to modify the original while iterating
  249. mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
  250. else:
  251. mixed = mixed * lam + batch[j][0] * (1 - lam)
  252. output[i] += mixed
  253. return lam
  254. def __call__(self, batch, _=None):
  255. batch_size = len(batch)
  256. if batch_size % 2 != 0:
  257. raise IllegalDatasetParameterException('Batch size should be even when using this')
  258. half = 'half' in self.mode
  259. if half:
  260. batch_size //= 2
  261. output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.float32)
  262. if self.mode == 'elem' or self.mode == 'half':
  263. lam = self._mix_elem_collate(output, batch, half=half)
  264. elif self.mode == 'pair':
  265. lam = self._mix_pair_collate(output, batch)
  266. else:
  267. lam = self._mix_batch_collate(output, batch)
  268. target = torch.tensor([b[1] for b in batch], dtype=torch.int32)
  269. target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
  270. target = target[:batch_size]
  271. return output, target
Discard
Tip!

Press p or to see the previous file or, n or to see the next file