Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#578 Feature/sg 516 support head replacement for local pretrained weights unknown dataset

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-516_support_head_replacement_for_local_pretrained_weights_unknown_dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. import numpy as np
  2. import torch
  3. from torchvision.transforms import RandomErasing
  4. class DataAugmentation:
  5. @staticmethod
  6. def to_tensor():
  7. def _to_tensor(image):
  8. if len(image.shape) == 3:
  9. return torch.from_numpy(
  10. image.transpose(2, 0, 1).astype(np.float32))
  11. else:
  12. return torch.from_numpy(image[None, :, :].astype(np.float32))
  13. return _to_tensor
  14. @staticmethod
  15. def normalize(mean, std):
  16. mean = np.array(mean)
  17. std = np.array(std)
  18. def _normalize(image):
  19. image = np.asarray(image).astype(np.float32) / 255.
  20. image = (image - mean) / std
  21. return image
  22. return _normalize
  23. @staticmethod
  24. def cutout(mask_size, p=1, cutout_inside=False, mask_color=(0, 0, 0)):
  25. mask_size_half = mask_size // 2
  26. offset = 1 if mask_size % 2 == 0 else 0
  27. def _cutout(image):
  28. image = np.asarray(image).copy()
  29. if np.random.random() > p:
  30. return image
  31. h, w = image.shape[:2]
  32. if cutout_inside:
  33. cxmin, cxmax = mask_size_half, w + offset - mask_size_half
  34. cymin, cymax = mask_size_half, h + offset - mask_size_half
  35. else:
  36. cxmin, cxmax = 0, w + offset
  37. cymin, cymax = 0, h + offset
  38. cx = np.random.randint(cxmin, cxmax)
  39. cy = np.random.randint(cymin, cymax)
  40. xmin = cx - mask_size_half
  41. ymin = cy - mask_size_half
  42. xmax = xmin + mask_size
  43. ymax = ymin + mask_size
  44. xmin = max(0, xmin)
  45. ymin = max(0, ymin)
  46. xmax = min(w, xmax)
  47. ymax = min(h, ymax)
  48. image[ymin:ymax, xmin:xmax] = mask_color
  49. return image
  50. return _cutout
  51. IMAGENET_PCA = {
  52. 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
  53. 'eigvec': torch.Tensor([[-0.5675, 0.7192, 0.4009],
  54. [-0.5808, -0.0045, -0.8140],
  55. [-0.5836, -0.6948, 0.4203]])}
  56. class Lighting(object):
  57. """
  58. Lighting noise(AlexNet - style PCA - based noise)
  59. Taken from fastai Imagenet training -
  60. https://github.com/fastai/imagenet-fast/blob/faa0f9dfc9e8e058ffd07a248724bf384f526fae/imagenet_nv/fastai_imagenet.py#L103
  61. To use:
  62. - training_params = {"imagenet_pca_aug": 0.1}
  63. - Default training_params arg is 0.0 ("don't use")
  64. - 0.1 is that default in the original paper
  65. """
  66. def __init__(self, alphastd, eigval=IMAGENET_PCA['eigval'], eigvec=IMAGENET_PCA['eigvec']):
  67. self.alphastd = alphastd
  68. self.eigval = eigval
  69. self.eigvec = eigvec
  70. def __call__(self, img):
  71. if self.alphastd == 0:
  72. return img
  73. alpha = img.new().resize_(3).normal_(0, self.alphastd)
  74. rgb = self.eigvec.type_as(img).clone() \
  75. .mul(alpha.view(1, 3).expand(3, 3)) \
  76. .mul(self.eigval.view(1, 3).expand(3, 3)) \
  77. .sum(1).squeeze()
  78. return img.add(rgb.view(3, 1, 1).expand_as(img))
  79. class RandomErase(RandomErasing):
  80. """
  81. A simple class that translates the parameters supported in SuperGradient's code base
  82. """
  83. def __init__(self, probability: float, value: str):
  84. # value might be a string representing a float. First we try to convert to float and if fails,
  85. # pass it as-is to super
  86. try:
  87. value = float(value)
  88. except ValueError:
  89. pass
  90. super().__init__(p=probability, value=value)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file